Probabilistic generative models are powerful ways to generate data from various modalities, including image, video and even genomic sequences.
Note: These notes are still in an evolving state.
Before we talk about diffusion models we should first talk about VAEs. Historically, VAEs were introduced first with the work of Kingma and Welling [1].
The important thing to understand here is the general setup of probabilistic modeling.
<aside>
Probabilistic Modeling with Latent Variables
Suppose we have some data $X$ that is sampled via a random process. The process is as follows. First, a latent variable $z$ is generated according to some distribution $p(z)$. Then, data point $x$ is sampled from some conditional distribution $p(x \mid z)$.
These latent variables can be thought of as an unobserved, internal representation of the dataset $X$. For instance, if $X$ consists of pictures of dogs, the latent variables could be thought of as living in a highly specialized submanifold that encodes the characteristics of dogs.
</aside>
There are four important actors to consider in this regime:
$p(z)$: the distribution of the latent variable.
$p(x\mid z)$: the conditional likelihood - expresses the probability of seeing a dataset given a specific latent.
$p(z \mid x)$: the posterior. It is a conditional distribution over latents given the data point.
$p(x)$: the distribution of the data points that nature provides (e.g. a distribution over images of dogs). Our goal is ultimately to sample from this distribution! This marginal distribution is intractable (we don’t know it and cannot calculate it), though we can think of it like:
$$ p_\theta(x) = \int p_\theta(z)p_\theta(x\mid z)dz $$
<aside>
Probability Distribution Proxy
The key idea is that because we don’t know the distributions involved with $X$ we approximate them via proxies. These will parameterized using deep neural networks:
$$ p(z\mid x)\approx q_\phi(z\mid x),\quad p(x\mid z)\approx p_\theta(x\mid z) $$
The parameters $\phi$ and $\theta$ will be learned eventually.
These two proxy distributions are associated with an encoder and a decoder respectively. $q_\phi(z\mid x)$ is created via a Neural network that compresses the data $x$ into a latent representation. $p_\theta(x \mid z)$ does the opposite.
To parameterize the distributions via neural networks, we only ask the networks to predict the parameters of the distributions. This will work as follows: given $x$ the encoder network predicts $(\mu_\phi(x),\sigma_\phi(x)^2)=\text{Enc}_\phi(x)$. Then:
$$ q_\phi(z\mid x) = \mathcal{N}(x\mid \mu_\phi(x),\sigma_\phi(x)^2)\cdot I). $$
Similarly, if we fix some hyperparameter $\sigma_\text{dec}$, the decoder network gives:
$$ p_\theta(x\mid z)=\mathcal{N}(x\mid f_\theta(z),\sigma^2_\text{dec}I) $$
</aside>
Training the encoder and decoder networks is done by simple maximum likelihood estimation. We try to set the parameters such that the likelihood $p_\theta(x)$ of observing data $x$ is maximized. The issue is that we cannot calculate $p_\theta(x)$ directly. To do so we’d have to consider all possible values of the latent variable $z$ and write:
$$ p_\theta(x) = \int p_\theta(x \mid z)\cdot p(z)dz $$
That is intractable. Instead, we cone up with a different trick. First, we write:
$$ \log p_\theta(x) = \log \int q_\phi(z\mid x)\frac{p_\theta(x,z)}{q_\phi(z \mid x)}dz=\log\mathbb{E}{z\sim q\phi(z \mid x)} \frac{p_{\theta}(x,z)}{q_{\phi}(z\mid x)} $$
Estimating this directly is possible (in fact it gives a model called the Importance Weighted Autoencoder (IWAE)). However, it requires great care to avoid high variance. A simpler approach is to use a lower bound. Jensen’s inequality on the concave $\log$ function gives:
$$ \log p_\theta(x) \geq \mathbb{E}{z \sim q\phi(z\mid x)}\left[\log\frac{p_\theta(x,z)}{q_\phi(z\mid x)}\right]:=\text{ELBO}(x) = \mathcal{L}(x;\theta,\phi) $$
The quantity $\mathcal{L}(x;\theta,\phi)$ is called the Evidence Lower Bound (ELBO). It is a quantity we can easily compute and differentiate, meaning we can maximize it. It serves as a proxy to the true likelihood $p_\theta(x)$ which we cannot directly estimate.