K-Diffusion

0.1.1.post1 · active · verified Thu Apr 16

K-Diffusion is a PyTorch library implementing the improved diffusion models from Karras et al. (2022). It provides a highly optimized collection of samplers (e.g., DPM-Solver, Euler) and utilities for building and running stable diffusion models. The current version is 0.1.1.post1, and it maintains an active, community-driven release schedule primarily focused on stability and integration with other generative AI projects.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a dummy UNet model, wrap it using `k_diffusion.external.CompVisDenoiser` to conform to the library's API, and perform a basic sampling step using `sample_dpmpp_2m`. In a real application, the `DummyUNet` would be replaced by your actual pre-trained model (e.g., a Stable Diffusion UNet).

import torch
from k_diffusion import sampling, external

# 1. Define a dummy UNet-like model (replace with your actual pre-trained UNet)
# This mock UNet simulates a model expecting (latent, timestep, conditioning) input.
class DummyUNet(torch.nn.Module):
    def __init__(self, in_channels=4, out_channels=4, img_size=64):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.relu = torch.nn.ReLU()
    def forward(self, x, timesteps, context=None):
        # In a real UNet, timesteps and context would be used for conditioning.
        return self.relu(self.conv(x))

# Instantiate the dummy UNet
inner_model = DummyUNet()

# 2. Wrap the UNet with k-diffusion's external denoiser (e.g., for Stable Diffusion latents)
# This wrapper adapts the UNet's API to k-diffusion's expected (x, sigma) signature.
model_wrap = external.CompVisDenoiser(inner_model)
model_wrap.eval().cpu() # Set to eval mode and move to CPU for quickstart simplicity

# 3. Prepare initial noisy latents and define the sampling schedule
batch_size = 1
channels = 4 # Common for Stable Diffusion latent space
height, width = 64, 64 # Latent resolution (e.g., 512x512 image -> 64x64 latent)
initial_noise = torch.randn(batch_size, channels, height, width, device='cpu') * 8.0
sigmas = sampling.get_sigmas_karras(n=40, sigma_min=0.1, sigma_max=8.0, device='cpu')

# 4. Run the sampling process using a DPM++ 2M sampler
# The sampler takes the wrapped model, initial noise, and the sigma schedule.
with torch.no_grad():
    print("Starting K-Diffusion sampling (DPM++ 2M)...")
    denoised_latents = sampling.sample_dpmpp_2m(
        model_wrap,           # The wrapped model callable
        initial_noise,        # Initial noisy latents
        sigmas                # Sigma schedule
        # Optional: `extra_args` can pass conditioning, e.g., {'cond': text_embeddings}
    )
    print(f"Sampling complete. Denoised latents shape: {denoised_latents.shape}")
    # In a real pipeline, `denoised_latents` would then be decoded to an image.

view raw JSON →