Conformer (lucidrains' PyTorch implementation)
This library provides a PyTorch implementation of the Conformer model, an architecture that combines convolutional neural networks and transformers. It is designed to efficiently model both local and global dependencies in sequences, primarily for tasks like speech recognition. The library is currently at version 0.3.2.
Warnings
- breaking The original GitHub repository for this library (https://github.com/lucidrains/conformer) has disappeared, and the author's account associated with it is no longer active. The library is effectively unmaintained, meaning no further updates, bug fixes, or official support are expected.
- gotcha Due to the popularity of the Conformer architecture, there are multiple Python libraries and repositories that implement it. This specific PyPI package `conformer` corresponds to the implementation originally by 'lucidrains'. Be cautious to distinguish it from other implementations (e.g., `torchaudio.models.Conformer`, `conformer-tf`) when referencing documentation or examples to avoid compatibility issues.
Install
-
pip install conformer
Imports
- Conformer
from conformer import Conformer
Quickstart
import torch
from conformer import Conformer
# Define model parameters (example values)
batch_size, sequence_length, input_dim = 3, 12345, 80
num_classes = 10
encoder_dim = 32
num_encoder_layers = 3
depthwise_conv_kernel_size = 31 # Common kernel size for Conformer
# Instantiate the Conformer model
model = Conformer(
num_classes=num_classes,
input_dim=input_dim,
encoder_dim=encoder_dim,
num_encoder_layers=num_encoder_layers,
depthwise_conv_kernel_size=depthwise_conv_kernel_size
)
# Create dummy input data (e.g., mel spectrograms and their lengths)
inputs = torch.rand(batch_size, sequence_length, input_dim) # (batch, sequence_length, input_dim)
input_lengths = torch.LongTensor([12345, 12300, 12000]) # Actual lengths for each item in batch
# Perform a forward pass
outputs, output_lengths = model(inputs, input_lengths)
print(f"Output features shape: {outputs.shape}")
print(f"Output lengths: {output_lengths}")