Natural Gradient - UCL Computer Science

Transcription

Natural Gradient - UCL Computer Science
1
Natural Gradient
Daniel Worrall
I. P ROBLEM SETTING
A. A generalisation
Say I have an objective function to minimise a scalar-valued
function f . I can parametrise this function with respect to some
coordinate system Θ as f (θ) and then find a solution to
θ∗ = arg min [f (θ)] .
(1)
θ
Alternatively I could find another parameterisation of f with
respect to a different basis Φ as f (φ) and perform
φ∗ = arg min [f (φ)] .
(2)
φ
Let’s define a new basis N . Suppose we estimate ν ∗ with
successive samples of an iterative scheme where
ν k+1 ← ν k + g(ν0k , f )
(4)
where ν k is our estimate of ν ∗ at iteration k, ν0k =
{ν 0 , ν 1 , ..., ν k } is the current partial trajectory of samples and g
is some function, we are yet to define. If we want a reasonable
solution then we wish for ν k to be bounded so
k−1
X
(5)
kν k k = ν 0 +
g(ν0i , f ) < ∞.
i=1
For arguments sake, I’m going to restrict my analysis to the
scenario where we can define a continuous, invertible function
between parameter spaces T : Θ → Φ, so in topology speak we
say that Θ and Φ are topologically equivalent. I’m also going to
focus on smooth, well-behaved f only.
Starting from a bounded point and running the algorithm to
infinity, this implies
∞
∞
X
X
i
g(ν0 , f ) < ∞.
(6)
g(ν0i , f ) ≤
Now, say we cannot find an exact solution to our optimisation so
perform an iterative technique to find the minimiser. For either
basis, if we initialise at the same point φ0 = T (θ0 ) and then
run the algorithm we will return a trajectory of points {θ0∞ } =
0
1
∞
{θ0 , θ1 , ..., θ∞ } and {φ∞
0 } = {φ , φ , ..., φ }.
There are many different trajectories satisfying this constraint, so
we also note that eventually we wish for the algorithm to stop,
so limk→∞ kg(ν0k , f )k = 0. Note, however, that this condition
alone isn’t enough.
We have already established that our start points are the same
φ0 = T (θ0 ), we might also wish that the end points of the
algorithm in either basis are the same φ∞ = T (θ∞ ). This is
necessarily true if we are performing true global optimisation,
but generally we focus on local problems. Now there is a school
of thought called, the principle of covariance1 which says that ‘a
consistent algorithm should give the same results independent of
the units in which quantities are measured’, i.e. every point of
either trajectory should be the same
B. Local minimisation
{φ} = T (θ}).
(3)
So we are aiming to develop a basis-independent algorithm to
perform our optimisation. This is an attractive idea, because it
offers a level of robustness with respect to how we decide to
represent our data and as we shall see it is in a sense optimal,
in that we are working with the natural parametrisation of the
problem.
1 I have taken the naming convention from Mackay, who took the original idea
from Knuth. I believe the name comes from the fact that we are dealing with
covariant gradients
i=1
i=1
Let’s consider the simpler problem of descending on a unique
local minimum. We aim to satisfy the sufficient conditions
∇ν f (ν) = 0
ν
>
∇2ν f (ν)ν
> 0.
(7)
(8)
1) Steepest descent: The simplest descent scheme is steepest
descent (SD), which seeks to satisfy (7) only, so really, there
is no guarantee of even descending upon a minimum, just
stationary points; nonetheless, it is widely used due to ease of
implementation. The updates take the form of
ν k+1 ← ν k + α∇ν f (ν k ).
(9)
where α is small and negative. We this see that this is similar to
limk→∞ kg(ν0k , f )k = 0, in that if we are slowly converging
on our minimum (by correct choice of α) then we expect
limk→∞ k∇f (ν k )k = 0, so in this case
g(ν0k , f ) = g(ν k , f ) = α∇ν f (ν k ).
(10)
2
SD is first-order Markov2 , in that the next move depends on the
current state only.
The problem with SD is that it is basis-dependent. To see this
explicitly we use Θ and Φ again as our basis pair and define the
function f expressed in different bases as
fθ (θ) = fθ T −1 (φ) = fφ (φ) = fφ (T (θ)) .
(11)
If we consider differentials then this becomes
n
X
|dw|2 =
(dwi )2 .
If the coordinate system is non-orthonormal, however, then the
squared length is given locally by
|dw|2 =
The steepest descent update in Φ is
φk+1 = φk + α∇φ fφ (φk )
(12)
(13)
Given T (θk ) = φk in order for {φ} = T ({θ}), we require
T (θk+1 ) = φk+1 after the update. Is this the case? Immediately
we see that the algorithm will only be globally invariant under
linear transformations because what we are really asking is to
evaluate whether
T (∇θ fθ (θk )) = ∇φ fφ (φk ).
n
X
gij (w)dwi dwj .
(18)
i,j
and the same update in Θ is
θk+1 = θk + α∇θ fθ (θk ).
(17)
i=1
(14)
gij evaluated at a particular point in space with respect to a given
basis returns a matrix G called the Riemannian metric tensor.
When the parameter space is a curved manifold we have to resort
to using this kind of approximation at each point in space. This
is the Riemannian space.
Now in the normal Euclidean setting gij = δij , the Kronecker
delta, so the squared length reduces to the usual dot product form.
This is the source of much of our confusion and why transformations preserve volume and angle leave naive SD invariant.
So what is the steepest descent direction, factoring in the Riemannian metric? This can be found with a simple optimisation.
arg min f (ν + dν) − f (ν) subject to |a|2 = a> Ga = 1. (19)
In reality though, we are only ever concerning with small volumes
of parameter space, which we can approximate as linear. To make
reading easier, we adopt the notation g = ∇θ fθ (θk ) and g 0 =
∇φ fφ (φk ), so
This can be solved simply using Lagrange multipliers to yield
the natural gradient
T (g) = [∇θ T (θk )]g 0
˜ (ν) = G−1 ∇f (ν).
∇f
(15)
which wasn’t what we were hoping for T (g) = g 0 . The only
transformations, which leave the SD algorithm invariant are
(locally) volume preserving rotations3 i.e. orthonormal ones. This
is a fairly poor set of transformations. Surely we can do better
than this!
The problem is that we have not actually chosen the direction of
greatest reduction in f . Note this is different to the direction of
steepest descent!
C. Another point of view
The partial gradient direction ∇f paradoxically does not change
f the most. How can this be? The problem arises from the
fact that we are assuming that we are computing everything in
Euclidean space. In fact we need to consider the more general
Riemannian space.
In Euclidean space the distance between two points x and y is
computed as the root of the square of their difference i.e.
q
(16)
d(x, y) = (x − y)> (x − y).
2 My
naming convention
can also add reflections, but we want to impose the positive definiteness
conditions, which aren’t strictly part of SD
3 We
a:dν=a
(20)
The natural gradient changes the nature of the partial gradient
∇f (ν) such that it transforms in a different way when we change
basis. Anyone with a background in differential geometry will
recognise immediately that we are simply converting a covariant
gradient into a contravariant gradient by index raising using the
metric tensor.
D. How the natural gradient links with covariant optimisation
The natural gradient steepest descent method is a covariant
optimisation on a local level, because transformations of the
parameter space leave the algorithm invariant. Globally, we
cannot define a mapping such that the partial gradient is invariant
completely, but do we really need to do this? In reality, we
are only going to concern ourselves with small patches of the
parameter space because we want to keep our step sizes α small.
These patches we approximate as these Riemannian manifolds
where we can approximate a distance metric locally. So whilst
we view transformations locally as these linear mappings, we
also need to remember to reevaluate what we mean by distance
within each locality to maintain a coherent view of the world.
The natural gradient is the gradient of the function with
respect to the parameters, with this redefinition of distance
in mind.
3
Going back to the transformation example, if we were to use
natural gradient descent we would get
θk+1 = θk + αG−1 g
(21)
which under transformation becomes, using A> = [∇θ T (θk )] to
clean up notation
T (θk+1 ) = T (θk ) + αT (G−1 g)
k
>
k
−1
−1
= T (θ ) + α(A GA)
= T (θ ) + αA
= φk + αG−1 g 0
−1
G
g
(22)
>
A g
(23)
(24)
(25)