NVIDIA Megatron Core

0.16.1 · active · verified Thu Apr 16

Megatron Core is a Python library developed by NVIDIA for building highly efficient and scalable transformer-based models, especially for large-scale distributed training. It provides fundamental building blocks for tensor and pipeline parallelism. The current version is 0.16.1, and it generally follows an active release cadence with minor versions released frequently.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to initialize a basic distributed environment (required for Megatron-Core components) and instantiate a `ColumnParallelLinear` layer. It showcases the fundamental usage pattern of defining a parallelized model component. For actual distributed training, `torch.distributed.launch` or `torchrun` should be used to set up the environment variables.

import os
import torch
import torch.distributed as dist
from megatron.core.tensor_parallel.layers import ColumnParallelLinear
from megatron.core import dist_init

# Minimal distributed setup for demonstration purposes.
# In a real scenario, these env vars would be set by a launcher (e.g., torchrun)
# and dist.init_process_group would be called globally.
if not dist.is_initialized():
    os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost')
    os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
    os.environ['RANK'] = os.environ.get('RANK', '0')
    # Set WORLD_SIZE to 1 for a single-GPU test without a full distributed setup
    os.environ['WORLD_SIZE'] = os.environ.get('WORLD_SIZE', '1') 
    
    if torch.cuda.is_available() and int(os.environ['WORLD_SIZE']) > 0:
        try:
            dist.init_process_group(backend='nccl', rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']))
            print("PyTorch distributed group initialized with NCCL.")
        except Exception as e:
            print(f"Warning: Could not initialize NCCL backend: {e}. Falling back to CPU/non-distributed.")
            os.environ['WORLD_SIZE'] = '1'
            if dist.is_initialized(): # Destroy if partial init failed
                dist.destroy_process_group()
    else:
        print("Warning: CUDA not available or WORLD_SIZE=0. Skipping torch.distributed init.")
        os.environ['WORLD_SIZE'] = '1'

# Set Megatron-Core specific parallel configuration
# This is crucial for Megatron-Core layers to correctly interpret parallel ranks.
if dist.is_initialized():
    dist_init.set_tensor_model_parallel_world_size(int(os.environ['WORLD_SIZE']))
    dist_init.set_tensor_model_parallel_rank(int(os.environ['RANK']))
else:
    # Fallback for CPU-only or non-distributed setup (effectively no parallelism)
    dist_init.set_tensor_model_parallel_world_size(1)
    dist_init.set_tensor_model_parallel_rank(0)

# Define a simple parallel linear layer
hidden_size = 128
output_size = 256

try:
    # ColumnParallelLinear shards the input tensor across GPUs.
    # If world_size > 1, each rank will only compute a part of the output.
    # gather_output=True means the output is gathered on all ranks at the end.
    linear_layer = ColumnParallelLinear(
        input_size=hidden_size,
        output_size=output_size,
        gather_output=True
    )
    if torch.cuda.is_available():
        linear_layer.cuda()

    # Create a dummy input tensor
    # Input size should match hidden_size. Batch and sequence length can vary.
    input_tensor = torch.randn(2, 4, hidden_size)
    if torch.cuda.is_available():
        input_tensor = input_tensor.cuda()

    # Perform a forward pass
    output_tensor = linear_layer(input_tensor)

    print(f"\nMegatron-Core ColumnParallelLinear initialized successfully.")
    print(f"Input shape: {input_tensor.shape}")
    print(f"Output shape (gathered): {output_tensor.shape}")
    print(f"Output device: {output_tensor.device}")

except Exception as e:
    print(f"An error occurred during Megatron-Core layer execution: {e}")

finally:
    # Clean up distributed process group if initialized
    if dist.is_initialized():
        dist.destroy_process_group()

view raw JSON →