PyTorch TabNet

raw JSON →
4.1.0 verified Fri May 01 auth: no python

PyTorch implementation of TabNet (Google's attention-based tabular network). Current version 4.1.0, with semi-annual releases. Supports classification, regression, and unsupervised pre-training.

pip install pytorch-tabnet
error ModuleNotFoundError: No module named 'pytorch_tabnet'
cause Incorrect pip package name (missing hyphen) or import path.
fix
Install with: pip install pytorch-tabnet. Import as: from pytorch_tabnet.tab_model import TabNetClassifier
error TypeError: __init__() got an unexpected keyword argument 'cat_idxs'
cause Old version; categorical embedding was not supported before v1.1.0.
fix
Upgrade to latest version: pip install --upgrade pytorch-tabnet
error RuntimeError: expected scalar type Long but found Float
cause Target labels for classification are float instead of integer/long.
fix
Convert y_train to integer type: y_train = y_train.astype(np.int64)
breaking In v4.0, the unsupervised pretraining loss was changed to match the original paper. Models trained with pretraining in v3.x cannot be directly resumed or fine-tuned in v4.x without retraining.
fix Retrain any models that used unsupervised pretraining after upgrading to v4.x.
breaking The default metric for regression changed from 'mse' to 'rmse' in v3.0.0. If you relied on default metric behavior, your training/evaluation results may differ.
fix Explicitly set the `eval_metric` parameter to 'mse' to retain old behavior.
gotcha When using categorical features, `cat_idxs` and `cat_dims` must be consistent. An error is raised only in v4.0+ if they are incoherent.
fix Ensure `len(cat_idxs) == len(cat_dims)` and that each index corresponds to a valid column.
gotcha Saving and loading models: use `save_model` and `load_model` methods. Directly pickling the model object is not supported and may break.
fix Use clf.save_model('model.pt') and clf.load_model('model.pt').

Minimal example of fitting a TabNetClassifier on random data.

from pytorch_tabnet.tab_model import TabNetClassifier
import numpy as np

X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)

clf = TabNetClassifier(device_name='cpu')
clf.fit(X_train, y_train, max_epochs=10)
print(clf.predict(X_train))