Auraloss
Auraloss is a collection of audio-focused loss functions implemented in PyTorch, designed for tasks like audio synthesis, source separation, and speech enhancement. It provides specialized losses such as Mel-spectrogram, multi-resolution STFT, and perceptual losses. The current stable version is 0.4.0, and new features and improvements are added periodically, with releases typically following significant development milestones.
Common errors
-
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!
cause Input audio tensors or target audio tensors are on a different device (CPU/GPU) than the initialized auraloss module.fixEnsure all tensors and the loss module are moved to the same device: `loss_fn = loss_fn.to(device)`, `input_audio = input_audio.to(device)`, `target_audio = target_audio.to(device)`. -
ValueError: Expected input to be a 3D tensor, got 2D tensor
cause auraloss functions expect input audio to be in (Batch, Channels, Samples) format, but received a 2D tensor (e.g., Batch, Samples).fixFor mono audio, add a channel dimension using `tensor.unsqueeze(1)`: `input_audio = input_audio.unsqueeze(1)`. -
TypeError: 'module' object is not callable
cause Attempting to call the auraloss module directly (e.g., `auraloss.freq(input, target)`) instead of an instantiated loss class.fixYou need to import and instantiate a specific loss class first, then call its instance: `from auraloss.freq import MultiResolutionSTFTLoss; mr_loss = MultiResolutionSTFTLoss(); loss = mr_loss(input, target)`.
Warnings
- gotcha All auraloss functions expect input and target tensors to be 3-dimensional (Batch, Channels, Samples). A common mistake is to pass 2D (Batch, Samples) or 1D (Samples) tensors.
- gotcha Ensure input and target tensors are on the same device (CPU/GPU) as the loss function instance. Mismatched devices will lead to `RuntimeError: Expected all tensors to be on the same device`.
- gotcha The STFT-based losses (e.g., MultiResolutionSTFTLoss, MelSpectrogramLoss) rely on `torchaudio`'s STFT implementation, which might have specific requirements for tensor dtypes (typically `torch.float32` or `torch.float64`). Using other dtypes like `torch.float16` might cause issues or unexpected behavior.
- breaking Prior to v0.3.0, some loss functions like `STFTLoss` and `MelSTFTLoss` were directly in `auraloss.loss.STFTLoss` or `auraloss.loss.MelSTFTLoss`. They were later refactored into `auraloss.freq` and renamed.
Install
-
pip install auraloss -
pip install git+https://github.com/csteinmetz1/auraloss.git
Imports
- MultiResolutionSTFTLoss
from auraloss.freq import MultiResolutionSTFTLoss
- MelSpectrogramLoss
from auraloss.freq import MelSpectrogramLoss
- PerceptualLoss
from auraloss.perceptual import PerceptualLoss
- SpectralConvergenceLoss
from auraloss.freq import SpectralConvergenceLoss
Quickstart
import torch
from auraloss.freq import MultiResolutionSTFTLoss
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Dummy input/target tensors (e.g., 10 seconds of mono audio at 16kHz)
# Batch size B, Channels C, Samples S
input_audio = torch.randn(2, 1, 160000, device=device)
target_audio = torch.randn(2, 1, 160000, device=device)
# Initialize Multi-Resolution STFT Loss
# The paper recommends a set of default parameters for MR-STFT Loss
# consisting of 3 STFT magnitudes, with varying window sizes and hop sizes.
# auraloss.freq.MultiResolutionSTFTLoss provides these defaults.
mr_stft_loss = MultiResolutionSTFTLoss().to(device)
# Compute the loss
loss = mr_stft_loss(input_audio, target_audio)
print(f"Computed MR-STFT Loss: {loss.item()}")