PyTorch TabNet
raw JSON → 4.1.0 verified Fri May 01 auth: no python
PyTorch implementation of TabNet (Google's attention-based tabular network). Current version 4.1.0, with semi-annual releases. Supports classification, regression, and unsupervised pre-training.
pip install pytorch-tabnet Common errors
error ModuleNotFoundError: No module named 'pytorch_tabnet' ↓
cause Incorrect pip package name (missing hyphen) or import path.
fix
Install with: pip install pytorch-tabnet. Import as: from pytorch_tabnet.tab_model import TabNetClassifier
error TypeError: __init__() got an unexpected keyword argument 'cat_idxs' ↓
cause Old version; categorical embedding was not supported before v1.1.0.
fix
Upgrade to latest version: pip install --upgrade pytorch-tabnet
error RuntimeError: expected scalar type Long but found Float ↓
cause Target labels for classification are float instead of integer/long.
fix
Convert y_train to integer type: y_train = y_train.astype(np.int64)
Warnings
breaking In v4.0, the unsupervised pretraining loss was changed to match the original paper. Models trained with pretraining in v3.x cannot be directly resumed or fine-tuned in v4.x without retraining. ↓
fix Retrain any models that used unsupervised pretraining after upgrading to v4.x.
breaking The default metric for regression changed from 'mse' to 'rmse' in v3.0.0. If you relied on default metric behavior, your training/evaluation results may differ. ↓
fix Explicitly set the `eval_metric` parameter to 'mse' to retain old behavior.
gotcha When using categorical features, `cat_idxs` and `cat_dims` must be consistent. An error is raised only in v4.0+ if they are incoherent. ↓
fix Ensure `len(cat_idxs) == len(cat_dims)` and that each index corresponds to a valid column.
gotcha Saving and loading models: use `save_model` and `load_model` methods. Directly pickling the model object is not supported and may break. ↓
fix Use clf.save_model('model.pt') and clf.load_model('model.pt').
Imports
- TabNetClassifier
from pytorch_tabnet.tab_model import TabNetClassifier
Quickstart
from pytorch_tabnet.tab_model import TabNetClassifier
import numpy as np
X_train = np.random.rand(100, 10)
y_train = np.random.randint(0, 2, 100)
clf = TabNetClassifier(device_name='cpu')
clf.fit(X_train, y_train, max_epochs=10)
print(clf.predict(X_train))