PyTorch Weighted Prediction Error (WPE)

0.0.1 · maintenance · verified Thu Apr 16

A PyTorch implementation of the Weighted Prediction Error (WPE) algorithm, primarily for speech dereverberation. It serves as a proof-of-concept, closely mirroring the WPE implementation found in `nara_wpe`. The current version is 0.0.1, released in March 2021. Due to its 'proof of concept' nature and age, it does not appear to have an active release cadence or dedicated maintenance.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates basic usage of key functions like `signal_framing` and `get_power` using a synthetic `ComplexTensor` as input. A real-world application would involve feeding Short-Time Fourier Transform (STFT) outputs of audio signals into these functions.

import torch
from torch_complex.tensor import ComplexTensor
from pytorch_wpe import signal_framing, get_power

# 1. Create a dummy complex signal (e.g., from STFT output)
# Shape: (batch_size, channels, time_frames)
dummy_signal_real = torch.randn(1, 4, 200)
dummy_signal_imag = torch.randn(1, 4, 200)
dummy_complex_signal = ComplexTensor(dummy_signal_real, dummy_signal_imag)

print(f"Original complex signal shape (Real, Imag): {dummy_complex_signal.real.shape}, {dummy_complex_signal.imag.shape}")

# 2. Use signal_framing function
frame_length = 64  # e.g., STFT window size
frame_step = 32    # e.g., STFT hop length
framed_signal = signal_framing(dummy_complex_signal, frame_length, frame_step)
print(f"Framed signal shape: {framed_signal.shape} (batch, channels, num_frames, frame_length)")

# 3. Use get_power function
power = get_power(dummy_complex_signal)
print(f"Power of signal shape: {power.shape} (channels, time_frames)")

view raw JSON →