Diffusion Deep Dive Part 3: Coding a DDPM from Scratch
In Part 1 we derived the DDPM training objective and the DDPM sampler. In Part 2 we extended that sampler to DDIM so the same network can generate in steps instead of . This post turns the math into running code.
We build a DDPM that trains on MNIST in about an hour on a single GPU (or a few hours on CPU, if you are patient). The point is not state-of-the-art image quality; it is to have a minimal, readable reference where every line maps to an equation from Part 1 and every sampler change maps to Part 2.
What we will build:
- The noise schedule .
- The closed-form forward process .
- A small UNet with sinusoidal time embeddings.
- The training loop (three lines of math, about ten of code).
- The iterative DDPM sampler (Algorithm 2 from Ho et al.).
- Practical stability tricks (EMA, gradient clipping, amp).
Reference: Ho, Jain, Abbeel, "Denoising Diffusion Probabilistic Models" (NeurIPS 2020). All equation numbers in this post refer to Part 1 of this series.
Setup
We will use PyTorch. Nothing else is required.
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
device = "cuda" if torch.cuda.is_available() else "cpu"For the dataset, MNIST is the right starting point: it is small, the images are , and you can see whether the model is working within a few epochs. We pad to (a power of two, which makes the UNet's down and up sampling clean) and scale pixels to so the data looks roughly standard-normal, matching the noise we will add.
def get_loader(batch_size=128):
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(), # [0, 1]
transforms.Normalize((0.5,), (0.5,)), # [-1, 1]
])
dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=2, drop_last=True, pin_memory=True)The Noise Schedule
Recall the definitions from Part 1: and . The noise schedule controls how aggressively we destroy the signal at each step. Ho et al. use a linear schedule from to with . It is simple and works well enough for images.
We precompute every quantity we will ever need, for every timestep, once at init. This avoids repeated cumulative products inside the training loop.
@dataclass
class Schedule:
T: int
betas: torch.Tensor # (T,)
alphas: torch.Tensor # (T,)
alphas_bar: torch.Tensor # (T,)
sqrt_alphas_bar: torch.Tensor
sqrt_one_minus_alphas_bar: torch.Tensor
sqrt_recip_alphas: torch.Tensor
posterior_variance: torch.Tensor # tilde-beta_t
def make_schedule(T=1000, beta_start=1e-4, beta_end=2e-2, device="cpu"):
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
return Schedule(
T=T,
betas=betas,
alphas=alphas,
alphas_bar=alphas_bar,
sqrt_alphas_bar=torch.sqrt(alphas_bar),
sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
posterior_variance=posterior_variance,
)Two notes worth calling out:
alphas_bar_previs . We set by convention (nothing has been corrupted at ), which is thepad(..., value=1.0)trick.posterior_varianceis from Part 1. This is what we use as at sampling time.
The Forward Process (Closed Form)
From Part 1, equation (5), we can sample from in one shot:
No loop over . This is the single most important efficiency property of diffusion training.
def gather(buf, t):
"""buf: (T,), t: (B,) -> (B, 1, 1, 1) for broadcasting over (B, C, H, W)."""
return buf.gather(0, t).view(-1, 1, 1, 1)
def q_sample(x0, t, noise, sched):
"""Sample x_t from q(x_t | x_0) in closed form."""
sqrt_ab = gather(sched.sqrt_alphas_bar, t)
sqrt_one_minus_ab = gather(sched.sqrt_one_minus_alphas_bar, t)
return sqrt_ab * x0 + sqrt_one_minus_ab * noisegather looks trivial but it is load-bearing: each sample in the batch has its own , so we need per-sample coefficients broadcast over the spatial dimensions.
The Network: A Small UNet with Time Embeddings
The network takes a noisy image and a timestep and outputs a prediction of the noise, with the same shape as the input. A UNet is the canonical choice: encoder-decoder with skip connections, which preserves spatial detail that would otherwise be lost during downsampling.
Sinusoidal Time Embeddings
The timestep is an integer in , but the network needs a continuous, information-rich representation. We use the same sinusoidal embedding as the Transformer (and for the same reason: it encodes position on many frequency scales at once). This embedding is then projected through an MLP and injected additively into every residual block.
class TimeEmbedding(nn.Module):
def __init__(self, dim, hidden):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
)
def forward(self, t):
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
emb = torch.cat([args.sin(), args.cos()], dim=-1) # (B, dim)
return self.mlp(emb) # (B, hidden)Residual Block with Time Conditioning
Each residual block does: group-norm, SiLU, conv; inject time embedding (one scalar per channel, added as bias); group-norm, SiLU, conv; plus a skip connection. This is a simplified version of the block in Ho et al.
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, t_dim):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.t_proj = nn.Linear(t_dim, out_ch)
self.norm2 = nn.GroupNorm(8, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x, t_emb):
h = self.conv1(F.silu(self.norm1(x)))
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
h = self.conv2(F.silu(self.norm2(h)))
return h + self.skip(x)The UNet
Three resolutions: . Each level has one residual block going down, one at the bottleneck, and one going up, with skip connections concatenating the encoder features into the decoder. Keep it small. This is enough for MNIST and will train in minutes on a modern GPU.
class UNet(nn.Module):
def __init__(self, in_ch=1, base=64, t_dim=128):
super().__init__()
self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
t_out = t_dim * 4
self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
# Encoder
self.down1 = ResBlock(base, base, t_out)
self.pool1 = nn.Conv2d(base, base * 2, 3, stride=2, padding=1)
self.down2 = ResBlock(base * 2, base * 2, t_out)
self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
# Bottleneck
self.mid = ResBlock(base * 4, base * 4, t_out)
# Decoder (skip-connect by concatenation, so input channels double)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
self.dec2 = ResBlock(base * 4, base * 2, t_out)
self.up1 = nn.ConvTranspose2d(base * 2, base, 4, stride=2, padding=1)
self.dec1 = ResBlock(base * 2, base, t_out)
self.out_norm = nn.GroupNorm(8, base)
self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
def forward(self, x, t):
t_emb = self.time(t)
h0 = self.in_conv(x)
h1 = self.down1(h0, t_emb) # (B, base, 32, 32)
h2 = self.down2(self.pool1(h1), t_emb) # (B, 2*base, 16, 16)
hb = self.mid(self.pool2(h2), t_emb) # (B, 4*base, 8, 8)
u2 = self.up2(hb) # (B, 2*base, 16, 16)
u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
u1 = self.up1(u2) # (B, base, 32, 32)
u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
return self.out_conv(F.silu(self.out_norm(u1)))A real implementation (Ho et al., OpenAI guided-diffusion, Stable Diffusion) adds self-attention at low resolutions, multiple residual blocks per level, and more channels. For the purposes of this post, the architecture above is intentionally stripped down.
Training
The training step is the one you came here for. Sample , form in closed form, predict the noise, compute MSE. Four lines.
def train_step(model, x0, sched, optimizer):
B = x0.size(0)
t = torch.randint(0, sched.T, (B,), device=x0.device)
noise = torch.randn_like(x0)
x_t = q_sample(x0, t, noise, sched)
noise_pred = model(x_t, t)
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()Gradient clipping at norm is a cheap safety net: occasional outlier batches (very small , where is almost pure noise and the loss surface is flat) can produce huge gradients. Without clipping, a single bad step can destabilize training.
EMA of Weights
One practical detail that matters a lot for sample quality: keep an exponential moving average (EMA) of the model weights, and sample from the EMA copy instead of the live model. Diffusion losses are noisy (different every step), so the live weights oscillate; the EMA smooths this out. Typical decay is .
class EMA:
def __init__(self, model, decay=0.9999):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model):
for k, v in model.state_dict().items():
if v.dtype.is_floating_point:
self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
else:
self.shadow[k].copy_(v)
def copy_to(self, model):
model.load_state_dict(self.shadow)The Full Training Loop
def train(epochs=30, batch_size=128, lr=2e-4, T=1000):
loader = get_loader(batch_size)
sched = make_schedule(T=T, device=device)
model = UNet(in_ch=1, base=64).to(device)
ema = EMA(model, decay=0.9999)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
step = 0
for epoch in range(epochs):
for x0, _ in loader:
x0 = x0.to(device, non_blocking=True)
loss = train_step(model, x0, sched, opt)
ema.update(model)
step += 1
if step % 200 == 0:
print(f"epoch {epoch} step {step}: loss {loss:.4f}")
return model, ema, schedOn an RTX 4090, 30 epochs of MNIST takes roughly 15 to 20 minutes and produces legible digits. You should see the training loss drop quickly from to within the first epoch and then crawl down slowly. Do not expect it to go to zero: the loss is an expectation over all noise levels, and the high- regime is irreducibly hard (the network cannot predict the noise when the signal is basically gone).
Sampling
Sampling is where the iterative structure comes back. We cannot skip timesteps the way we can in training: to sample we need . Algorithm 2 from DDPM:
Starting from , for :
where for and at the last step. We use (the posterior variance we precomputed).
The formula is exactly the expression from Part 1, Step 4 of the noise-prediction parameterization, plus a Gaussian noise injection.
@torch.no_grad()
def sample(model, sched, n=16, img_size=32, channels=1, device=device):
model.eval()
x = torch.randn(n, channels, img_size, img_size, device=device)
for t in reversed(range(sched.T)):
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
if t > 0:
noise = torch.randn_like(x)
x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
else:
x = mean # no noise at the last step
model.train()
return x.clamp(-1, 1)To sample from the EMA weights:
@torch.no_grad()
def sample_from_ema(model, ema, sched, **kwargs):
# swap in EMA weights, sample, swap back
backup = {k: v.clone() for k, v in model.state_dict().items()}
ema.copy_to(model)
imgs = sample(model, sched, **kwargs)
model.load_state_dict(backup)
return imgsWhy This Sampler Is Slow
This loop runs the network times per batch of samples. That is the fundamental cost of DDPM sampling and the largest practical drawback of vanilla diffusion. Two families of fixes:
- DDIM (Song et al., 2021). Same trained network, but a different reverse process that can skip timesteps: to network evaluations with near-identical quality. Derived in Part 2. We implement it below.
- Latent diffusion (Rombach et al., 2022). Do diffusion in a lower-dimensional latent space learned by a VAE. Each step is cheaper and you need fewer of them. Out of scope here.
Sampling Faster: DDIM
The derivation is in Part 2; the code is shorter than the math. Two additions to the file we already have:
- A helper that picks the sub-sampled timestep schedule from Part 2 equation .
- The DDIM reverse step from Part 2 equation , which plugs from equation and the -parameterized from equation into a single loop.
def make_ddim_timesteps(T: int, num_steps: int) -> list[int]:
"""Ascending sub-sequence of timestep indices 0..T-1 of length ~num_steps+1.
Includes T-1 as the last index so sampling starts from pure noise."""
stride = max(1, T // num_steps)
ts = list(range(0, T, stride))
if ts[-1] != T - 1:
ts.append(T - 1)
return tsThe DDIM sampler itself:
@torch.no_grad()
def sample_ddim(model, sched, n=16, img_size=32, channels=1,
num_steps=50, eta=0.0, device=device):
"""DDIM sampler (Song, Meng, Ermon 2021).
num_steps: number of network evaluations (~50 is usually enough).
eta=0.0: fully deterministic sampler (probability-flow ODE).
eta=1.0: recovers DDPM on the sub-sampled grid.
"""
model.eval()
x = torch.randn(n, channels, img_size, img_size, device=device)
ts = make_ddim_timesteps(sched.T, num_steps) # ascending
for i in reversed(range(len(ts))):
t = ts[i]
t_prev = ts[i - 1] if i > 0 else -1
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
ab_t = sched.alphas_bar[t]
ab_prev = sched.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
# (a) predict the clean image [Part 2 eq. (5)]
x0_pred = (x - torch.sqrt(1.0 - ab_t) * eps) / torch.sqrt(ab_t)
x0_pred = x0_pred.clamp(-1, 1) # stability, optional
# (b) sigma_t from the eta knob [Part 2 eq. (7)]
sigma_sq = (eta ** 2) * ((1.0 - ab_prev) / (1.0 - ab_t)) * (1.0 - ab_t / ab_prev)
sigma = torch.sqrt(sigma_sq.clamp(min=0.0))
# (c) deterministic drift along the predicted noise [Part 2 eq. (6)]
dir_coef = torch.sqrt((1.0 - ab_prev - sigma_sq).clamp(min=0.0))
# (d) stochastic noise injection; zero on the final step
z = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
x = torch.sqrt(ab_prev) * x0_pred + dir_coef * eps + sigma * z
model.train()
return x.clamp(-1, 1)Four labeled blocks, each corresponding to one term of Part 2's boxed update equation . That is the entire new sampler.
Sanity Check: Must Recover DDPM
If we take num_steps = sched.T (so is the full ) and set eta = 1.0, DDIM and DDPM should produce statistically identical samples. Worth asserting in a unit test:
# With eta=1 and the full grid, DDIM reduces to DDPM (Part 2 §4.1).
sigma_sq_eta1 = ((1 - ab_prev) / (1 - ab_t)) * (1 - ab_t / ab_prev)
# This should equal sched.posterior_variance[t] (= tilde_beta_t from Part 1)If that equality does not hold to within floating-point slop, something is wrong with your schedule.
EMA Wrapper
To sample from the EMA copy with either sampler:
@torch.no_grad()
def sample_with_ema(model, ema, sched, fn=sample, **kwargs):
"""fn is `sample` (DDPM) or `sample_ddim`."""
backup = {k: v.clone() for k, v in model.state_dict().items()}
ema.copy_to(model)
imgs = fn(model, sched, **kwargs)
model.load_state_dict(backup)
return imgsDDPM vs DDIM: What to Expect
| Sampler | num_steps | Wall-clock (RTX 4090, n=64) | Visual quality |
|---|---|---|---|
sample (DDPM) | s | baseline | |
sample_ddim(eta=1.0) | s | near-baseline | |
sample_ddim(eta=0.0) | s | near-baseline, deterministic | |
sample_ddim(eta=0.0) | s | mild loss of detail | |
sample_ddim(eta=0.0) | s | visible artifacts on MNIST |
The speed-up comes from doing fewer network forward passes; the UNet cost per step is unchanged.
Putting It All Together
if __name__ == "__main__":
torch.manual_seed(0)
model, ema, sched = train(epochs=30)
# Slow, stochastic DDPM sampling (T=1000 network calls)
ddpm_imgs = sample_with_ema(model, ema, sched, fn=sample, n=64)
# Fast, deterministic DDIM sampling (50 network calls)
ddim_imgs = sample_with_ema(model, ema, sched, fn=sample_ddim, n=64,
num_steps=50, eta=0.0)
from torchvision.utils import save_image
save_image((ddpm_imgs + 1) / 2, "samples_ddpm.png", nrow=8)
save_image((ddim_imgs + 1) / 2, "samples_ddim.png", nrow=8)After 30 epochs you should get something like clean, legible MNIST digits from both samplers. If the samples look like noise or mostly uniform gray, the usual suspects are:
- You forgot the
[-1, 1]normalization; the model trains against noise, so unscaled pixels will produce visually dim samples. - You are sampling from the live model instead of EMA (quality will be visibly worse, especially early in training).
- You are gathering by but forgot the
.view(-1, 1, 1, 1)broadcast shape, so the schedule coefficients broadcast wrong. - You swapped the sign somewhere: the formula is correct (we are removing predicted noise), not .
- For DDIM: a
NaNintorch.sqrt(sigma_sq)usually means or a schedule bug; clamp to .
How Each Piece Maps to Parts 1 and 2
| Math | Code |
|---|---|
| (Part 1) | Schedule / make_schedule |
| (Part 1 ) | posterior_variance |
| (Part 1 ) | q_sample |
UNet.forward | |
| (Part 1 ) | F.mse_loss(noise_pred, noise) in train_step |
| (Part 1 ) | mean = sqrt_recip_alphas[t] * (x - coef * eps) in sample |
| (Part 1 ) | x = mean + sqrt(posterior_variance[t]) * noise |
| (Part 2 ) | x0_pred = (x - sqrt(1-ab_t) * eps) / sqrt(ab_t) in sample_ddim |
| (Part 2 ) | sigma_sq = eta**2 * ((1-ab_prev)/(1-ab_t)) * (1 - ab_t/ab_prev) |
| DDIM update (Part 2 ) | x = sqrt(ab_prev)*x0_pred + dir_coef*eps + sigma*z |
| Sub-sequence (Part 2 ) | make_ddim_timesteps |
If any of these rows confuses you, go back to the corresponding section; the translation is line-for-line.
Full Source: A Single Runnable File
Everything above condensed into one file. Save as ddpm_ddim.py, run with python ddpm_ddim.py, and it will train on MNIST and write samples_ddpm.png and samples_ddim.png.
"""
ddpm_ddim.py — Minimal DDPM training + DDPM/DDIM sampling, from scratch.
Companion code for:
Part 1: /blog/diffusion-series-1-math-of-diffusion
Part 2: /blog/diffusion-series-2-ddim
Part 3: /blog/diffusion-series-3-coding-ddpm
Run:
python ddpm_ddim.py
Dependencies: torch, torchvision.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
device = "cuda" if torch.cuda.is_available() else "cpu"
# ---------------------------------------------------------------------------
# Data
# ---------------------------------------------------------------------------
def get_loader(batch_size: int = 128) -> DataLoader:
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(), # [0, 1]
transforms.Normalize((0.5,), (0.5,)), # [-1, 1]
])
dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=batch_size, shuffle=True,
num_workers=2, drop_last=True, pin_memory=True)
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
@dataclass
class Schedule:
T: int
betas: torch.Tensor
alphas: torch.Tensor
alphas_bar: torch.Tensor
sqrt_alphas_bar: torch.Tensor
sqrt_one_minus_alphas_bar: torch.Tensor
sqrt_recip_alphas: torch.Tensor
posterior_variance: torch.Tensor
def make_schedule(T: int = 1000, beta_start: float = 1e-4,
beta_end: float = 2e-2, device: str = "cpu") -> Schedule:
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1.0 - alphas_bar_prev) / (1.0 - alphas_bar)
return Schedule(
T=T,
betas=betas,
alphas=alphas,
alphas_bar=alphas_bar,
sqrt_alphas_bar=torch.sqrt(alphas_bar),
sqrt_one_minus_alphas_bar=torch.sqrt(1.0 - alphas_bar),
sqrt_recip_alphas=torch.sqrt(1.0 / alphas),
posterior_variance=posterior_variance,
)
def gather(buf: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return buf.gather(0, t).view(-1, 1, 1, 1)
def q_sample(x0: torch.Tensor, t: torch.Tensor,
noise: torch.Tensor, sched: Schedule) -> torch.Tensor:
"""x_t from q(x_t | x_0) in closed form. [Part 1 eq. (41)]"""
return gather(sched.sqrt_alphas_bar, t) * x0 + \
gather(sched.sqrt_one_minus_alphas_bar, t) * noise
# ---------------------------------------------------------------------------
# UNet
# ---------------------------------------------------------------------------
class TimeEmbedding(nn.Module):
def __init__(self, dim: int, hidden: int):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.mlp = nn.Sequential(
nn.Linear(dim, hidden),
nn.SiLU(),
nn.Linear(hidden, hidden),
)
def forward(self, t: torch.Tensor) -> torch.Tensor:
half = self.dim // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(half, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
emb = torch.cat([args.sin(), args.cos()], dim=-1)
return self.mlp(emb)
class ResBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int, t_dim: int):
super().__init__()
self.norm1 = nn.GroupNorm(8, in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.t_proj = nn.Linear(t_dim, out_ch)
self.norm2 = nn.GroupNorm(8, out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
h = self.conv1(F.silu(self.norm1(x)))
h = h + self.t_proj(F.silu(t_emb))[:, :, None, None]
h = self.conv2(F.silu(self.norm2(h)))
return h + self.skip(x)
class UNet(nn.Module):
def __init__(self, in_ch: int = 1, base: int = 64, t_dim: int = 128):
super().__init__()
self.time = TimeEmbedding(dim=t_dim, hidden=t_dim * 4)
t_out = t_dim * 4
self.in_conv = nn.Conv2d(in_ch, base, 3, padding=1)
self.down1 = ResBlock(base, base, t_out)
self.pool1 = nn.Conv2d(base, base * 2, 3, stride=2, padding=1)
self.down2 = ResBlock(base * 2, base * 2, t_out)
self.pool2 = nn.Conv2d(base * 2, base * 4, 3, stride=2, padding=1)
self.mid = ResBlock(base * 4, base * 4, t_out)
self.up2 = nn.ConvTranspose2d(base * 4, base * 2, 4, stride=2, padding=1)
self.dec2 = ResBlock(base * 4, base * 2, t_out)
self.up1 = nn.ConvTranspose2d(base * 2, base, 4, stride=2, padding=1)
self.dec1 = ResBlock(base * 2, base, t_out)
self.out_norm = nn.GroupNorm(8, base)
self.out_conv = nn.Conv2d(base, in_ch, 3, padding=1)
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_emb = self.time(t)
h0 = self.in_conv(x)
h1 = self.down1(h0, t_emb)
h2 = self.down2(self.pool1(h1), t_emb)
hb = self.mid(self.pool2(h2), t_emb)
u2 = self.up2(hb)
u2 = self.dec2(torch.cat([u2, h2], dim=1), t_emb)
u1 = self.up1(u2)
u1 = self.dec1(torch.cat([u1, h1], dim=1), t_emb)
return self.out_conv(F.silu(self.out_norm(u1)))
# ---------------------------------------------------------------------------
# Training
# ---------------------------------------------------------------------------
def train_step(model: UNet, x0: torch.Tensor, sched: Schedule,
optimizer: torch.optim.Optimizer) -> float:
B = x0.size(0)
t = torch.randint(0, sched.T, (B,), device=x0.device)
noise = torch.randn_like(x0)
x_t = q_sample(x0, t, noise, sched)
noise_pred = model(x_t, t)
loss = F.mse_loss(noise_pred, noise)
optimizer.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return loss.item()
class EMA:
def __init__(self, model: nn.Module, decay: float = 0.9999):
self.decay = decay
self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}
@torch.no_grad()
def update(self, model: nn.Module) -> None:
for k, v in model.state_dict().items():
if v.dtype.is_floating_point:
self.shadow[k].mul_(self.decay).add_(v.detach(), alpha=1.0 - self.decay)
else:
self.shadow[k].copy_(v)
def copy_to(self, model: nn.Module) -> None:
model.load_state_dict(self.shadow)
def train(epochs: int = 30, batch_size: int = 128, lr: float = 2e-4,
T: int = 1000) -> tuple[UNet, EMA, Schedule]:
loader = get_loader(batch_size)
sched = make_schedule(T=T, device=device)
model = UNet(in_ch=1, base=64).to(device)
ema = EMA(model, decay=0.9999)
opt = torch.optim.AdamW(model.parameters(), lr=lr)
step = 0
for epoch in range(epochs):
for x0, _ in loader:
x0 = x0.to(device, non_blocking=True)
loss = train_step(model, x0, sched, opt)
ema.update(model)
step += 1
if step % 200 == 0:
print(f"epoch {epoch} step {step}: loss {loss:.4f}")
return model, ema, sched
# ---------------------------------------------------------------------------
# Samplers
# ---------------------------------------------------------------------------
@torch.no_grad()
def sample(model: UNet, sched: Schedule, n: int = 16,
img_size: int = 32, channels: int = 1) -> torch.Tensor:
"""DDPM ancestral sampler (Ho et al. 2020, Algorithm 2). [Part 1 eq. (51)]"""
model.eval()
x = torch.randn(n, channels, img_size, img_size, device=device)
for t in reversed(range(sched.T)):
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
coef = sched.betas[t] / sched.sqrt_one_minus_alphas_bar[t]
mean = sched.sqrt_recip_alphas[t] * (x - coef * eps)
if t > 0:
noise = torch.randn_like(x)
x = mean + torch.sqrt(sched.posterior_variance[t]) * noise
else:
x = mean
model.train()
return x.clamp(-1, 1)
def make_ddim_timesteps(T: int, num_steps: int) -> list[int]:
"""Ascending sub-sequence of timestep indices 0..T-1."""
stride = max(1, T // num_steps)
ts = list(range(0, T, stride))
if ts[-1] != T - 1:
ts.append(T - 1)
return ts
@torch.no_grad()
def sample_ddim(model: UNet, sched: Schedule, n: int = 16,
img_size: int = 32, channels: int = 1,
num_steps: int = 50, eta: float = 0.0) -> torch.Tensor:
"""DDIM sampler (Song, Meng, Ermon 2021). [Part 2 eq. (6), (7), (11)]
eta=0.0 is deterministic (probability-flow ODE); eta=1.0 recovers DDPM
on the sub-sampled grid."""
model.eval()
x = torch.randn(n, channels, img_size, img_size, device=device)
ts = make_ddim_timesteps(sched.T, num_steps)
for i in reversed(range(len(ts))):
t = ts[i]
t_prev = ts[i - 1] if i > 0 else -1
t_batch = torch.full((n,), t, device=device, dtype=torch.long)
eps = model(x, t_batch)
ab_t = sched.alphas_bar[t]
ab_prev = sched.alphas_bar[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=device)
x0_pred = ((x - torch.sqrt(1.0 - ab_t) * eps) / torch.sqrt(ab_t)).clamp(-1, 1)
sigma_sq = (eta ** 2) * ((1.0 - ab_prev) / (1.0 - ab_t)) * (1.0 - ab_t / ab_prev)
sigma = torch.sqrt(sigma_sq.clamp(min=0.0))
dir_coef = torch.sqrt((1.0 - ab_prev - sigma_sq).clamp(min=0.0))
z = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
x = torch.sqrt(ab_prev) * x0_pred + dir_coef * eps + sigma * z
model.train()
return x.clamp(-1, 1)
@torch.no_grad()
def sample_with_ema(model: UNet, ema: EMA, sched: Schedule,
fn=sample, **kwargs) -> torch.Tensor:
backup = {k: v.clone() for k, v in model.state_dict().items()}
ema.copy_to(model)
imgs = fn(model, sched, **kwargs)
model.load_state_dict(backup)
return imgs
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
torch.manual_seed(0)
model, ema, sched = train(epochs=30)
ddpm_imgs = sample_with_ema(model, ema, sched, fn=sample, n=64)
ddim_imgs = sample_with_ema(model, ema, sched, fn=sample_ddim, n=64,
num_steps=50, eta=0.0)
save_image((ddpm_imgs + 1) / 2, "samples_ddpm.png", nrow=8)
save_image((ddim_imgs + 1) / 2, "samples_ddim.png", nrow=8)
print("Wrote samples_ddpm.png and samples_ddim.png")The file is about lines. Three-quarters of it is the UNet; the DDPM side is roughly lines of substantive code, and DDIM adds about .
What's Next
With this scaffolding, natural next steps are:
- Better schedules. The cosine schedule of Nichol and Dhariwal (2021) keeps more signal at high and trains faster.
- Classifier-free guidance. The one trick that made text-to-image diffusion actually work. A tiny change to training (drop the conditioning with 10% probability) and sampling (combine conditional and unconditional predictions).
- Higher-order solvers. Heun, PLMS, and DPM-Solver push the step count from down to to without quality loss, using the same trained network.
- DDIM inversion and interpolation. Because DDIM at is an invertible map, you can encode a real image to a latent and slerp between latents to get semantic interpolations. See Part 2 §8.
- Scaling up. CIFAR-10 or CelebA : increase UNet channels, add self-attention at resolution, train longer. Code changes are modest; training budget is not.
The fundamentals do not change. Every diffusion model you will encounter, from Stable Diffusion to video diffusion to molecular generation, is a variation on this series: an ELBO that reduces to a noise-prediction MSE, a UNet (or Transformer) that minimizes it, and a reverse-process sampler with a stochastic/deterministic knob.
Written by Suchinthaka Wanninayaka
AI/ML Researcher exploring semantic communications, diffusion models, and language model systems. Writing about deep learning from theory to production.
Continue the Series
7 RAG Retrieval Strategies, Benchmarked
12 min read
Next ArticleDiffusion Deep Dive Part 2: DDIM — From 1000 Steps to 25 Without Retraining
11 min read
Related Articles
Diffusion Deep Dive Part 1: From an Impossible Integral to a Two-Line Loss (and Back Out to Samples)
22 min read
Diffusion Deep Dive Part 2: DDIM — From 1000 Steps to 25 Without Retraining
11 min read
7 RAG Retrieval Strategies, Benchmarked
12 min read
Responses
No responses yet. Be the first to share your thoughts!