Diffusion Model

Posted by zhaogs on January 29, 2023

Diffusion Model(扩散模型)是经过严格的数学推理得到的,主要用到了贝叶斯公式和马尔科夫链法则,其推导过程与VAE(变分自编码器)有些相像,其中都用到了参数重整化这样的技巧,通过一系列的化简,最终可以得到一个相对简单的形式,其实验效果在去噪,图像生成等领域取得了一定的进展。

Diffusion Mode分为两个过程,扩散过程与逆扩散过程,顾名思义,逆扩散过程就是扩散过程的逆过程。以下就来详细介绍一下Diffusion Model的数学推导过程。

1. 扩散过程

给定一个初始数据分布\(x_0\sim q(x)\) ,不断向其中添加高斯噪声,每次添加噪声的标准差是以固定值\(\beta_t\)而确定的,均值是以固定值\(\beta_t\)和当前的\(t\)时刻的数据\(x_t\)共同决定的,\(\beta_t\in(0,1)\)。

通过不停的添加高斯噪声,最终的数据分布就变成了一个各项独立的高斯分布。

其中每次添加高斯噪声后数据分布与联合概率分布可以表示为

\[\notag q({\bf x_t\vert x_{t-1}})={\cal N}({\bf x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t I})\\ q({\bf x_{1:T}\vert x_0})=\prod_{t=1}^T q({\bf x_t\vert x_{t-1}})\]

从公式中可以看出,这里的添加高斯噪声不是普通的直接相加,若是普通的相加,\(t\)时刻的均值应该是\(x_{t-1}\),这里的相加是经过仿射变换后的相加,这样的好处是当\(t\rightarrow\infty\)时,此时的分布是高斯分布。这里的添加的方式是一种人为设定。

\(x_t\)使用参数重整化方式可以表示为\(x_t=\sqrt{1-\beta_t}x_{t-1}+\sqrt{\beta_t}z_{t-1},\quad z_{t-1}\sim {\cal N}(0,I)\)

令\(\alpha_t=1-\beta_t,\bar \alpha_t = \prod_{i=1}^T\alpha_i\),

\[\notag \begin{aligned} x_t &= \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_{t-1}\\ &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t}z_{t-1}+\sqrt{\alpha_t-\alpha_t\alpha_{t-1}}z_{t-2} \end{aligned}\]

对于\(z_{t-1},z_{t-2}\),二者皆为正态分布,两个高斯分布的和依然是高斯分布,对于\(X\sim {\cal N}(\mu_1,\sigma_1^2),Y\sim {\cal N}(\mu_2,\sigma_2^2)\),叠加后的\(aX+bY\)服从\({\cal N}(a\mu_1+b\mu_2,\sqrt{a^2\mu_1^2+b^2\mu_2^2})\)。

因此上式可以进一步化简可得

\[\notag \begin{aligned} x_t &= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}\bar z_{t-2}\\ &= \cdots\\ &= \sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar\alpha_t}\bar z_0 \end{aligned}\]

至此我们就得到的\(q(x_t\vert x_0)\),即只要给定了\(x_0\),我们就可以计算出对应的\(x_t\)。

2. 逆扩散过程

逆扩散过程是上述过程的逆过程,是要从噪声数据中恢复出原始数据的过程。

逆过程其实就是得到\(q(x_{t-1}\vert x_t )\),然而得到这样的真实分布是非常困难的,因此可以使用神经网络来拟合这个分布函数。这也就是diffusion model的目的。

由于\(x_T\sim {\cal N}(0,I)\),当\(\beta_t\)足够小时\(q(x_{t-1}\vert x_t)\)也是高斯分布,由于这个分布很复杂,因此可以使用神经网络来计算它的参数,即使用神经网络得到一个模型\(p_\theta\),使用这个模型来逼近\(q(x_{t-1}\vert x_t)\),即可得

\[\notag p_\theta(x_{t-1}\vert x_t)={\cal N}(x_{t-1};\mu_\theta(x_t;t),\Sigma_\theta(x_t,t))\]

要想求\(p_\theta\)的分布,首先需要求出\(q(x_{t-1}\vert x_t)\)的真实分布情况,根据目前已知先验分布,我们需要已知\(x_0\)来求出\(q(x_{t-1}\vert x_t, x_0)\),得到后后验分布。

\[\notag \begin{aligned} q(x_{t-1}\vert x_t,x_0) &=q(x_t \vert x_{t-1},x_0)\frac{q(x_{t-1}\vert x_0)}{q(x_t\vert x_0)}\\ &\propto \exp (-\frac{1}{2}(\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{\beta_t}+\frac{(x_{t-1}-\sqrt{\bar \alpha_{t-1}}x_0)^2}{1-\bar \alpha_{t-1}}-\frac{(x_t-\sqrt{\bar \alpha_{t}}x_0)^2}{1-\bar \alpha_t}))\\ & = \exp(-\frac{1}{2}((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar \alpha_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\bar \alpha_{t-1}}}{1-\bar \alpha_{t-1}}x_0)x_{t-1}+C(x_t,x_0)) \end{aligned}\]

至此可以得到一个关于\(x_{t-1}\)的一元二次方程,进一步化为高斯分布的形式,可以得到

\[\notag \begin{aligned} \widetilde \beta &= 1/(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar \alpha_{t-1}})=\frac{1-\bar \alpha_{t-1}}{1-\bar \alpha_t}\beta_t\\ \widetilde \mu &= (\frac{\sqrt{\alpha_t}}{\beta_t}x_t+\frac{\sqrt{\bar \alpha_{t-1}}}{1-\bar \alpha_{t-1}}x_0) /(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar \alpha_{t-1}})=\frac{\sqrt{\alpha_t}(1-\bar \alpha_{t-1})}{1-\bar \alpha_t}x_t+\frac{\sqrt{\bar \alpha_{t-1}}\beta_t}{1-\bar \alpha_t}x_0 \end{aligned}\]

由之前\(x_t,x0\)的关系可以将\(\widetilde\mu\)表示为\(x_t\)的函数

\[\notag \begin{aligned} x_0 &= \frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar \alpha_t}z_t) \\ \widetilde \mu &= \frac{\sqrt{\alpha_t}(1-\bar \alpha_{t-1})}{1-\bar \alpha_t}x_t+\frac{\beta_t}{\sqrt{\alpha_t}(1-\bar \alpha_t)}(x_t-\sqrt{1-\bar \alpha_t}z_t)\\ &= \frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar \alpha_t}}z_t) \end{aligned}\]

这里\(z_t\)是未知量,这里形式上可以从\(x_t\)推出\(x_0\),但实际上,这里的\(z_t\)是未知量,虽然已知它服从正态分布,但无法计算出每次的采样值,因此在diffusion model中,设计的网络就是为了求出每次生成的噪声。

3.损失函数

目标数据分布的似然函数上界可以表示为

\[\notag \begin{aligned} -\log p_{\theta}(x_0) &\le-\log p_\theta(x_0)+D_{KL}(q(x_{1:T}\vert x_0) \Vert p_\theta(x_{1:T}\vert x_0))\\ &=-\log p_\theta(x_0)+ {\Bbb E}_{x_{1:T}\sim q(x_{1:T}\vert x_0)}\left[ \log \frac{q(x_{1:T \vert x_0})}{p_\theta(x_{0:T})/p_\theta(x_0)}\right]\\ &=-\log p_\theta(x_0)+{\Bbb E}_{q}\left[\log \frac{q(x_{1:T}\vert x_0)}{p_\theta(x_{0:T})}+\log p_\theta(x_0)\right]\\ &= {\Bbb E}_{q}\left[\log \frac{q(x_{1:T}\vert x_0)}{p_\theta(x_{0:T})}\right] \end{aligned}\]

即得到了目标数据分布似然函数的上界,对其化简可得

\[\notag \begin{aligned} L_{VLB} &= {\Bbb E}_{q(x_{0:T})}\left[\log \frac{q(x_{1:T}\vert x_0)}{p_\theta(x_{0:T})}\right]\\ &= {\Bbb E}_{q}\left[\log \frac{\prod_{t=1}^Tq(x_t\vert x_{t-1})}{p_\theta(x_{T})\prod_{i=1}^Tp_\theta(x_{t-1}\vert x_{t})}\right]\\ &= {\Bbb E}_{q}\left[-\log p_\theta(x_T)+\sum_{t=1}^T\log \frac{q(x_t\vert x_{t-1})}{p_\theta(x_{t-1}\vert x_t)}\right]\\ &= {\Bbb E}_{q}\left[-\log p_\theta(x_T)+\sum_{t=2}^T\log \frac{q(x_t\vert x_{t-1})}{p_\theta(x_{t-1}\vert x_t)}+\log \frac{q(x_1\vert x_0)}{p_\theta(x_0|x_1)}\right]\\ &= {\Bbb E}_{q}\left[-\log p_\theta(x_T)+\sum_{t=2}^T\log \left(\frac{q(x_{t-1}\vert x_{t},x_0)}{p_\theta(x_{t-1}\vert x_t)}\cdot \frac{q(x_t\vert x_0)}{q(x_{t-1}\vert x_0)}\right)+\log \frac{q(x_1\vert x_0)}{p_\theta(x_0|x_1)}\right]\\ &= {\Bbb E}_{q}\left[-\log p_\theta(x_T)+\sum_{t=2}^T\log \left(\frac{q(x_{t-1}\vert x_{t},x_0)}{p_\theta(x_{t-1}\vert x_t)}\right)+\sum_{t=2}^T \left (\frac{q(x_t\vert x_0)}{q(x_{t-1}\vert x_0)}\right)+\log \frac{q(x_1\vert x_0)}{p_\theta(x_0|x_1)}\right]\\ &= {\Bbb E}_{q}\left[-\log p_\theta(x_T)+\sum_{t=2}^T\log \left(\frac{q(x_{t-1}\vert x_{t},x_0)}{p_\theta(x_{t-1}\vert x_t)}\right)+\log \frac{q(x_T\vert x_0)}{q(x_{1}\vert x_0)}+\log \frac{q(x_1\vert x_0)}{p_\theta(x_0|x_1)}\right]\\ &= {\Bbb E}_{q}\left[\log \frac{q(x_T\vert x_0)}{p_\theta(x_T)}+\sum_{t=2}^T\log \left(\frac{q(x_{t-1}\vert x_{t},x_0)}{p_\theta(x_{t-1}\vert x_t)}\right)-\log {p_\theta(x_0|x_1)}\right]\\ &= {\Bbb E}_q\left[\underbrace {D_{KL}(q(x_T\vert x_0)\Vert p_\theta(x_T))}_{L_T}+\sum_{t=2}^T \underbrace{D_{KL}(q(x_{t-1}\vert x_t,x_0)\Vert p_\theta(x_{t-1}\vert x_{t}))}_{L_{t-1}}-\underbrace{\log p_\theta(x_0\vert x_1)}_{L_0} \right] \end{aligned}\]

通过化简就得到了后验概率的真实数据分布于估计分布之间的KL散度。

\(q\)是真实分布,其均值上面已经求出,方差为固定常数,\(p\)是估计分布,其均值由网络计算得到,方差为常数。两个高斯分布的KL散度可以简化为一个公式\(KL(p,q)=\log \frac{\sigma_1}{\sigma_2}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2}\),去掉常数项并化简上式可得

\[\notag \begin{aligned} L_{t-1} &= {\Bbb E}_q\left[\frac{1}{2\sigma_t^2}\Vert\widetilde\mu_t(x_t,x_0)-\mu_\theta(x_t,t) \Vert^2\right]+C\\ L_{t-1}-C &= {\Bbb E}_{x_0,\epsilon}\left[\frac{1}{2\sigma_t^2}\Vert\widetilde\mu_t(x_t(x_0,\epsilon),\frac{1}{\sqrt{\bar \alpha_t}}(x_t(x_0,\epsilon)-\sqrt{1-{\bar \alpha_t}}\epsilon))-\mu_\theta(x_t(x_0,\epsilon),t) \Vert^2\right]\\ &= {\Bbb E}_{x_0,\epsilon}\left[\frac{1}{2\sigma_t^2}\Vert\frac{1}{\sqrt{\alpha_t}}(x_t(x_0,\epsilon)-\frac{\beta_t}{\sqrt{1-\bar \alpha_t}}\epsilon)-\mu_\theta(x_t(x_0,\epsilon),t) \Vert^2\right]\\ \end{aligned}\]

从上式中可以看出,设计出的网络的训练目标是让\(\mu_\theta\)逼近\(\widetilde \mu_t\),上式中,我们已知的是\(x_t\),将其看作是一个常数,其余量作为\(x_t\)的函数进行表示,即

\[\notag \begin{aligned} \mu_\theta(x_t,t)&=\widetilde \mu_t(x_t,\frac{1}{\sqrt{\bar \alpha_t}}(x_t-\sqrt{1-\bar \alpha_t}\epsilon_\theta(x_t)))\\ &=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar \alpha_t}}\epsilon_\theta(x_t,t))\\ L_{t-1}-C&= {\Bbb E}_{x_0,\epsilon}\left[\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar\alpha_t)}\Vert \epsilon-\epsilon_\theta(\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar \alpha_t}\epsilon,t)\Vert^2\right]\\ L_{\rm simple} &= {\Bbb E}_{t,x_0,\epsilon}\left[\Vert \epsilon-\epsilon_\theta(\sqrt{\bar \alpha_t}x_0+\sqrt{1-\bar \alpha_t}\epsilon,t)\Vert^2\right] \end{aligned}\]

实际上diffusion model在做的事情就是估计出每次添加的高斯噪声值