Optuna

4.8.0 · active · verified Sun Apr 05

Optuna is an automatic hyperparameter optimization framework for machine learning, featuring an imperative, define-by-run style user API that allows for dynamic construction of search spaces. It supports Python 3.9 or newer. The current version is 4.8.0, and it maintains an active development and release cadence, with major versions often introducing significant improvements and deprecating older features after a few releases.

Warnings

Install

Imports

Quickstart

This quickstart defines an objective function that trains either an SVR or RandomForestRegressor, with hyperparameters sampled by Optuna's `Trial` object. It then creates a study to minimize the mean squared error over 100 trials, showcasing how Optuna dynamically builds search spaces and finds optimal hyperparameters.

import optuna
import sklearn
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_california_housing
from sklearn.metrics import mean_squared_error

def objective(trial: optuna.Trial) -> float:
    # Invoke suggest methods of a Trial object to generate hyperparameters.
    regressor_name = trial.suggest_categorical('regressor', ['SVR', 'RandomForest'])

    if regressor_name == 'SVR':
        svr_c = trial.suggest_float('svr_c', 1e-10, 1e10, log=True)
        regressor_obj = sklearn.svm.SVR(C=svr_c)
    else:
        rf_max_depth = trial.suggest_int('rf_max_depth', 2, 32)
        regressor_obj = RandomForestRegressor(max_depth=rf_max_depth, random_state=0)

    X, y = fetch_california_housing(return_X_y=True)
    X_train, X_val, y_train, y_val = train_test_split(X, y, random_state=0)

    regressor_obj.fit(X_train, y_train)
    y_pred = regressor_obj.predict(X_val)
    error = mean_squared_error(y_val, y_pred)
    return error

if __name__ == '__main__':
    study = optuna.create_study(direction='minimize')  # Create a new study
    study.optimize(objective, n_trials=100)  # Invoke optimization

    print(f"Best trial value: {study.best_value:.4f}")
    print(f"Best params: {study.best_params}")

view raw JSON →