PyTorch Weighted Prediction Error (WPE)
A PyTorch implementation of the Weighted Prediction Error (WPE) algorithm, primarily for speech dereverberation. It serves as a proof-of-concept, closely mirroring the WPE implementation found in `nara_wpe`. The current version is 0.0.1, released in March 2021. Due to its 'proof of concept' nature and age, it does not appear to have an active release cadence or dedicated maintenance.
Common errors
-
ModuleNotFoundError: No module named 'torch_complex'
cause The `torch_complex` library, which `pytorch-wpe` depends on for handling complex-valued tensors, is not installed.fixInstall the `torch_complex` library: `pip install torch_complex`. -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
cause PyTorch tensors used with `pytorch-wpe` functions are on different devices (e.g., one on CPU, another on GPU).fixEnsure all input tensors are explicitly moved to the same device (CPU or GPU) before being passed to `pytorch-wpe` functions, e.g., `tensor.to('cuda')` or `tensor.to('cpu')`. -
TypeError: signal_framing() got an unexpected keyword argument 'pad_value'
cause Using an older version of the `pytorch-wpe` functions or a mismatch in the function signature if the code was adapted from a different WPE implementation. The specific `pytorch_wpe.py` in the `dnn_wpe` repo shows `pad_value` as an argument. If the installed `pytorch-wpe` is different or an older version, the signature might change.fixVerify the exact function signature of `signal_framing` in your installed `pytorch_wpe` version. If `pad_value` is not expected, remove it or update `pytorch-wpe` if a newer version includes it (unlikely given the status).
Warnings
- breaking As a 'proof of concept' library at version 0.0.1 with no active development since 2021, the API is highly unstable. Future compatibility with newer PyTorch or `torch_complex` versions is not guaranteed, and breaking changes are likely without warning if external dependencies update.
- gotcha The GitHub repository explicitly states that the implementation 'may be slow' as it is 'not optimized in terms of computational efficiency'. This library is not designed for high-performance, real-time applications.
- gotcha Lack of comprehensive documentation, examples, and community support. The library is a direct implementation of core WPE functions without high-level wrappers or extensive tutorials.
Install
-
pip install pytorch-wpe -
conda install -c conda-forge pytorch-wpe
Imports
- signal_framing
from pytorch_wpe import signal_framing
- get_power
from pytorch_wpe import get_power
- get_correlations
from pytorch_wpe import get_correlations
Quickstart
import torch
from torch_complex.tensor import ComplexTensor
from pytorch_wpe import signal_framing, get_power
# 1. Create a dummy complex signal (e.g., from STFT output)
# Shape: (batch_size, channels, time_frames)
dummy_signal_real = torch.randn(1, 4, 200)
dummy_signal_imag = torch.randn(1, 4, 200)
dummy_complex_signal = ComplexTensor(dummy_signal_real, dummy_signal_imag)
print(f"Original complex signal shape (Real, Imag): {dummy_complex_signal.real.shape}, {dummy_complex_signal.imag.shape}")
# 2. Use signal_framing function
frame_length = 64 # e.g., STFT window size
frame_step = 32 # e.g., STFT hop length
framed_signal = signal_framing(dummy_complex_signal, frame_length, frame_step)
print(f"Framed signal shape: {framed_signal.shape} (batch, channels, num_frames, frame_length)")
# 3. Use get_power function
power = get_power(dummy_complex_signal)
print(f"Power of signal shape: {power.shape} (channels, time_frames)")