EconML

0.16.0 · active · verified Mon Apr 13

EconML is a Python library for estimating Conditional Average Treatment Effects (CATEs) from observational or experimental data. It provides a suite of advanced machine learning methods, including Double Machine Learning (DML) and Causal Forests, to infer causal relationships and individual-level treatment effects. The current version is 0.16.0, and it maintains an active development pace with major updates and bugfix releases every few months.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to use `CausalForestDML` to estimate Conditional Average Treatment Effects (CATEs) using synthetic data. It involves defining confounders (W), features for heterogeneity (X), treatment (T), and outcome (Y), then fitting the model and predicting CATEs for new feature values. `RandomForestRegressor` and `RandomForestClassifier` are used as base learners for Y and T models respectively.

import numpy as np
import pandas as pd
from econml.dml import CausalForestDML
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

# Simulate data
np.random.seed(42)
n_samples = 1000
W = np.random.normal(0, 1, size=(n_samples, 3)) # Confounders
X = np.random.normal(0, 1, size=(n_samples, 2)) # Features for heterogeneity
T = (W[:, 0] + W[:, 1] + np.random.normal(0, 1, n_samples) > 0).astype(float) # Treatment
Y = W[:, 0] + W[:, 2] + T * (X[:, 0] + np.random.normal(0, 0.1, n_samples)) + np.random.normal(0, 1, n_samples) # Outcome

# Initialize and fit the CausalForestDML model
est = CausalForestDML(
    model_y=RandomForestRegressor(min_samples_leaf=5, n_estimators=100, random_state=42),
    model_t=RandomForestClassifier(min_samples_leaf=5, n_estimators=100, random_state=42),
    cv=5,
    random_state=42
)
est.fit(Y, T, X=X, W=W)

# Estimate CATE for new data (or original X)
X_test = np.array([[0.5, 0.5], [-0.5, -0.5]])
cate_estimates = est.effect(X_test)
print(f"CATE estimates for X_test: {cate_estimates}")

# Expected output: CATE estimates for X_test: [0.67204641 0.44976767]

view raw JSON →