pmdarima (Auto-ARIMA)

2.1.1 · active · verified Thu Apr 09

pmdarima is a Python library that provides an equivalent to R's `auto.arima` function, automating the process of selecting optimal ARIMA (AutoRegressive Integrated Moving Average) models for time series forecasting. It builds on `statsmodels` but offers a scikit-learn-like API, simplifying complex time series analysis. The library is currently at version 2.1.1 and receives regular updates for Python version compatibility and dependency support.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `pmdarima.auto_arima` to automatically select and fit an ARIMA model to a time series and generate future predictions. The example uses a simple synthetic dataset, fits the model, and then forecasts 10 future periods. Parameters like `start_p`, `start_q`, `max_p`, `max_q`, and `seasonal` control the search space for the optimal ARIMA model.

import pmdarima as pm
import numpy as np
import matplotlib.pyplot as plt

# Generate some sample time series data
y = np.random.rand(100) * 10 + np.arange(100) # Simple trend + noise

# Fit a stepwise auto_arima model
model = pm.auto_arima(y, 
                      start_p=1, start_q=1,
                      test='adf',       # use adftest to find optimal 'd'
                      max_p=3, max_q=3, # maximum p and q
                      m=1,              # frequency of series
                      d=None,           # let model determine 'd'
                      seasonal=False,   # No seasonality
                      start_P=0, 
                      D=0,
                      trace=False,      # Suppress verbose output
                      error_action='ignore',  
                      suppress_warnings=True, 
                      stepwise=True)

# Make predictions
forecast, conf_int = model.predict(n_periods=10, return_conf_int=True)

print("Forecast:", forecast)
print("Confidence Interval:", conf_int)

# Optional: plot results
# plt.plot(y, label='Actual')
# plt.plot(np.arange(len(y), len(y) + len(forecast)), forecast, label='Forecast')
# plt.fill_between(np.arange(len(y), len(y) + len(forecast)),
#                  conf_int[:, 0], conf_int[:, 1], alpha=0.1)
# plt.legend()
# plt.show()

view raw JSON →