Hyper-Connections

0.4.9 · active · verified Thu Apr 16

Hyper-Connections is a Python library that implements the 'Hyper-Connections' method, proposed by ByteDance AI lab, as an alternative to traditional residual connections in neural networks. It aims to address drawbacks like the seesaw effect between gradient vanishing and representation collapse by introducing learnable depth and width connections. The library allows for flexible integration of features across depths and dynamic rearrangement of layers, particularly beneficial for large language models and vision tasks. It is actively developed, with its current version being 0.4.9, and features frequent releases.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to integrate `hyper-connections` into a PyTorch model. It involves defining a base neural network layer (branch), then using `get_init_and_expand_reduce_stream_functions` to get utilities. The input tensor is first expanded into multiple streams, processed by the hyper-connection-wrapped branch, and then reduced back to a single tensor. The `num_streams` parameter (typically > 1) dictates the number of parallel information pathways.

import torch
from torch import nn
from hyper_connections import get_init_and_expand_reduce_stream_functions

# Define a simple neural network layer (branch) that will be enhanced by hyper-connections
class SimpleFFN(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )
    def forward(self, x):
        return self.net(x)

# Example dimensions for a tensor and number of hyper-connection streams
dim = 512
batch_size = 2
seq_len = 1024
num_streams = 4 # Recommended to be > 1 for full benefits of hyper-connections

# Retrieve the utility functions for hyper-connections
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams)

# Instantiate a base layer
branch_layer = SimpleFFN(dim)

# Wrap the base layer with hyper-connections logic
hyper_conn_branch = init_hyper_conn(dim=dim, branch=branch_layer)

# Create an initial input tensor (e.g., from a transformer layer's output)
input_tensor = torch.randn(batch_size, seq_len, dim)

print(f"Initial input shape: {input_tensor.shape}")

# 1. Expand the input into multiple residual streams
# The exact shape transformation depends on internal implementation, but typically adds a stream dimension.
expanded_input = expand_stream(input_tensor)

print(f"Shape after expansion (may vary internally): {expanded_input.shape}")

# 2. Forward pass through the wrapped branch function, which processes the multiple streams
output_streams = hyper_conn_branch(expanded_input)

print(f"Shape after hyper-connected branch (may vary internally): {output_streams.shape}")

# 3. Reduce the multiple streams back to a single output tensor
final_output = reduce_stream(output_streams)

print(f"Final output shape after reduction: {final_output.shape}")

view raw JSON →