FFT Conv PyTorch

raw JSON →
1.2.0 verified Mon Apr 27 auth: no python

Implementation of 1D, 2D, and 3D FFT convolutions in PyTorch. Current version 1.2.0, supports padding='same' and half-precision input. Release cadence is irregular; latest updates in 2023.

pip install fft-conv-pytorch
error ImportError: cannot import name 'FFTConv2d' from 'fft_conv_pytorch'
cause Incorrect import path or library not installed.
fix
Ensure you have installed fft-conv-pytorch and use: from fft_conv_pytorch import FFTConv2d
error RuntimeError: Expected all tensors to be on the same device, but found at least two devices
cause Input tensor and kernel not on the same device (CPU/GPU).
fix
Move both input and model to the same device, e.g., conv.to(device) and x = x.to(device).
error ValueError: padding must be 'valid' or 'same'
cause Padding argument is not one of the allowed values.
fix
Set padding='valid' or padding='same' (supported from v1.2.0+).
breaking Requires PyTorch >= 1.8 for gradient support on dilated convolutions; earlier versions may break gradient computation.
fix Upgrade PyTorch to >= 1.8.
breaking In version 1.0.1, einsum usage was removed for efficiency; custom gradients from previous versions may not work.
fix Ensure your code does not rely on gradient computation via einsum; re-implement custom operations if needed.
breaking Padding='same' and half-precision support added in 1.2.0; using these with older versions will raise an error.
fix Upgrade to 1.2.0 or use padding='valid' and full-precision tensors.

Create a 2D FFT convolutional layer with 'same' padding and run a forward pass.

import torch
from fft_conv_pytorch import FFTConv2d

conv = FFTConv2d(in_channels=3, out_channels=16, kernel_size=3, padding='same')
x = torch.randn(1, 3, 32, 32)
y = conv(x)
print(y.shape)  # torch.Size([1, 16, 32, 32])