Variational Methods. Zoubin Ghahramani

Variational Methods Zoubin Ghahramani [email protected] http://www.gatsby.ucl.ac.uk Statistical Approaches to Learning and Discovery Carnegie Mellon ...
60 downloads 2 Views 761KB Size
Variational Methods

Zoubin Ghahramani [email protected] http://www.gatsby.ucl.ac.uk Statistical Approaches to Learning and Discovery Carnegie Mellon University April 2003

The Expectation Maximization (EM) algorithm Given observed/visible variables y, unobserved/hidden/latent/missing variables x, and model parameters θ, maximize the likelihood w.r.t. θ. L(θ) = log p(y|θ) = log

Z

p(x, y|θ)dx,

where we have written the marginal for the visibles in terms of an integral over the joint distribution for hidden and visible variables. Using Jensen’s inequality, any distribution1 over hidden variables q(x) gives: L(θ) = log

Z

q(x)

p(x, y|θ) dx ≥ q(x)

Z

q(x) log

p(x, y|θ) dx = F(q, θ), q(x)

defining the F(q, θ) functional, which is a lower bound on the log likelihood. In the EM algorithm, we alternately optimize F(q, θ) wrt q and θ, and we can prove that this will never decrease L(θ). 1

s.t. q(x) > 0 if p(x, y|θ) > 0.

The E and M steps of EM The lower bound on the log likelihood: F(q, θ) = where H(q) = −

Z

Z

p(x, y|θ) q(x) log dx = q(x)

Z

q(x) log p(x, y|θ)dx + H(q),

q(x) log q(x)dx is the entropy of q. We iteratively alternate:

E step: maximize F(q, θ) wrt distribution over hidden variables given the parameters:  (k) (k−1) q (x) := argmax F q(x), θ . q(x)

M step: maximize F(q, θ) wrt the parameters given the hidden distribution: θ(k) := argmax F q (k)(x), θ = argmax θ



θ

Z

q (k)(x) log p(x, y|θ)dx,

which is equivalent to optimizing the expected complete-data likelihood p(x, y|θ), since the entropy of q(x) does not depend on θ.

EM as Coordinate Ascent in F

The EM algorithm never decreases the log likelihood The difference between the log likelihood and the bound: L(θ) − F(q, θ) = log p(y|θ) −

Z

q(x) log

p(x, y|θ) dx q(x)

= log p(y|θ) −

Z

q(x) log

p(x|y, θ)p(y|θ) dx q(x)

= −

Z

 p(x|y, θ) q(x) log dx = KL q(x), p(x|y, θ) , q(x)

This is the Kullback-Liebler divergence; it is non-negative and zero if and only if q(x) = p(x|y, θ) (thus this is the E step). Although we are working with a bound on the likelihood, the likelihood is non-decreasing in every iteration: L θ

(k−1)



= F q E step

(k)



(k−1)



≤ F q M step

(k)



(k)



≤ L θ Jensen

(k)



,

where the first equality holds because of the E step, and the first inequality comes from the M step and the final inequality from Jensen. Usually EM converges to a local optimum of L (although there are exceptions).

A Generative Model for Generative Models SBN, Boltzmann Machines

Factorial HMM

hier

dyn

Cooperative Vector Quantization

distrib

mix : mixture red-dim : reduced dimension dyn : dynamics distrib : distributed representation hier : hierarchical nonlin : nonlinear

distrib

dyn

switch : switching

HMM

Mixture of Gaussians (VQ)

mix red-dim

Mixture of HMMs

mix Mixture of Factor Analyzers

Gaussian

red-dim

dyn mix

Factor Analysis (PCA)

dyn

nonlin

Linear Dynamical Systems (SSMs)

ICA

hier Nonlinear Gaussian Belief Nets

Switching State-space Models

dyn

nonlin

Nonlinear Dynamical Systems

switch mix

Mixture of LDSs

Example: Dynamic Bayesian Networks Factorial Hidden Markov Models, Switching State Space Models, and Nonlinear Dynamical Systems, Nonlinear Dynamical Systems, ... (1)

(1)

X1 (1)

S t-1 

(1)

X2

S t-1 

(3)

S t+1





(2)

(M)

X2

X3

S1

S2

S3

Y1

Y2

Y3



St

S t+1 

(3)

(3)

St

S t+1

Yt-1

Yt

Yt+1



(M)

X1 (2)



S t-1 



(1)

St

(M)

(2)

(1)

X3













At



At+1 Bt

Ct

Bt+1 Ct+1

Dt

At+2 Bt+2

...

Ct+2

Dt+1

Dt+2



Intractability For many models of interest, exact inference is not computationally feasible. This occurs for two (main) reasons: • distributions may have complicated forms (non-linearities in generative model) • “explaining away” causes coupling from observations: observing the value of a child induces dependencies amongst its parents X1

X2

1

2

X3

1

X4

1

X5

3

Y

Y = X1 + 2 X2 + X3 + X4 + 3 X5

We can still work with such models by using approximate inference techniques to estimate the latent variables.

Variational Appoximations Assume your goal is to maximize likelihood ln p(y|θ). Any distribution q(x) over the hidden variables defines a lower bound on ln p(y|θ): ln p(y|θ) ≥

X x

q(x) ln

p(x, y|θ) = F(q, θ) q(x)

Constrain q(x) to be of a particular tractable form (e.g. factorised) and maximise F subject to this constraint • E-step: Maximise F w.r.t. q with θ fixed, subject to the constraint on q, equivalently minimize: X q(x) ln p(y|θ) − F(q, θ) = q(x) ln = KL(qkp) p(x|y, θ) x The inference step therefore tries to find q closest to the exact posterior distribution. • M-step: Maximise F w.r.t. θ with q fixed (related to mean-field approximations)

Variational Approximations for Bayesian Learning

Model structure and overfitting: a simple example

M=0

M=1

M=2

M=3

40

40

40

40

20

20

20

20

0

0

0

0

−20

0

5

10

−20

0

M=4

5

10

−20

0

M=5

5

10

−20

M=6

40

40

40

20

20

20

20

0

0

0

0

0

5

10

−20

0

5

10

−20

0

5

5

10

M=7

40

−20

0

10

−20

0

5

10

Learning Model Structure • Conditional Independence Structure What is the structure of the graph (i.e. what ⊥ ⊥ relations hold)?

B

A

C

D E

• Feature Selection Is some input relevant to predicting some output ? • Cardinality of Discrete Latent Variables How many clusters in the data? How many states in a hidden Markov model? SVYDAAAQLTADVKKDLRDSWKVIGSDKKGNGVALMTTY

• Dimensionality of Real Valued Latent Vectors What choice of dimensionality in a PCA/FA model of the data? How many state variables in a linear-Gaussian state-space model?

Using Bayesian Occam’s Razor to Learn Model Structure Select the model class, m, with the highest probability given the data, y: Z p(y|m)p(m) p(m|y) = , p(y|m) = p(y|θ, m) p(θ|m) dθ p(y) Interpretation of the marginal likelihood (“evidence”): The probability that randomly selected parameters from the prior would generate y. Model classes that are too simple are unlikely to generate the data set. Model classes that are too complex can generate many possible data sets, so again, they are unlikely to generate that particular data set at random.

P(Y|Mi)

too simple

"just right" too complex Y All possible data sets

Bayesian Model Selection: Occam’s Razor at Work M=0

M=1

40

M=2

40

M=3

40

40 Model Evidence

20

20

20

0

0

0

0

0.8

−20

−20

−20

−20

0

5

10

0

M=4

5

10

0

M=5

5

10

M=6

40

40

40

20

20

20

20

0

0

0

0

0

demo:

5

10

−20

0

polybayes

5

10

−20

0

5

5

10

M=7

40

−20

0

10

−20

P(Y|M)

20

1

0.6

0.4

0.2

0

0

1

2

3

4

5

6

7

M

0

5

10

Subtleties of Occam’s Hill

Computing Marginal Likelihoods can be Computationally Intractable

p(y|m) =

Z

p(y|θ, m) p(θ|m) dθ

• This can be a very high dimensional integral. • The presence of latent variables results in additional dimensions that need to be marginalized out. p(y|m) =

Z Z

p(y, x|θ, m) p(θ|m) dx dθ

• The likelihood term can be complicated.

Practical Bayesian approaches

• Laplace approximations: – Appeals to asymptotic normality to make a Gaussian approximation about the posterior mode of the parameters. • Large sample approximations (e.g. BIC). • Markov chain Monte Carlo methods (MCMC): – converge to the desired distribution in the limit, but: – many samples are required to ensure accuracy. – sometimes hard to assess convergence and reliably compute marginal likelihood. • Variational approximations...

Note: other deterministic approximations are also available now: e.g. Bethe/Kikuchi approximations, Expectation Propagation, Tree-based reparameterizations.

Lower Bounding the Marginal Likelihood Variational Bayesian Learning Let the latent variables be x, data y and the parameters θ. We can lower bound the marginal likelihood (by Jensen’s inequality): ln p(y|m) = ln

Z

p(y, x, θ|m) dx dθ

= ln

Z

q(x, θ)



Z

q(x, θ) ln

p(y, x, θ|m) dx dθ q(x, θ) p(y, x, θ|m) dx dθ. q(x, θ)

Use a simpler, factorised approximation to q(x, θ) ≈ qx(x)qθ (θ): Z p(y, x, θ|m) qx(x)qθ (θ) ln dx dθ ln p(y|m) ≥ qx(x)qθ (θ) = Fm(qx(x), qθ (θ), y).

Variational Bayesian Learning . . . Maximizing this lower bound, Fm, leads to EM-like iterative updates: qx(t+1)(x) ∝ exp (t+1)



Z

(t)

ln p(x,y|θ, m) qθ (θ) dθ

(θ) ∝ p(θ|m) exp

Z



E-like step

 ln p(x,y|θ, m) qx(t+1)(x) dx

M-like step

Maximizing Fm is equivalent to minimizing KL-divergence between the approximate posterior, qθ (θ) qx(x) and the true posterior, p(θ, x|y, m): ln p(y|m) − Fm(qx(x), qθ (θ), y) =

Z

qx(x) qθ (θ) ln

qx(x) qθ (θ) dx dθ = KL(qkp) p(θ, x|y, m)

In the limit as n → ∞, for identifiable models, the variational lower bound approaches Schwartz’s (1978) BIC criterion.

Conjugate-Exponential models Let’s focus on conjugate-exponential (CE) models, which satisfy (1) and (2): Condition (1). The joint probability over variables is in the exponential family: >



p(x, y|θ) = f (x, y) g(θ) exp φ(θ) u(x, y)

where φ(θ) is the vector of natural parameters, u are sufficient statistics Condition (2). The prior over parameters is conjugate to this joint probability: η



>

p(θ|η, ν) = h(η, ν) g(θ) exp φ(θ) ν where η and ν are hyperparameters of the prior.



Conjugate priors are computationally convenient and have an intuitive interpretation: • η: number of pseudo-observations • ν: values of pseudo-observations

Conjugate-Exponential examples In the CE family: • • • • •

Gaussian mixtures factor analysis, probabilistic PCA hidden Markov models and factorial HMMs linear dynamical systems and switching models discrete-variable belief networks

Other as yet undreamt-of models can combine Gaussian, Gamma, Poisson, Dirichlet, Wishart, Multinomial and others.

Not in the CE family: • • • •

Boltzmann machines, MRFs (no conjugacy) logistic regression (no conjugacy) sigmoid belief networks (not exponential) independent components analysis (not exponential)

Note: one can often approximate these models with models in the CE family.

The Variational Bayesian EM algorithm EM for MAP estimation

Variational Bayesian EM

Goal: maximize p(θ|y, m) w.r.t. θ E Step: compute

Goal: lower bound p(y|m) VB-E Step: compute ¯ (t)) qx(t+1)(x) = p(x|y, φ

qx(t+1)(x) = p(x|y, θ (t)) M Step: θ

(t+1)

VB-M Step:

Z (t+1) (x) ln p(x, y, θ) dx =argmax qx θ

(t+1)



(θ) ∝ exp

Z

 (t+1) qx (x) ln p(x, y, θ) dx

Properties: • Reduces to the EM algorithm if qθ (θ) = δ(θ − θ ∗). • Fm increases monotonically, and incorporates the model complexity penalty. • Analytical parameter distributions (but not constrained to be Gaussian). • VB-E step has same complexity as corresponding E step. • We can use the junction tree, belief propagation, Kalman filter, etc, algorithms ¯ in the VB-E step of VB-EM, but using expected natural parameters, φ.

Variational Bayesian EM The Variational Bayesian EM algorithm has been used to approximate Bayesian learning in a wide range of models such as: • • • • • •

probabilistic PCA and factor analysis mixtures of Gaussians and mixtures of factor analysers hidden Markov models state-space models (linear dynamical systems) independent components analysis (ICA) and mixtures discrete graphical models...

The main advantage is that it can be used to automatically do model selection and does not suffer from overfitting to the same extent as ML methods do.

Also it is about as computationally demanding as the usual EM algorithm.

See: www.variational-bayes.org demos: mixture of Gaussians, hidden Markov models













































































 

  



 













































  





  







  







  























  





  





  





  



















  







   





























 







*

*

(

(

&

&

&



&



&



&





( 









&











( 



&









( 

)

)

&



&





( 

&

&





'

&





#



!



%

$

"







 





 

 









 











  















 









 



  

Marginal Likelihood (AIS)

−2900

*

−3000 −3100 −3200 −3300 −3400 −3500 −3600 −3700

2

10

3

4

10 10 Duration of Annealing (samples)

5

10

 

 





 

−2800

rank of true structure 10 0

AIS VB BIC

10 1

10 2

10 1

10 2

10 3

n 10

4 













 

  



 

Comparison to Cheeseman-Stutz and BIC 0.9

0.9 BIC BICprior CS CSnew VB

0.8

0.8 0.7 True Structure has P>0.5

True Structure has P>0.01

0.7

BIC BICprior CS CSnew VB

0.6 0.5 0.4 0.3

0.6 0.5 0.4 0.3

0.2

0.2

0.1

0.1

0

0 1

10

2

3

10

10 Size of Data Set

4

10

1

10

2

3

10

10

4

10

Size of Data Set

• Averaging over about 100 samples. • CS is much better than BIC, under some measures as good as VB. Note: BIC and CS require estimates of the effective number of parameters. This can be difficult to compute. We estimate the effective number of parameters using a variant of the procedure described in Geiger, Heckerman and Meek (1996).

Summary and Conclusions

• EM can be interpreted as a lower bound maximization algorithm. • For many models of interest the E step is intractable. • For such models an approximate E step can be used in a variational lower bound optimization algorithm. • This lower bound idea can also be used to do variational Bayesian learning. • Bayesian learning embodies automatic Occam’s razor via the marginal likelihood. • This makes it possible to avoid overfitting and select models. Appendix

Appendix

Example: A Multiple Cause Model

Shapes Problem 36

...

W2

W1

Training Data

W3

16

x2

...

x1

16

...

...

16

y

x3

s1

...

s2

Architecture

W1

y

W2

W3

Output Weight Matrix

sK

Example: A Multiple Cause Model

s1

...

s2

sK

Model with binary latent variables si ∈ {0, 1}, real-valued observed vector y and parameters 2 θ = {{µi, πi}K i=1 , σ }

y

p(s1, . . . , sK |π) =

K Y

p(si|πi) =

i=1

K Y

πisi (1 − πi)(1−si)

i=1

X p(y|s1, . . . , sK , µ, σ ) = N ( siµi, σ 2I) 2

i

EM optimizes lower bound on likelihood: F(q, θ) = hlog p(s, y|θ)iq(s) −hlog q(s)iq(s) where hiq is expectation under q. Optimum E step: q(s) = p(s|y, θ) is exponential in K.

Example: A Multiple Cause Model (cont) s1

...

s2

sK

F(q, θ) = hlog p(s, y|θ)iq(s) − hlog q(s)iq(s) (1) y

log

p(s, y|θ) + c

=

PK

si log πi

=

PK

si log πi

i=1

i=1

X X 1 > +(1 − si) log(1 − πi) − D log σ − (y − siµi) (y − siµi) 2σ 2 i i +(1 − si) log(1 − πi) − D log σ   X XX 1 > > >  − 2 y y−2 siµi y + sisj µi µj  2σ i i j

we therefore need hsii and hsisj i to compute F. These are the expected sufficient statistics of the hidden variables.

Example: A Multiple Cause Model (cont) Variational approximation:

q(s) =

Y i

qi(si) =

K Y

λsi i (1 − λi)(1−si)

(2)

i=1

Under this approximation we know hsii = λi and hsisj i = λiλj + δij (λi − λ2i ). F (λ, θ) =

X

λi log

i

πi (1 − πi) + (1 − λi) log λi (1 − λi)

X X 1 > (y − λiµi) (y − λiµi) + C(λ, µ) + c − D log σ − 2σ 2 i i

where C(λ, µ) =

− 2σ1 2

D 2 > (λ − λ )µ µ , and c = − i i i i i 2 log(2π) is a constant.

P

(3)

Fixed point equations for multiple cause model Taking derivatives w.r.t. λi: X ∂F πi λi 1 1 > > = log − log + (y − λj µj ) µi − 2 µi µi ∂λi 1 − πi 1 − λi σ 2 2σ

(4)

j6=i

Setting to zero we get fixed point equations: 



X 1 1 >  π i >  + (y − λj µj ) µi − 2 µi µi λi = f log 1 − πi σ 2 2σ j6=i

where f (x) = 1/(1 + exp(−x)) is the logistic (sigmoid) function. Learning algorithm: E step: run fixed point equations until convergence of λ for each data point. M step: re-estimate θ given λs.

(5)

Structured Variational Approximations q(s) need not be completely factorized. For example, suppose you can partition s into sets s1 and s2 such that computing the expected sufficient statistics under q(s1) and q(s2) is tractable. Then q(s) = q(s1)q(s2) is tractable. If you have a graphical model, you may want to factorize q(s) into aproduct of trees, which are tractable distributions. At

At+1 Bt

Ct

Bt+1 Ct+1

Dt

At+2 Bt+2

...

Ct+2

Dt+1

Dt+2

Scaling the parameter priors How the parameter priors are scaled determines whether an Occam’s hill is present or not. Order 0

Unscaled models:

Order 2

Order 4

Order 6

Order 8

Order 10

2

2

2

2

2

2

1

1

1

1

1

1

0

0

0

0

0

0

−1

−1

−1

−1

−1

−1

−2

−2 −1 0 1

−2 −1 0 1

−2 −1 0 1

−2 −1 0 1

−2 −1 0 1

−1 0 1

0.2 0.1 0

0

1

2

Order 0

Scaled models:

3

4

Order 2

5 6 Model order

Order 4

7

8

Order 6

9

10

Order 8

Order 10

2

2

2

2

2

2

1

1

1

1

1

1

0

0

0

0

0

0

−1

−1

−1

−1

−1

−1

−2

−2 −1 0 1

−2 −1 0 1

−2 −1 0 1

−2 −1 0 1

11

−2 −1 0 1

−1 0 1

0.2 0.1 0

0

1

2

3

4

5 6 Model order

(Rasmussen & Ghahramani, 2000)

7

8

9

back

10

11

The Cheeseman-Stutz (CS) Approximation The Cheeseman-Stutz approximation is based on: R dθ p(θ|m)p(y|θ, m) p(y|m) R p(y|m) = p(z|m) = p(z|m) p(z|m) dθ p(θ 0|m)p(z|θ 0, m) which is true for any completion of the data: z = {ˆs, y}. We use the BIC approximation for both top and bottom integrals: ˆ ˆ − d ln n ln p(y|m) ≈ ln p(ˆs, y|m) + ln p(θ|m) + ln p(y|θ) 2 0 0 d ˆ |m) − ln p(ˆs, y|θ) ˆ + ln n − ln p(θ 2 ˆ − ln p(ˆs, y|θ) ˆ , = ln p(ˆs, y|m) + ln p(y|θ) This can be corrected for d 6= d0. ˆ complete data with expectations under θ, ˆ Cheeseman-Stutz: Run MAP to get θ, compute CS approximation as above.

Suggest Documents