sklearn-crfsuite
sklearn-crfsuite is a thin wrapper around the `python-crfsuite` library, providing an interface similar to scikit-learn. It enables the use of scikit-learn's model selection utilities (like cross-validation and hyperparameter optimization) with Conditional Random Field (CRF) models, and allows saving/loading models using joblib. The library is actively maintained, with its latest major release (0.5.0) in June 2024.
Warnings
- breaking In version 0.5.0, the `CRF.predict()` and `CRF.predict_marginals()` methods now return a NumPy array instead of a list of lists, aligning with expectations from newer scikit-learn versions.
- breaking Version 0.4.0 dropped official support for Python 3.7 and lower, and explicitly added support for Python 3.8 and higher. It also increased minimum versions for dependencies like `python-crfsuite` (0.9.7) and `scikit-learn` (0.24.0).
- breaking In version 0.2, the `crf.tagger` attribute was renamed to `crf.tagger_`. Additionally, accessing `crf.tagger_` before training no longer raises an exception but returns `None`.
- gotcha `python-crfsuite` and `sklearn-crfsuite` do not natively support array-like features (e.g., word embeddings) directly. Attempting to pass a NumPy array as a single feature will result in errors.
- gotcha As with general scikit-learn practices, inconsistent preprocessing between training and test data (e.g., feature extraction functions) can lead to unexpected model performance.
Install
-
pip install sklearn-crfsuite
Imports
- CRF
from sklearn_crfsuite import CRF
- metrics
from sklearn_crfsuite import metrics
- scorers
from sklearn_crfsuite import scorers
Quickstart
import sklearn_crfsuite
from sklearn_crfsuite import metrics
# Dummy data for a simple sequence labeling task (e.g., POS tagging)
# Each sentence is a list of (word, pos_tag)
# Features are extracted for each word, labels are the expected tags
def word2features(sent, i):
word = sent[i][0]
postag = sent[i][1]
features = {
'bias': 1.0,
'word.lower()': word.lower(),
'word.isupper()': word.isupper(),
'word.istitle()': word.istitle(),
'word.isdigit()': word.isdigit(),
'postag': postag,
'postag[:2]': postag[:2],
}
if i > 0:
word1 = sent[i-1][0]
postag1 = sent[i-1][1]
features[' -1:word.lower()'] = word1.lower()
features[' -1:word.istitle()'] = word1.istitle()
features[' -1:word.isupper()'] = word1.isupper()
features[' -1:postag'] = postag1
features[' -1:postag[:2]'] = postag1[:2]
else:
features['BOS'] = True # Beginning of Sentence
if i < len(sent)-1:
word1 = sent[i+1][0]
postag1 = sent[i+1][1]
features['+1:word.lower()'] = word1.lower()
features['+1:word.istitle()'] = word1.istitle()
features['+1:word.isupper()'] = word1.isupper()
features['+1:postag'] = postag1
features['+1:postag[:2]'] = postag1[:2]
else:
features['EOS'] = True # End of Sentence
return features
def sent2features(sent):
return [word2features(sent, i) for i in range(len(sent))]
def sent2labels(sent):
return [label for word, label in sent]
train_sents = [
[('The', 'DT'), ('quick', 'JJ'), ('brown', 'JJ'), ('fox', 'NN'), ('jumps', 'VBZ'), ('over', 'IN'), ('the', 'DT'), ('lazy', 'JJ'), ('dog', 'NN')],
[('I', 'PRP'), ('love', 'VBP'), ('Python', 'NNP')],
[('Natural', 'JJ'), ('Language', 'NNP'), ('Processing', 'NNP'), ('is', 'VBZ'), ('fun', 'JJ')],
]
X_train = [sent2features(s) for s in train_sents]
y_train = [sent2labels(s) for s in train_sents]
# Initialize and train the CRF model
crf = sklearn_crfsuite.CRF(
algorithm='lbfgs',
c1=0.1, # L1 regularization
c2=0.1, # L2 regularization
max_iterations=100,
all_possible_transitions=True
)
crf.fit(X_train, y_train)
# Make predictions on new data
test_sents = [
[('A', 'DT'), ('fast', 'JJ'), ('red', 'JJ'), ('car', 'NN'), ('drives', 'VBZ'), ('by', 'IN')],
]
X_test = [sent2features(s) for s in test_sents]
y_pred = crf.predict(X_test)
print("Predicted labels for test sentence:")
for sent_idx, labels in enumerate(y_pred):
print(f"Sentence {sent_idx+1}: {labels}")
# Example of using metrics (requires scikit-learn)
y_true = [sent2labels(s) for s in test_sents] # In a real scenario, this would be actual ground truth
if y_true:
print("\nClassification Report:")
print(metrics.flat_classification_report(y_true, y_pred))