Keras

3.13.2 · active · verified Sun Mar 29

Keras 3 is a multi-backend deep learning framework providing a high-level API for building and training neural networks. It supports JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only) as computational backends, allowing users to leverage the same codebase across different frameworks. Focused on fast experimentation and user experience, Keras 3 enables efficient development and deployment of deep learning models across various domains. The current version is 3.13.2, and the library maintains an active release cadence with frequent updates.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates building, compiling, and training a simple convolutional neural network using Keras 3.x for image classification on the MNIST dataset. It includes setting the backend via an environment variable before importing Keras, which is a critical step for Keras 3.x.

import os
os.environ["KERAS_BACKEND"] = os.environ.get("KERAS_BACKEND", "tensorflow") # Set backend before importing keras
import keras
from keras import layers
import numpy as np

# Load example data (e.g., MNIST for a simple classification task)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0

# Define a simple Sequential model
model = keras.Sequential([
    keras.Input(shape=(28, 28, 1)),
    layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
    layers.MaxPooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dropout(0.5),
    layers.Dense(10, activation="softmax"),
])

# Compile the model
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=["accuracy"],
)

# Train the model
print("\nTraining model...")
model.fit(x_train, y_train, batch_size=128, epochs=3, validation_split=0.1)

# Evaluate the model
print("\nEvaluating model...")
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

view raw JSON →