PyTorch Wavelets (DTCWT)

raw JSON →
1.3.0 verified Fri May 01 auth: no python maintenance

A port of the Dual-Tree Complex Wavelet Transform (DTCWT) toolbox to run on PyTorch, enabling wavelet-based image processing, denoising, and feature extraction with GPU support. Current version 1.3.0 (released 2021), with no recent releases since June 2021.

pip install pytorch-wavelets
error ModuleNotFoundError: No module named 'pytorch_wavelets'
cause Package not installed or installed under a different name.
fix
pip install pytorch-wavelets
error AttributeError: module 'pytorch_wavelets' has no attribute 'DWTForward'
cause Incorrect import path or outdated version.
fix
Ensure you use 'from pytorch_wavelets import DWTForward' and pip install --upgrade pytorch-wavelets
error RuntimeError: Input tensor must be of shape (B, C, H, W) and H, W must be multiples of 2^J
cause Input dimensions not compatible with the number of wavelet decomposition levels J.
fix
Pad or resize input so that height and width are multiples of 2**J.
breaking Wavelet name 'db1' may not be recognized on some installations if PyWavelets is not installed. This package uses PyWavelets for wavelet filter banks.
fix Install PyWavelets: pip install PyWavelets. Alternatively, use built-in wavelets like 'haar' (which always works).
gotcha The DTCWT functions (DTCWTForward, DTCWTInverse) require the 'dtcwt' backend (not always installed by default). Ensure you have the scipy dependency.
fix Install scipy: pip install scipy. The dtcwt backend uses scipy for some operations.
deprecated The 'require_grad' argument in DWTForward/DTCWTForward is deprecated and may be removed. Use the default behavior instead.
fix Remove require_grad=True from calls; gradients are automatically handled via torch tensors.
gotcha The output of DWTForward is a tuple (yl, yh). yh is a list of highpass subbands per level. Many users mistakenly treat yh as a single tensor.
fix Access highpass bands as yh[j] for level j, where j=0 corresponds to finest scale.

Basic DWT forward and inverse on a random image batch

import torch
from pytorch_wavelets import DWTForward, DWTInverse

# Create a random image batch (B, C, H, W)
x = torch.randn(2, 3, 64, 64)

# Forward DWT with wavelet 'db1' (Haar) and 3 levels
xfm = DWTForward(J=3, wave='db1', mode='zero')
yl, yh = xfm(x)  # yl: lowpass, yh: list of highpass components

# Inverse DWT
ifm = DWTInverse(wave='db1', mode='zero')
recon = ifm((yl, yh))

# Check reconstruction error
print(torch.abs(x - recon).max().item())