Coqui TTS Trainer

0.4.0 · active · verified Fri Apr 17

Coqui TTS Trainer is a general-purpose model trainer for PyTorch, designed to be flexible and extensible. It's part of the wider Coqui AI ecosystem, providing core training utilities for various deep learning models, including those for Text-to-Speech. The current version is 0.4.0, with releases occurring on an irregular, feature-driven cadence.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to set up a basic training loop using the `Trainer` class. It defines a simple PyTorch model and dataset, then configures the `Trainer` with a `TrainerConfig`, optimizer, scheduler, criterion, and data loaders. The `trainer.train()` method then executes the training process.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from trainer.trainer import Trainer
from trainer.generic_model_config import TrainerConfig

# 1. Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

# 2. Define a dummy dataset
class DummyDataset(Dataset):
    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return torch.randn(10), torch.randint(0, 2, ())

# 3. Create a TrainerConfig
config = TrainerConfig()
config.num_epochs = 2
config.output_path = "./trainer_output"
config.batch_size = 4

# 4. Instantiate model, optimizer, criterion, dataloaders
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
criterion = nn.CrossEntropyLoss()

train_dataset = DummyDataset()
eval_dataset = DummyDataset()
train_loader = DataLoader(train_dataset, batch_size=config.batch_size)
eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size)

# 5. Initialize and run the Trainer
trainer = Trainer(
    config=config,
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    data_loader_train=train_loader,
    data_loader_eval=eval_loader,
    grad_scaler=None, # For mixed precision, can be torch.cuda.amp.GradScaler()
    output_path=config.output_path
)

trainer.train()

view raw JSON →