PyTorch Forecasting

1.7.0 · active · verified Thu Apr 16

PyTorch Forecasting is a highly scalable open-source library for state-of-the-art time series forecasting with PyTorch. It provides common data structures like TimeSeriesDataSet, various forecasting models (e.g., TFT, DeepAR, N-BEATS), normalizers, and metrics, all integrated with PyTorch Lightning for efficient training. The current version is 1.7.0, with regular updates aligning with PyTorch and PyTorch Lightning developments. It requires Python versions >=3.10 and <3.15.

Common errors

Warnings

Install

Imports

Quickstart

Demonstrates basic usage of PyTorch Forecasting with `TemporalFusionTransformer`, from dummy data generation, data preparation using `TimeSeriesDataSet`, to model definition, training with `pytorch_lightning.Trainer`, and making predictions.

import pandas as pd
import pytorch_lightning as pl
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE

# 1. Create dummy data
data = pd.DataFrame(dict(
    time_idx=pd.to_datetime(pd.date_range("2020-01-01", periods=100)),
    value=range(100),
    group=["a"] * 50 + ["b"] * 50,
    static_cat=["x"] * 100,
    known_cont=[i for i in range(100)]
))
data["time_idx"] = (data["time_idx"] - data["time_idx"].min()).dt.days

max_encoder_length = 20
max_prediction_length = 5
training_cutoff = data["time_idx"].max() - max_prediction_length

# 2. Define TimeSeriesDataSet
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="value",
    group_ids=["group"],
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["static_cat"],
    time_varying_known_reals=["time_idx", "known_cont"],
    time_varying_unknown_reals=["value"],
    target_normalizer=GroupNormalizer(groups=["group"], transformation="softplus"),
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)
# create validation set (predict=True) which means to predict the last max_prediction_length points in time
validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_index=training_cutoff)
train_dataloader = training.to_dataloader(batch_size=4, num_workers=0)
val_dataloader = validation.to_dataloader(batch_size=4, num_workers=0)

# 3. Define model
tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=7, # 7 quantiles by default
    loss=MAE(), # Can also use QuantileLoss()
    log_interval=10,
    reduce_on_plateau_patience=4,
)

# 4. Train model
trainer = pl.Trainer(
    max_epochs=1, # Reduced for quickstart
    gradient_clip_val=0.1,
)
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

# 5. Make predictions
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)
raw_predictions, x = best_tft.predict(val_dataloader, mode="raw", return_x=True)
# print(raw_predictions["prediction"].shape)
# print(best_tft.calculate_metrics(x, raw_predictions, metrics=[MAE()]))

view raw JSON →