Google Tunix

0.1.6 · active · verified Fri Apr 17

Google Tunix (current version 0.1.6) is a lightweight, JAX-native framework designed for post-training Large Language Models (LLMs) using both reinforcement learning (RL) and supervised fine-tuning (SFT). It provides powerful tools for researchers and production teams to achieve maximum control and scalability when aligning and improving foundation models, particularly on accelerators like TPUs. Releases are frequent, focusing on new model support, API stability, and performance enhancements.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize a basic `AgenticGRPOConfig`, which is central to defining Agentic Reinforcement Learning from Human Feedback (RLHF) training parameters in Tunix. This config would typically be passed to an `AgenticGRPOLearner` along with actual JAX/Flax models and data for a full training workflow.

from tunix import AgenticGRPOConfig

# Configure Agentic GRPO for LLM post-training
# This is a minimal configuration; a real setup would require more specific parameters
# like model_config, optimizers, and potentially a tokenizer.
agentic_grpo_config = AgenticGRPOConfig(
    num_generations=2, # Number of generations per iteration
    num_iterations=10, # Total training iterations
    max_response_length=512, # Maximum length for generated responses
    beta=0.1, # KL-divergence coefficient
    # Placeholders for complex objects; in a real scenario these would be actual config objects
    model_config=None, # e.g., Llama2Config, GemmaConfig
    optimizer_config_factory=lambda: None, # Factory for optimizer configs
)

print(f"AgenticGRPOConfig initialized with num_generations: {agentic_grpo_config.num_generations}")
print(f"Max response length: {agentic_grpo_config.max_response_length}")

# Note: To run a full training loop, you would also need to instantiate
# AgenticGRPOLearner with actual JAX/Flax models, a tokenizer, and a dataset.

view raw JSON →