SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers

Teaser Image

We present Scalable interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the objective for the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler.

By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06.

SiT sota
Selected samples from our largest SiT-XL models trained on ImageNet at \(512\times512\) and \(256\times256\) resolutions with classifier-free guidance scale of \(4.0\), respectively.

Flow and Diffusion

In recent years a family of flexible generative model based on transforming pure noise \(\varepsilon \sim \mathcal{N}(0, \mathbf{I})\) into data \(x_* \sim p(x)\) has emerged. This transformation can be described by a simple time-dependent process $$ x_t = \alpha_t x_* + \sigma_t \varepsilon $$ with t defined on \([0, T]\), \(\alpha_t, \sigma_t\) being time-dependent functions and chosen such that \(x_0 \sim p(x)\), \(x_T \sim \mathcal{N}(0, \mathbf{I})\). At each \(t\), \(x_t\) has a conditional density \(p_t(x | x_*) = \mathcal{N}(\alpha_t x_*, \sigma_t^2\mathbf{I})\), and our goal is to estimate the marginal density \(p_t(x) = \int p_t(x | x_*) p(x) \mathrm{d}x \). In most cases, the marginal density \(p_t(x)\) is intractable. Some previous methods focus on maximizing the likelihood \(\log p_t(x)\), whereas modern approaches take advantage of differential equations with the corresponding marginal density \(p_t(x)\) to directly estmate data samples from \(p(x)\).

Diffusion-Based Models. Diffusion-Based Models is the most commonly used framework for this transformation. The \(\alpha_t\) and \(\sigma_t\) are set indrectly by a forward-time stochastic differential equation (SDE) with \(\mathcal{N}(0, \mathbf{I})\) as equilibrium distribution. $$ dX_t = f(X_t, t) \mathrm{d}t + g(t) \mathrm{d} W_t $$ where \(W_t\) is a standard Brownian motion.
In practice, the model samples the process by learning the gradient of the likelihood \(\nabla \log p_t(x)\) (score) with a generative model \(s_\theta(x_t, t)\) under the score matching objective $$ \mathcal{L}_s(\theta) = \int \mathbb{E}[\Vert \sigma_t s_\theta(x_t, t) + \varepsilon \Vert^2] \mathrm{d}t $$ Inference is done by solving either a probability flow ODE $$\mathrm{d}X_t = [f(X_t, t) - \frac{1}{2} g^2(t) \nabla \log p_t(x)] \mathrm{d}t $$ or a reverse-time SDE $$ \mathrm{d}X_t = [f(X_t, t) - g^2(t) \nabla \log p_t(x)] \mathrm{d}t + g(t) \mathrm{d} \bar{W}_t $$ where \(\bar{W}_t\) is a standard Brownian motion with reversed time. Integrating both equations from \(t = T\) to \(t=0\) with initial condition of a Gaussian noise will push to a data sample approximating \(p(x)\).

Stochastic Interpolant and Flow-Based Models. Stochastic Interpolant and other Flow-Based models are the recent additions to this family, where the \(\alpha_t\) and \(\sigma_t\) are restricted on time interval \([0,1]\) with \(\alpha_0 = \sigma_1 = 1\) and \(\alpha_1 = \sigma_0 = 0\), so that \(x_t\) exactly interpolate between \(x_*\) and \(\varepsilon\). We note that this gives more flexibility in the choice of the interpolating functions, as they are no longer subject to a forward SDE.
Furthermore, these models use a simpler probability flow ODE for inference \[\begin{align*} \mathrm{d}X_t &= \underbrace{[f(X_t, t) - \frac{1}{2} g^2(t) \nabla \log p_t(x)]}_{\text{directly learn this}} \mathrm{d}t \\ \implies \mathrm{d}X_t &= v(X_t, t) \mathrm{d}t \end{align*}\] where the velocity \(v(X_t, t)\) is estimated by the flow matching objective $$ \mathcal{L}_v(\theta) = \int \mathbb{E}[\Vert v(X_t, t) - \dot \alpha_t x_* - \dot \sigma_t \varepsilon \Vert^2] \mathrm{d}t $$ Intuitively, this can be viewed as predicting the velocity of a particle starting moving from some \(\varepsilon\) at time \(t\).

We summarize the components of the above models in the following table:

Diffusion-Based Flow-Based
\( t \) \(\{0, \cdots, T\}\) (DDPM) / \([0,1]\) (SBDM) \([0,1]\)
\( \mathcal{L}(\theta) \) \( \mathcal{L}_s \sim \Vert \sigma_t s_\theta(x_t, t) + \varepsilon\Vert^2 \) \( \mathcal{L}_v \sim \Vert v_\theta(x_t, t) - \dot \alpha_t x_* - \dot \sigma_t \varepsilon \Vert^2 \)
\( x_t \) \( \alpha_t x + \sigma_t \varepsilon \) \( \alpha_t x + \sigma_t \varepsilon \)
ODE \( \mathrm{d}X_t = [f(X_t, t) - \frac{1}{2} g^2(t) \nabla \log p_t(x)] \mathrm{d}t \) \( dX_t = v(X_t, t) \)
SDE \( dX_t =[f(X_t, t) - g^2(t) \nabla \log p_t(x)] \mathrm{d}t + g(t) \mathrm{d} \bar{W}_t \) ?
It has been proved that under the same \(\alpha_t\) and \(\sigma_t\), Diffusion and Flow-based methods share the same time-evolving process: Flow-Based ODE's corresponding \(p_t(x)\) coincides with that of Diffusion-Based ODE and SDE.
In our work, we proceed to demonstrate the mathematical equivalences and performance influences of other components in the above table. We also managed to fill in the question mark by showing that Flow-Based methods can also be sampled by a reverse-time SDE despite the lack of a forward SDE.

Scalable Interpolant Transformers

From the above table, we summarize the design space into four components.

Systematically varying these components in the design space, we construct our SiT model, which consistently outperforms the DiT model in generating \(256 \times 256\) ImageNet images.
Model Params(M) Training Steps FID \(\downarrow\)
DiT-S 33 400K 68.4
SiT-S 33 400K 57.6
DiT-B 130 400K 43.5
SiT-B 130 400K 33.5
DiT-L 458 400K 19.5
SiT-L 458 400K 17.2
DiT-XL 675 400K 9.6
SiT-XL 675 400K 8.6
DiT-XL(cfg=1.5) 675 7M 2.27
SiT-XL(cfg=1.5) 675 7M 2.06
Across all model sizes, SiT performs better than DiT. What's the source of the performance gain?

To investigate such performance improvement, we gradually transition from a DiT model, a typical denoising diffusion model (discrete, denoising, variance preserving, and SDE) to our SiT model via a series of orthogonal steps in the design space. As we progress, we carefully evaluate how each move away from the diffusion model impacts the performance in the following sections.

To conduct such study, we use the same backbone architecture of DiT-B and training hyperparameters for all models. We also maintain number of parameters, GFLOPs, and training schedule (400K steps) to be identical for all models. All the numbers presented in tables are FID-50K score evaluated with respect to ImageNet256 training set, and produced by a 250 steps Heun ODE solver without otherwise specified. For solving the SDE, we used an Euler-Maruyama integrator.

Timespace

The first move away is well-studied: we switch from a discrete-time denoising model to a continuous-time score model. Marginal performance improvement is observed.

Objective FID
DDPM \( \mathcal{L}_s^\dagger \) 44.2
SBDM \( \mathcal{L}_s \) 43.6
DDPM vs. SBDM. \( \dagger \) DDPM uses the discretized objective.

Model Prediction

We claim that velocity model is related to score model by a time-dependent weighting function. To be specific, we discovere that \[ v(x_t, t) = \frac{\dot \alpha_t}{\alpha_t} x_t - \lambda_t\sigma_t s(x_t, t) \] with \( \lambda_t = \dot \sigma_t - \frac{\dot \alpha_t \sigma_t}{\alpha_t} \). Plug this linear relation into \(\mathcal{L}_v\), we obtain \[ \begin{align*} \mathcal{L}_{v}(\theta) &= \int_0^T \lambda_t^2 \mathbb{E}[\Vert \sigma_t s_\theta(x_t, t) + \varepsilon \Vert^2] \mathrm{d} t \\ &= \mathcal{L}_{s_\lambda}(\theta) \end{align*} \] this aligns with the observation made in , that different model predictions of diffusion models corresponding to a vanilla denoising objective weighted by different time-dependent functions. We trained all three models and present the results below.

Interpolant Objective FID
SBDM-VP \( \mathcal{L}_s\) 43.6
SBDM-VP \( \mathcal{L}_{s_\lambda} \) 39.1
SBDM-VP \( \mathcal{L}_v \) 39.8
Effect of different model predictions.

As \(t \to 0\), \(\dot \sigma_t \to -\infty\) under the variance preserving setting, making \(\lambda_t^2\) blow up to infinity as well. In practice, we follow the setting of SBDM and clip both training and sampling to the interval \([\varepsilon, 1]\) to avoid numerical stability. As a result, the large \(\lambda_\varepsilon\) is able to compensate for the vanishing gradient of \(\mathcal{L}_s\), but in turn makes \(\mathcal{L}_v\) harder to optimize.
time-dependent weighting
Plot of \( \lambda_t^2 \) clipped on \([10^{-5}, 1]\).

Interpolant

We mainly experiment with three choices of interpolants:

Below are some examples demonstrating the effects of the above interoplants on simple 1D distributions. On an aligned time interval starting from a standard Gaussian distribution, we note that VP interpolant changes most drastically in both cases while GVP and Linear interpolants are more smooth. Intuitively, such abrupt change in the VP interpolant increases the Lipschitz constant of the velocity field, making its learning more difficult.
density interpolation density interpolation
Density Transformation using the above three interpolants. Left: from a standard Gaussian to Bernoulli distribution defined at \(\pm 1\). Right: from a standard Gaussian to a Gaussian mixture with two modes at \(-1\) and \(2\).
Moving from toy examples to image generation tasks, we again observe large performance differences across different interpolants. We present the results below.

Interpolant Objective FID
SBDM-VP \( \mathcal{L}_v\) 39.8
GVP \( \mathcal{L}_v \) 34.6
Linear \( \mathcal{L}_v \) 34.8
Effect of interpolants.

One possible explanation for this observation is given in the following figure, where we see that the path length (transport cost) is reduced when changing from SBDM-VP to GVP or Linear. Numerically, we also note that the singularity issue of \(\dot \sigma_t\) at \(t = 0\) does not appear with the GVP and Linear interpolants, making the model easier to learn near the data distribution.
path length
Path Length. The path length \(\mathcal{C}(v) = \int \mathbb{E}[|v(x_t, t)^2|] \mathrm{d}t\) arising from the velocity field at different training steps for the various interpolants; each curve is approximated by 10000 datapoints.

Sampler

We introduce more flexibility in sampling the velocity model in this section. Firstly, under SBDM setting, the reverse-time SDE for velocity can be constructed in the following way: \[ dX_t = [v(X_t, t) - \frac12 g^2(t)s(X_t, t)]\mathrm{d}t + g(t) \mathrm{d} \bar{W}_t \] where we take advantage of the linear relationship between velocity and score to construct the drift term. We denote \(g(t)\) as the SBDM diffusion coefficient. Following previous sections, we can also construct such SDE for GVP and Linear interpolants given the relationship \(g^2(t) = 2\lambda_t\sigma_t\). The results are resented below

Interpolant Objective ODE SDE
SBDM-VP \( \mathcal{L}_v\) 39.8 37.8
GVP \( \mathcal{L}_v \) 34.6 32.9
Linear \( \mathcal{L}_v \) 34.8 33.6
ODE vs. SDE, SBDM diffusion

We further proposed that the diffusion coefficient \(g(t)\) can be tuned separately from the learning process. In fact, any non-negative function \(w(t)\) (not necessarily monotone) is eligible to be used as diffusion coefficient, and the reverse-time SDE can thus be generalized to \[ dX_t = [v(X_t, t) - \frac12 w(t)s(X_t, t)]\mathrm{d}t + \sqrt{w(t)} \mathrm{d} \bar{W}_t \] Apart from SBDM coefficient, we also experiment with the choices of \(w(t) = \sigma_t\) (to eliminate the singularity of score near data distribution) and \(w(t) = \sin^2(\pi t)\), as well as their effects on either velocity or score models.

Interpolant Objective \( w(t) = g(t)\) \( w(t) = \sigma_t \) \( w(t) = \sin^2(\pi t)\)
SBDM-VP \( \mathcal{L}_v\) 37.8 38.7 39.2
\( \mathcal{L}_{s_\lambda}\) 35.7 37.1 37.7
GVP \( \mathcal{L}_v \) 32.9 33.4 33.6
\( \mathcal{L}_s \) 38.0 33.5 33.2
Linear \( \mathcal{L}_v \) 33.6 33.5 33.3
\( \mathcal{L}_s \) 41.0 35.3 34.4
Evaluation of SDE with different diffusion coefficients. To make the SBDM-VP competitive we instead use the weighted objective.

We note that the optimal choice of diffusion coefficient depends on the interpolant and the objective, and in our experiments, also largely depends on model sizes. Empirically, we observe the best choice for our SiT-XL is a continuous-time velocity model with Linear interpolant and sampled using SDE with \(w(t) = \sigma_t\) coefficient.

Lastly, we note that the performance of ODE and SDE integrators may differ under different computation budget. As shown below, the ODE converges faster with fewer number of functions evaluations, while the SDE is capable of reaching a much lower final FID score when given a larger computational budget.

nfe sampler
Comparison of ODE and SDE w/ choices of diffusion coefficients. Each point is evaluated with 10K samples.

Classifier-free Guidance

In this section, we give a concise justification for adopting it on the velocity model, and then empirically show that the drastic gains in performance for DiT case carry across to SiT.

Guidance for a velocity field means that: (i) that the velocity model \(v_\theta(x_t, t, y)\) takes class labels \(y\) during training, where \(y\) is occasionally masked with a null token \(\emptyset\); and (ii) during sampling the velocity used is \(v_\theta^\zeta(x_t, t, y) = \zeta v_\theta(x_t, t, y) + (1 - \zeta) v_\theta(x_t, t, \emptyset)\) for a fixed \(\zeta > 0\). Given this observation, one can leverage the usual argument for classifier-free guidance on score-based models. For a CFG scale of \(\zeta = 1.5\), DiT-XL sees an improvement in FID from a 9.6 (non-CFG) down to 2.27 (CFG). We observed similar performance improvement with our largest SiT-XL model under identical computation budget and CFG scale. Sampled with an ODE, the FID-50K score improved from 9.4 to 2.15; with an SDE, the FID improved from 8.6 to 2.06. This shows that SiT benefits from the same training and sampling choices explored previously, and can surpass DiT's performance in each training setting, not only with respect to model size, but also with respect to sampling choices.


Class-Conditional ImageNet \(256\times256\)
Model FID \(\downarrow\) sFID \(\downarrow\) IS \(\uparrow\) Precision \(\uparrow\) Recall \(\uparrow\)
BigGAN-deep 6.95 7.36 171.4 0.87 0.28
StyleGAN-XL 2.30 4.02 265.12 0.78 0.53
Mask-GIT 6.18 - 182.1 - -
ADM 10.94 6.02 100.98 0.69 0.63
ADM-G, ADM-U 3.94 6.14 215.84 0.83 0.53
CDM 4.88 - 158.71 - -
RIN 3.42 - 182.0 - -
Simple Diffusion(U-Net) 3.76 - 171.6 - -
Simple Diffusion(U-ViT, L) 2.77 - 211.8 - -
VDM++ 2.12 - 267.7 - -
DiT-XL(cfg=1.5) 2.27 4.60 278.24 0.83 0.57
SiT-XL(cfg=1.5, ODE) 2.15 4.60 258.09 0.81 0.60
SiT-XL(cfg=1.5, SDE:\(\sigma_t\)) 2.06 4.50 270.27 0.82 0.59
Benchmarking class-conditional image generation on ImageNet 256x256. SiT-XL surpasses DiT-XL in FID when using either of the samplers, ODE or SDE-based.

Conclusion

In this work, we have presented Scalable Interpolant Transformers, a simple and powerful framework for image generation tasks. Within the framework, we explored the tradeoffs between a number of key design choices: the choice of a continuous or discrete-time model, the choice of interpolant, the choice of model predcition, and the choice of samplers . We highlighted the advantages and disadvantages of each choice and demonstrated how careful decisions can lead to significant performance improvements. Many concurrent works explore similar approaches in a wide variety of downstream tasks, and we leave the application of SiT to these tasks for future works.

BibTeX

@article{ma2024sit,
  title={SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers},
  author={Nanye Ma and Mark Goldstein and Michael S. Albergo and Nicholas M. Boffi and Eric Vanden-Eijnden and Saining Xie},
  year={2024},
  eprint={2401.08740},
  archivePrefix={arXiv},
  primaryClass={cs.CV}
}