Torch-STOI
Torch-STOI is a Python library that provides a PyTorch implementation of the Short-Time Objective Intelligibility (STOI) metric, primarily designed for use as a loss function in deep learning models for tasks like speech enhancement and source separation. It wraps the functionality of the `pystoi` package to calculate both classical and extended STOI. The current version is 0.2.3, and releases are generally infrequent, focusing on functional improvements and correlation with the reference `pystoi` implementation.
Common errors
-
ModuleNotFoundError: No module named 'pystoi'
cause The `torch-stoi` library depends on `pystoi` for its underlying STOI calculation, but `pystoi` was not installed.fixInstall the `pystoi` dependency: `pip install pystoi`. -
AttributeError: 'Vocab' object has no attribute 'stoi'
cause This error typically occurs when using `torchtext`'s `Vocab` object, where `stoi` (string-to-integer) was a direct attribute in older versions but has been replaced by `get_stoi()` in newer `torchtext` releases. This is *not* an error related to the `torch-stoi` library, but a common confusion due to the shared 'stoi' acronym.fixIf working with `torchtext`, update your code from `vocab.stoi` to `vocab.get_stoi()`. This error is unrelated to `torch-stoi`. -
RuntimeError: The size of tensor a (X) must match the size of tensor b (Y) at non-singleton dimension Z
cause The input `preds` and `target` tensors passed to `NegSTOILoss` (or any STOI calculation) do not have matching shapes, which is required for comparison.fixEnsure that the `est_targets` (predicted speech) and `targets` (clean reference speech) tensors have identical shapes (e.g., `[batch_size, num_samples]`).
Warnings
- gotcha The `NegSTOILoss` provided by `torch-stoi` is primarily intended as a loss function for optimization and does not always perfectly replicate the exact values of the 'real' STOI metric. For objective evaluation, it is recommended to use the original `pystoi` library or `torchmetrics.audio.stoi.ShortTimeObjectiveIntelligibility` (which wraps `pystoi`).
- gotcha Calculations within `torch-stoi` (and `torchmetrics`'s STOI wrapper) are performed on the CPU. Input tensors will automatically be moved to the CPU for processing and then potentially moved back to their original device, which can introduce overhead, especially with large batches or frequent calls on GPU-accelerated workflows.
- gotcha Setting the `use_vad` parameter to `False` in `NegSTOILoss` can lead to results that are 'substantially different' from the standard STOI metric, as it bypasses the silent frame detection mechanism.
Install
-
pip install torch-stoi
Imports
- NegSTOILoss
from torch_stoi import NegSTOILoss
Quickstart
import torch
from torch import nn
from torch_stoi import NegSTOILoss
sample_rate = 16000
loss_func = NegSTOILoss(sample_rate=sample_rate)
# Example dummy data
clean_speech = torch.randn(2, sample_rate) # Batch of 2, 1 second audio
noisy_speech = torch.randn(2, sample_rate) # Batch of 2, 1 second audio
# In a real scenario, noisy_speech would be passed through a neural network
# to produce an estimated clean speech signal.
# For quickstart, let's assume `noisy_speech` is our `est_speech` for demonstration.
est_speech = noisy_speech # Replace with your model's output
# Compute loss
loss_batch = loss_func(est_speech, clean_speech)
print(f"Computed STOI loss: {loss_batch.mean().item()}")