Torch-STOI

0.2.3 · active · verified Thu Apr 16

Torch-STOI is a Python library that provides a PyTorch implementation of the Short-Time Objective Intelligibility (STOI) metric, primarily designed for use as a loss function in deep learning models for tasks like speech enhancement and source separation. It wraps the functionality of the `pystoi` package to calculate both classical and extended STOI. The current version is 0.2.3, and releases are generally infrequent, focusing on functional improvements and correlation with the reference `pystoi` implementation.

Common errors

Warnings

Install

Imports

Quickstart

Initializes `NegSTOILoss` with a sample rate and demonstrates its use as a loss function with example clean and estimated speech tensors. Note that `torch-stoi` is typically integrated into a neural network training loop.

import torch
from torch import nn
from torch_stoi import NegSTOILoss

sample_rate = 16000
loss_func = NegSTOILoss(sample_rate=sample_rate)

# Example dummy data
clean_speech = torch.randn(2, sample_rate) # Batch of 2, 1 second audio
noisy_speech = torch.randn(2, sample_rate) # Batch of 2, 1 second audio

# In a real scenario, noisy_speech would be passed through a neural network
# to produce an estimated clean speech signal.
# For quickstart, let's assume `noisy_speech` is our `est_speech` for demonstration.

est_speech = noisy_speech # Replace with your model's output

# Compute loss
loss_batch = loss_func(est_speech, clean_speech)

print(f"Computed STOI loss: {loss_batch.mean().item()}")

view raw JSON →