{"library":"pytorch-forecasting","title":"PyTorch Forecasting","description":"PyTorch Forecasting is a highly scalable open-source library for state-of-the-art time series forecasting with PyTorch. It provides common data structures like TimeSeriesDataSet, various forecasting models (e.g., TFT, DeepAR, N-BEATS), normalizers, and metrics, all integrated with PyTorch Lightning for efficient training. The current version is 1.7.0, with regular updates aligning with PyTorch and PyTorch Lightning developments. It requires Python versions >=3.10 and <3.15.","language":"python","status":"active","last_verified":"Thu Apr 16","install":{"commands":["pip install pytorch-forecasting"],"cli":null},"imports":["from pytorch_forecasting.data import TimeSeriesDataSet","from pytorch_forecasting.models import TemporalFusionTransformer","from pytorch_forecasting.models import DeepAR","from pytorch_forecasting.data import GroupNormalizer","from pytorch_lightning.trainer import Trainer"],"auth":{"required":false,"env_vars":[]},"quickstart":{"code":"import pandas as pd\nimport pytorch_lightning as pl\nfrom pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer\nfrom pytorch_forecasting.data import GroupNormalizer\nfrom pytorch_forecasting.metrics import MAE\n\n# 1. Create dummy data\ndata = pd.DataFrame(dict(\n    time_idx=pd.to_datetime(pd.date_range(\"2020-01-01\", periods=100)),\n    value=range(100),\n    group=[\"a\"] * 50 + [\"b\"] * 50,\n    static_cat=[\"x\"] * 100,\n    known_cont=[i for i in range(100)]\n))\ndata[\"time_idx\"] = (data[\"time_idx\"] - data[\"time_idx\"].min()).dt.days\n\nmax_encoder_length = 20\nmax_prediction_length = 5\ntraining_cutoff = data[\"time_idx\"].max() - max_prediction_length\n\n# 2. Define TimeSeriesDataSet\ntraining = TimeSeriesDataSet(\n    data[lambda x: x.time_idx <= training_cutoff],\n    time_idx=\"time_idx\",\n    target=\"value\",\n    group_ids=[\"group\"],\n    min_encoder_length=max_encoder_length // 2,\n    max_encoder_length=max_encoder_length,\n    min_prediction_length=1,\n    max_prediction_length=max_prediction_length,\n    static_categoricals=[\"static_cat\"],\n    time_varying_known_reals=[\"time_idx\", \"known_cont\"],\n    time_varying_unknown_reals=[\"value\"],\n    target_normalizer=GroupNormalizer(groups=[\"group\"], transformation=\"softplus\"),\n    add_relative_time_idx=True,\n    add_target_scales=True,\n    add_encoder_length=True,\n)\n# create validation set (predict=True) which means to predict the last max_prediction_length points in time\nvalidation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_index=training_cutoff)\ntrain_dataloader = training.to_dataloader(batch_size=4, num_workers=0)\nval_dataloader = validation.to_dataloader(batch_size=4, num_workers=0)\n\n# 3. Define model\ntft = TemporalFusionTransformer.from_dataset(\n    training,\n    learning_rate=0.03,\n    hidden_size=16,\n    attention_head_size=1,\n    dropout=0.1,\n    hidden_continuous_size=8,\n    output_size=7, # 7 quantiles by default\n    loss=MAE(), # Can also use QuantileLoss()\n    log_interval=10,\n    reduce_on_plateau_patience=4,\n)\n\n# 4. Train model\ntrainer = pl.Trainer(\n    max_epochs=1, # Reduced for quickstart\n    gradient_clip_val=0.1,\n)\ntrainer.fit(\n    tft,\n    train_dataloaders=train_dataloader,\n    val_dataloaders=val_dataloader,\n)\n\n# 5. Make predictions\nbest_model_path = trainer.checkpoint_callback.best_model_path\nbest_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)\nraw_predictions, x = best_tft.predict(val_dataloader, mode=\"raw\", return_x=True)\n# print(raw_predictions[\"prediction\"].shape)\n# print(best_tft.calculate_metrics(x, raw_predictions, metrics=[MAE()]))","lang":"python","description":"Demonstrates basic usage of PyTorch Forecasting with `TemporalFusionTransformer`, from dummy data generation, data preparation using `TimeSeriesDataSet`, to model definition, training with `pytorch_lightning.Trainer`, and making predictions.","tag":null,"tag_description":null,"last_tested":null,"results":[]},"compatibility":null}