In the previous post , we have introduced the fundamental mechanism behind diffusion models: the forward process and the reverse process. Now, we will go into more detail about how to train and sample from the diffusion-based denoising network.
# Training diffusion modelsInterestingly, we can see the forward and backward process of diffusion closely resemble that of the well-known Variational Autoencoder (VAE). Recall that in VAE, we leverage an Encoder to map the data into a latent (Gaussian) distribution, and then use a Decoder to reconstruct the original data from the latent. However, in diffusion model, we only learn the reverse process (the decoder) while the forward process has no learnable parameters. In particular, we can use similar derivations of the variational lower bound (VLB) of the VAE to optimize the likelihood:
log p θ ( x 0 ) ≥ log p θ ( x 0 ) − D KL ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) ) (7) \log p_\theta({x}_0) \geq \log p_\theta({x}_0) - D_\text{KL}(q({x}_{1:T}\vert{x}_0) \| p_\theta({x}_{1:T}\vert{x}_0) ) \tag{7} log p θ ( x 0 ) ≥ log p θ ( x 0 ) − D KL ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) ) ( 7 ) Since the D KL D_\text{KL} D KL is always non-negative, to maximize the log-likelihood of the data log p θ ( x 0 ) \log p_\theta({x}_0) log p θ ( x 0 ) , we aim to maximize the VLB (the RHS) of the above equation. This is equivalent to minimizing the negation of the RHS, which will later lead to our L vlb L_\text{vlb} L vlb loss:
− R H S = − log p θ ( x 0 ) + D KL ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) = − log p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) / p θ ( x 0 ) ] = − log p θ ( x 0 ) + E q [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) + log p θ ( x 0 ) ] = − log p θ ( x 0 ) + E q [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] + E q [ log p θ ( x 0 ) ] \\ -RHS = -\log p_\theta({x}_0) + D_\text{KL}(q({x}_{1:T}\vert{x}_0) \| p_\theta({x}_{1:T}\vert{x}_0)\\ = -\log p_\theta({x}_0) + \mathbb{E}_{{x}_{1:T}\sim q({x}_{1:T} \vert {x}_0)} \Big[ \log\frac{q({x}_{1:T}\vert{x}_0)}{p_\theta({x}_{0:T}) / p_\theta({x}_0)} \Big] \\ = -\log p_\theta({x}_0) + \mathbb{E}_q \Big[ \log\frac{q({x}_{1:T}\vert{x}_0)}{p_\theta({x}_{0:T})} + \log p_\theta({x}_0) \Big] \\ = -\log p_\theta({x}_0) + \mathbb{E}_q \Big[ \log \frac{q({x}_{1:T}\vert{x}_0)}{p_\theta({x}_{0:T})} \Big] + \mathbb{E}_q [\log p_\theta({x}_0)] − R H S = − log p θ ( x 0 ) + D KL ( q ( x 1 : T ∣ x 0 ) ∥ p θ ( x 1 : T ∣ x 0 ) = − log p θ ( x 0 ) + E x 1 : T ∼ q ( x 1 : T ∣ x 0 ) [ log p θ ( x 0 : T ) / p θ ( x 0 ) q ( x 1 : T ∣ x 0 ) ] = − log p θ ( x 0 ) + E q [ log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) + log p θ ( x 0 ) ] = − log p θ ( x 0 ) + E q [ log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] + E q [ log p θ ( x 0 ) ] Since that p θ ( x 0 ) p_\theta(x_0) p θ ( x 0 ) does not contain any learnable parameters, we can define our training loss L vlb L_\text{vlb} L vlb as:
L vlb = E q [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] (8) L_\text{vlb} = \mathbb{E}_q \Big[ \log \frac{q({x}_{1:T}\vert{x}_0)}{p_\theta({x}_{0:T})} \Big] \tag{8} L vlb = E q [ log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] ( 8 ) The training objective can be further decomposed as the combination of several KL terms with respect to the time t t t :
L vlb = E q ( x 0 : T ) [ log q ( x 1 : T ∣ x 0 ) p θ ( x 0 : T ) ] = E q [ log ∏ t = 1 T q ( x t ∣ x t − 1 ) p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ] = E q [ − log p θ ( x T ) + ∑ t = 1 T log q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t ∣ x t − 1 ) p θ ( x t − 1 ∣ x t ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] , ( q ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 , x 0 ) ) = E q [ − log p θ ( x T ) + ∑ t = 2 T log ( q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) ⋅ q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] , ( Bayes rule ) = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + ∑ t = 2 T log q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log ∏ t = 2 T q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log q ( x T ∣ x 0 ) q ( x 1 ∣ x 0 ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) ] = E q [ − log p θ ( x T ) q ( x T ∣ x 0 ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) + log q ( x 1 ∣ x 0 ) p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] = E q [ log q ( x T ∣ x 0 ) p θ ( x T ) + ∑ t = 2 T log q ( x t − 1 ∣ x t , x 0 ) p θ ( x t − 1 ∣ x t ) − log p θ ( x 0 ∣ x 1 ) ] = E q [ D KL ( q ( x T ∣ x 0 ) ∥ p θ ( x T ) ) ⏟ L T + ∑ t = 2 T D KL ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) ⏟ L t − 1 − log p θ ( x 0 ∣ x 1 ) ⏟ L 0 ] L_\text{vlb} = \mathbb{E}_{q({x}_{0:T})} \Big[ \log\frac{q({x}_{1:T}\vert{x}_0)}{p_\theta({x}_{0:T})} \Big] \\ = \mathbb{E}_q \Big[ \log\frac{\prod_{t=1}^T q({x}_t\vert{x}_{t-1})}{ p_\theta({x}_T) \prod_{t=1}^T p_\theta({x}_{t-1} \vert{x}_t) } \Big] \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \sum_{t=1}^T \log \frac{q({x}_t\vert{x}_{t-1})}{p_\theta({x}_{t-1} \vert{x}_t)} \Big] \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \color{green}{\sum_{t=2}^T \log \frac{q({x}_t\vert{x}_{t-1})}{p_\theta({x}_{t-1} \vert{x}_t)}} + \log\frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1)} \Big], \scriptsize(q(x_t | x_{t-1}) = q(x_t | x_{t-1}, x_0)) \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \color{green}{\sum_{t=2}^T \log \Big( \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)}\cdot \frac{q({x}_t \vert {x}_0)}{q({x}_{t-1}\vert{x}_0)} \Big)} + \log \frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1)} \Big], \scriptsize(\text{Bayes rule}) \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \color{green}{\sum_{t=2}^T \log \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)} + \sum_{t=2}^T \log \frac{q({x}_t \vert {x}_0)}{q({x}_{t-1} \vert {x}_0)}} + \log\frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1)} \Big] \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \color{green}{\sum_{t=2}^T \log \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)} + \log \prod_{t=2}^T\frac{q({x}_t \vert {x}_0)}{q({x}_{t-1} \vert {x}_0)}} + \log\frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1)} \Big] \\ = \mathbb{E}_q \Big[ -\log p_\theta({x}_T) + \color{green}{\sum_{t=2}^T \log \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)} + \log\frac{q({x}_T \vert {x}_0)}{q({x}_1 \vert {x}_0)}}+ \log \frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1)} \Big]\\ = \mathbb{E}_q \Big[ -\log \frac{p_\theta({x}_T)}{\color{green}{q({x}_T \vert {x}_0)}} + \color{green}{\sum_{t=2}^T \log \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)}}+ \log \frac{q({x}_1 \vert {x}_0)}{p_\theta({x}_0 \vert {x}_1) \color{green}{q({x}_1 \vert {x}_0)}} \Big]\\ = \mathbb{E}_q \Big[ \log\frac{q({x}_T \vert {x}_0)}{p_\theta({x}_T)} + \sum_{t=2}^T \log \frac{q({x}_{t-1} \vert {x}_t, {x}_0)}{p_\theta({x}_{t-1} \vert{x}_t)} - \log p_\theta({x}_0 \vert {x}_1) \Big] \\ = \mathbb{E}_q [\underbrace{D_\text{KL}(q({x}_T \vert {x}_0) \parallel p_\theta({x}_T))}_{L_T} + \sum_{t=2}^T \underbrace{D_\text{KL}(q({x}_{t-1} \vert {x}_t, {x}_0) \parallel p_\theta({x}_{t-1} \vert{x}_t))}_{L_{t-1}} \underbrace{- \log p_\theta({x}_0 \vert {x}_1)}_{L_0} ] L vlb = E q ( x 0 : T ) [ log p θ ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q [ log p θ ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) ∏ t = 1 T q ( x t ∣ x t − 1 ) ] = E q [ − log p θ ( x T ) + t = 1 ∑ T log p θ ( x t − 1 ∣ x t ) q ( x t ∣ x t − 1 ) ] = E q [ − log p θ ( x T ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t ∣ x t − 1 ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] , ( q ( x t ∣ x t − 1 ) = q ( x t ∣ x t − 1 , x 0 ) ) = E q [ − l o g p θ ( x T ) + t = 2 ∑ T l o g ( p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) ⋅ q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] , ( Bayes rule ) = E q [ − l o g p θ ( x T ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) + t = 2 ∑ T l o g q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] = E q [ − l o g p θ ( x T ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) + l o g t = 2 ∏ T q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] = E q [ − l o g p θ ( x T ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) + l o g q ( x 1 ∣ x 0 ) q ( x T ∣ x 0 ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) ] = E q [ − l o g q ( x T ∣ x 0 ) p θ ( x T ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) + l o g p θ ( x 0 ∣ x 1 ) q ( x 1 ∣ x 0 ) q ( x 1 ∣ x 0 ) ] = E q [ l o g p θ ( x T ) q ( x T ∣ x 0 ) + t = 2 ∑ T l o g p θ ( x t − 1 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) − l o g p θ ( x 0 ∣ x 1 ) ] = E q [ L T D KL ( q ( x T ∣ x 0 ) ∥ p θ ( x T ) ) + t = 2 ∑ T L t − 1 D KL ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) L 0 − l o g p θ ( x 0 ∣ x 1 ) ] In short, the L vlb L_\text{vlb} L vlb can be written as:
L vlb = L T + L T − 1 + . . . + L 1 + L 0 L_\text{vlb} = L_T + L_{T-1} + ... + L_{1} + L_0 L vlb = L T + L T − 1 + . . . + L 1 + L 0 where L t = D KL ( q ( x t ∣ x t + 1 , x 0 ) ∥ p θ ( x t ∣ x t + 1 ) ) , ∀ 1 ≤ t ≤ T − 1 L_t = D_\text{KL}(q({x}_{t} \vert {x}_{t+1}, {x}_0) \parallel p_\theta({x}_{t} \vert{x}_{t+1})), \forall 1 \leq t \leq T-1 L t = D KL ( q ( x t ∣ x t + 1 , x 0 ) ∥ p θ ( x t ∣ x t + 1 ) ) , ∀ 1 ≤ t ≤ T − 1 . We can see that these KL terms (except L 0 L_0 L 0 ) are between two Gaussian and can be computed in closed form, whereas the L T L_T L T term can be ignored during training since it has no learnable parameters.
Recall that we need to train a neural network to approximate the conditional distributions in the reverse process p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta({x}_{t-1} \vert {x}_t) = \mathcal{N}({x}_{t-1}; {\mu}_\theta({x}_t, t), {\Sigma}_\theta({x}_t, t)) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) . Moreover, the posterior mean μ t \mu_t μ t (see Equation 5 and 6 in this post ) of the reverse distribution q ( x t ∣ x t + 1 , x 0 ) q({x}_{t} \vert {x}_{t+1}, {x}_0) q ( x t ∣ x t + 1 , x 0 ) is:
μ ~ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \color{orange}{\tilde{\mu}_t(x_t,x_0)} = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} x_0 μ ~ t ( x t , x 0 ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + 1 − α ˉ t α ˉ t − 1 β t x 0 Considering the property of Equation 2, i.e., x t = α ˉ t x 0 + 1 − α ˉ t ϵ t x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon_t x t = α ˉ t x 0 + 1 − α ˉ t ϵ t , we can re-expressed μ t \mu_t μ t by replacing x 0 x_0 x 0 as:
μ ~ t = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t 1 α ˉ t ( x t − 1 − α ˉ t ϵ t ) = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) (9) \color{orange}{\tilde{{\mu}}_t} = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} {x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \frac{1}{\sqrt{\bar{\alpha}_t}}({x}_t - \sqrt{1 - \bar{\alpha}_t}{\epsilon}_t) \\ = \frac{1}{\sqrt{\alpha_t}} \Big( {x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} {\epsilon}_t \Big) \tag{9} μ ~ t = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + 1 − α ˉ t α ˉ t − 1 β t α ˉ t 1 ( x t − 1 − α ˉ t ϵ t ) = α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ t ) ( 9 ) Our ultimate goal is to train the network θ \theta θ to approximate the mean μ θ ( x t , t ) \color{blue}{\mu_\theta({x}_t, t)} μ θ ( x t , t ) , and the loss term L t L_t L t measured the KL Divergence between two Gaussian. Furthermore, Ho et. al [1] suggested predicting the noise ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵ θ ( x t , t ) using the re-parameterization as in Equation 9, and keeping the variance Σ \Sigma Σ fixed to the schedule. Therefore, we can calculate the loss as follows:
L t = E x 0 , ϵ [ 1 2 ∥ Σ t ∥ 2 2 ∥ μ ~ t − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 1 2 ∥ Σ t ∥ 2 2 ∥ 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ t ) − 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ ( 1 − α t ) 2 2 α t ( 1 − α ˉ t ) ∥ Σ t ∥ 2 2 ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] L_t = \mathbb{E}_{{x}_0, {\epsilon}} \Big[\frac{1}{2 \| \Sigma_t \|^2_2} \| \color{orange}{\tilde{{\mu}}_t} - \color{blue}{{\mu}_\theta({x}_t, t)} \|^2 \Big] \\ = \mathbb{E}_{{x}_0, {\epsilon}} \Big[\frac{1}{2 \|\Sigma_t \|^2_2} \| \color{orange}{\frac{1}{\sqrt{\alpha_t}} \Big( {x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} {\epsilon}_t \Big)} - \color{blue}{\frac{1}{\sqrt{\alpha_t}} \Big( {x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} {{\epsilon}}_\theta({x}_t, t) \Big)} \|^2 \Big] \\ = \mathbb{E}_{{x}_0, {\epsilon}} \Big[\frac{ (1 - \alpha_t)^2 }{2 \alpha_t (1 - \bar{\alpha}_t) \| \Sigma_t \|^2_2} \| \color{orange}{\epsilon_t} - \color{blue}{\epsilon_\theta({x}_t, t)} \|^2 \Big] L t = E x 0 , ϵ [ 2 ∥ Σ t ∥ 2 2 1 ∥ μ ~ t − μ θ ( x t , t ) ∥ 2 ] = E x 0 , ϵ [ 2 ∥ Σ t ∥ 2 2 1 ∥ α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ t ) − α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ θ ( x t , t ) ) ∥ 2 ] = E x 0 , ϵ [ 2 α t ( 1 − α ˉ t ) ∥ Σ t ∥ 2 2 ( 1 − α t ) 2 ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] [1] also suggested the simplification of this loss (i.e., removing the weighting term) as they found the training results are better, which leads to the "simple" objective:
L simple = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] (10) L_\text{simple} = \mathbb{E}_{{x}_0, {\epsilon}} \Big[ \| \color{orange}{\epsilon_t} - \color{blue}{\epsilon_\theta({x}_t, t)} \|^2 \Big] \tag{10} L simple = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] ( 1 0 ) In summary, the training process iteration is as follows:
Repeat Sample data x 0 ∼ q ( x 0 ) x_0 \sim q(x_0) x 0 ∼ q ( x 0 ) Sample t ∼ Uniform [ 1 , T ] t\sim\text{Uniform}[1,T] t ∼ Uniform [ 1 , T ] Sample noise ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0,I) ϵ ∼ N ( 0 , I ) Compute loss L simple = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] L_\text{simple} = \mathbb{E}_{{x}_0, {\epsilon}} \Big[ \| \epsilon_t - \epsilon_\theta({x}_t, t) \|^2 \Big] L simple = E x 0 , ϵ [ ∥ ϵ t − ϵ θ ( x t , t ) ∥ 2 ] Backprop and update network parameters θ \theta θ Until convergence # Sampling diffusion modelsTo obtain a sample from the original data distribution, we start by sampling from the noise distribution q ( x T ) q(x_T) q ( x T ) and then gradually remove the noise until we reach x 0 x_0 x 0 , following the reverse process. At each step, we sample from the approximated reverse distribution:
p θ ( x t − 1 ∣ x t ) = N ( x t ; μ θ ( x t , t ) , Σ t ) = N ( x t ; 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) , Σ t ) p_\theta(x_{t-1}|x_{t}) = \mathcal{N}(x_t; \mu_\theta(x_t,t), \Sigma_t )= \mathcal{N}(x_t; \frac{1}{\sqrt{\alpha_t}} \big( {x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} {\epsilon}_\theta(x_t,t) \big) , \Sigma_t) p θ ( x t − 1 ∣ x t ) = N ( x t ; μ θ ( x t , t ) , Σ t ) = N ( x t ; α t 1 ( x t − 1 − α ˉ t 1 − α t ϵ θ ( x t , t ) ) , Σ t ) The sampling process can be summarized as follows:
Firstly sample x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0,I) x T ∼ N ( 0 , I ) For t = T → 1 t = T \rightarrow 1 t = T → 1 do sample from posterior distribution:z ∼ N ( 0 , I ) z \sim \mathcal{N}(0,I) z ∼ N ( 0 , I ) if t > 1 t>1 t > 1 , else z = 0 z=0 z = 0 x t − 1 = μ θ ( x t , t ) + Σ t z x_{t-1} = \mu_\theta(x_t,t) + \Sigma_tz x t − 1 = μ θ ( x t , t ) + Σ t z , (reparameterization trick)Return x 0 x_0 x 0 In the original Denoising Diffusion Probabilistic Model (DDPM), the sampling process is usually quite slow to obtain a sample. This is because we need to follow the whole chain of the reverse diffusion process from T T T to 0 0 0 , where the number of steps is mostly up to T = 1000 T=1000 T = 1 0 0 0 . For example, it takes around 20 hours to sample 50k images of size 32 × 32 from a DDPM, but less than a minute to do so from a GAN model [2]. Many recent works have proposed several strategies to overcome this limitation to speed up the sampling process [2], [3], [4]. They can generate samples within only a few steps (e.g., 50 steps) while the sample quality can be as high as the full sampling process.
# References[1] Ho J, Jain A, Abbeel P. Denoising diffusion probabilistic models. In NIPS 2020.
[2] Song J, Meng C, Ermon S. Denoising diffusion implicit models. in ICLR 2021.
[3] Liu L, Ren Y, Lin Z, Zhao Z. Pseudo numerical methods for diffusion models on manifolds. ICLR 2022.
[4] Salimans T, Ho J. Progressive distillation for fast sampling of diffusion models. ICLR 2022.