{"id":8563,"library":"pytorch-wpe","title":"PyTorch Weighted Prediction Error (WPE)","description":"A PyTorch implementation of the Weighted Prediction Error (WPE) algorithm, primarily for speech dereverberation. It serves as a proof-of-concept, closely mirroring the WPE implementation found in `nara_wpe`. The current version is 0.0.1, released in March 2021. Due to its 'proof of concept' nature and age, it does not appear to have an active release cadence or dedicated maintenance.","status":"maintenance","version":"0.0.1","language":"en","source_language":"en","source_url":"https://github.com/nttcslab-sp/dnn_wpe","tags":["pytorch","audio","speech processing","dereverberation","wpe","signal processing"],"install":[{"cmd":"pip install pytorch-wpe","lang":"bash","label":"PyPI"},{"cmd":"conda install -c conda-forge pytorch-wpe","lang":"bash","label":"Conda"}],"dependencies":[{"reason":"Core PyTorch framework dependency for tensor operations and neural network components.","package":"torch","optional":false},{"reason":"Handles complex number tensors, which are fundamental to the WPE algorithm as implemented.","package":"torch_complex","optional":false},{"reason":"Common numerical computing library often used alongside PyTorch.","package":"numpy","optional":false}],"imports":[{"note":"A core utility for framing signals into overlapping windows.","symbol":"signal_framing","correct":"from pytorch_wpe import signal_framing"},{"note":"Function to calculate the power of a complex signal.","symbol":"get_power","correct":"from pytorch_wpe import get_power"},{"note":"Calculates weighted correlations of a signal window, central to WPE.","symbol":"get_correlations","correct":"from pytorch_wpe import get_correlations"}],"quickstart":{"code":"import torch\nfrom torch_complex.tensor import ComplexTensor\nfrom pytorch_wpe import signal_framing, get_power\n\n# 1. Create a dummy complex signal (e.g., from STFT output)\n# Shape: (batch_size, channels, time_frames)\ndummy_signal_real = torch.randn(1, 4, 200)\ndummy_signal_imag = torch.randn(1, 4, 200)\ndummy_complex_signal = ComplexTensor(dummy_signal_real, dummy_signal_imag)\n\nprint(f\"Original complex signal shape (Real, Imag): {dummy_complex_signal.real.shape}, {dummy_complex_signal.imag.shape}\")\n\n# 2. Use signal_framing function\nframe_length = 64  # e.g., STFT window size\nframe_step = 32    # e.g., STFT hop length\nframed_signal = signal_framing(dummy_complex_signal, frame_length, frame_step)\nprint(f\"Framed signal shape: {framed_signal.shape} (batch, channels, num_frames, frame_length)\")\n\n# 3. Use get_power function\npower = get_power(dummy_complex_signal)\nprint(f\"Power of signal shape: {power.shape} (channels, time_frames)\")\n","lang":"python","description":"This quickstart demonstrates basic usage of key functions like `signal_framing` and `get_power` using a synthetic `ComplexTensor` as input. A real-world application would involve feeding Short-Time Fourier Transform (STFT) outputs of audio signals into these functions."},"warnings":[{"fix":"Pin specific versions of `torch` and `torch_complex` in your project to mitigate unforeseen breakage. Be prepared for manual code adaptation if migrating to newer environments.","message":"As a 'proof of concept' library at version 0.0.1 with no active development since 2021, the API is highly unstable. Future compatibility with newer PyTorch or `torch_complex` versions is not guaranteed, and breaking changes are likely without warning if external dependencies update.","severity":"breaking","affected_versions":"<=0.0.1"},{"fix":"Benchmark performance thoroughly for your specific use case. For production or performance-critical systems, consider optimized WPE implementations like `nara_wpe` (which this library is based on) or highly optimized C++/CUDA versions.","message":"The GitHub repository explicitly states that the implementation 'may be slow' as it is 'not optimized in terms of computational efficiency'. This library is not designed for high-performance, real-time applications.","severity":"gotcha","affected_versions":"0.0.1"},{"fix":"Refer to the source code (`pytorch_wpe.py` in the GitHub repo) for detailed understanding of function signatures and internal logic. Knowledge of the original WPE algorithm and `nara_wpe` will be beneficial for effective use.","message":"Lack of comprehensive documentation, examples, and community support. The library is a direct implementation of core WPE functions without high-level wrappers or extensive tutorials.","severity":"gotcha","affected_versions":"0.0.1"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Install the `torch_complex` library: `pip install torch_complex`.","cause":"The `torch_complex` library, which `pytorch-wpe` depends on for handling complex-valued tensors, is not installed.","error":"ModuleNotFoundError: No module named 'torch_complex'"},{"fix":"Ensure all input tensors are explicitly moved to the same device (CPU or GPU) before being passed to `pytorch-wpe` functions, e.g., `tensor.to('cuda')` or `tensor.to('cpu')`.","cause":"PyTorch tensors used with `pytorch-wpe` functions are on different devices (e.g., one on CPU, another on GPU).","error":"RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!"},{"fix":"Verify the exact function signature of `signal_framing` in your installed `pytorch_wpe` version. If `pad_value` is not expected, remove it or update `pytorch-wpe` if a newer version includes it (unlikely given the status).","cause":"Using an older version of the `pytorch-wpe` functions or a mismatch in the function signature if the code was adapted from a different WPE implementation. The specific `pytorch_wpe.py` in the `dnn_wpe` repo shows `pad_value` as an argument. If the installed `pytorch-wpe` is different or an older version, the signature might change.","error":"TypeError: signal_framing() got an unexpected keyword argument 'pad_value'"}]}