Despite the title, this is my by-no-means exhaustive collection of known ways to decompose the ELBO, focusing on the role of the aggregate KL regularizer, its relation to the aggregate posterior, and the mutual information between the data and the latent variable.
Setup
We assume the data distribution has a density, \(p^*(x)\), and we would like to estimate it by that of a latent variable model \(p(x) := \int p(x|z) p(z) dz\), via maximum likelihood. Everything works the same in the discrete setting by replacing the word "density" with "PMF", and integrals with sums. In reality, the data distribution may not have a probability density, as is likely the case for real-world data such as natural images; then maximum likelihood density estimation runs into severe trouble, as I explain here. However, let's just stick with the standard textbook setting, and suppose the data distribution really does have a density \(p^*(x)\) 1.
The ELBO (evidence lower bound) is a tractable lower bound on the model log density \(p(x)\):
where \(q(z|x)\) is any distribution 2 over the latent space, and \(p(z|x) \propto p(z,x)\) is the Bayesian posterior of \(z\) given \(x\) under our latent variable model. Note that all the distributions without the * superscript (i.e., everything except the true data density \(p^*(x)\)) are ones we can potentially control and parameterize.
For maximum-likelihood training, ideally we would like to maximize the population log-likelihood \(\mathbb{E}_{p^*(x)}[ \log p(x)]\), or its sample approximation based on a training set. Since \(\log p(x)\) is intractable, we replace it with the ELBO, resulting in the aggregate ELBO objective,
which can be decomposed into an aggregate conditional log-likelihood, minus an aggregate ELBO regularizer:
Of particular interest to me is the aggregate ELBO regularizer. Many of the decompositions of the aggregate ELBO are simply based on rewriting the aggregate regularizer term. Moreover, the aggregate ELBO regularizer also arises outside the context of latent variable modeling and density estimation, and has a nice connection to information theory.
Decomposing the Aggregate ELBO Regularizer
Besides the standard interpretation of the aggregate ELBO regularizer as penalizing the complexity of the approximate posterior \(q(z|x)\) by pulling it closer to the prior \(p(z)\), here are some other ways to decompose and interpret it:
1. As the cross entropy between the aggregate posterior and prior, plus the negative average entropy of approximate posteriors:
The aggregate posterior is defined as \(q(z) = \mathbb{E}_{p^*(x)}[ q(z|x) ]\), and the cross-entropy \(CE(q(z)\| p(z)) = \mathbb{E}_{q(z)} [-\log p(z)]\). This form is easily obtained by writing out \(KL(q(z|x)\| p(z)) = \mathbb{E}_{q(z|x)}[- \log p(z) + \log q(z|x) ]\) and simplifying on the LHS.
This form reveals that the optimal prior, in the sense of minimizing the aggregate KL regularizer (and hence the aggregate ELBO), should be equal to the aggregate posterior, i.e., \(p(z) = q(z)\), which is when the cross-entropy term is minimized. See the ELBO surgery paper and the VampPrior paper, which popularized this decomposition.
2. As the KL divergence between the aggregate posterior and prior, plus the mutual information of \((x, z) \sim p^*(x) q(z|x)\):
where the mutual information is equal to
This form can be derived from the first decomposition ("cross-entropy plus negative posterior entropy"), by adding and subtracting \(\mathbb{E}_{q(z)}[\log q(z)] = \mathbb{E}_{p^*(x)q(z|x)}[\log q(z)]\) from the two terms. As in the first decomposition, the prior \(p(z)\) only appears in one term -- its divergence to the aggregate posterior, which immediately reveals that the optimal \(p(z) = q(z)\). This decomposition also nicely reveals the minimum value of the ELBO regularizer (when \(p(z) = q(z)\)) as the mutual information.
Looked at another way, given any joint distribution \(p^*(x) q(z|x)\), the aggregate ELBO regularizer is a variational upper bound on its mutual information, with \(p(z)\) being any arbitrary variational distribution,
This is quite a useful upper bound on mutual information, a quantity that often appears in machine learning but is typically challenging to compute; see Poole et. al 2019. This upper bound also plays a key role in the Blahut-Arimoto algorithm, a well-established algorithm from information theory for computing rate-distortion functions and channel capacity (see Section 19.8 of the Elements of Information Theory textbook).
3. As the KL divergence between joint distributions \(p^*(x) q(z|x)\) and \(p^*(x) p(z)\):
This is an intuitive restatement of effect of the aggregate ELBO regularizer in pulling the approximate posteriors towards the prior, in an "aggregate" sense. This form in terms of a KL divergence between two distributions defined on the product space \(\mathcal{X} \times \mathcal{Z}\) has some technical advantages, especially in the convergence proof of the Blahut-Arimoto algorithm.
4 (bonus content). This a decomposition of the whole aggregate ELBO, rather than just the aggregate ELBO regularizer term, but I'm including it given its popularity in the literature:
which is very much analogous to how the (true) maximum likelihood objective decomposes into the negative KL between the data and the model distributions, minus the data entropy:
To obtain this decomposition, first integrate the following equality w.r.t. \(p^*(x)\)
to get
As a sidenote: this tells us maximizing the aggregate ELBO objective (LHS) is equivalent to maximizing the marginal log-likelihood \(\mathbb{E}_{p^*(x)}[\log p(x)]\), modulo a gap equal to an aggregate mismatch between the variational and true posterior \(\mathbb{E}_{p^*(x)} [ KL(q(z|x)\|p(z|x)) ]\). Next, replace the marginal log-likelihood (negative cross-entropy) term with
and finally, replace the marginal KL divergence using the chain rule of KL divergence,
we'll see that the \(\mathbb{E}_{p^*(x)} [ KL(q(z|x)\|p(z|x)) ]\) terms cancel and the desired result follows.
This form is discussed at length in, e.g., Structured Disentangled Representations (note that in their definition, \(p^*(x)\) is the empirical distribution defined by training samples, rather the true data distribution).
-
It's possible to relax this assumption if we consider a more general notion of ``density'' than the standard Lebesgue density on \(\mathbb{R}^n\). i.e., if we know in advance the "manifold" on which the data is concentrated, then the data distribution has a density function (Radon–Nikodym derivative) w.r.t. the reference measure of that manifold, which we can still model/learn with a latent variable model. ↩
-
Throughout this note I use the words "distribution" and "density" interchangeably. ↩