Hyper-Connections
Hyper-Connections is a Python library that implements the 'Hyper-Connections' method, proposed by ByteDance AI lab, as an alternative to traditional residual connections in neural networks. It aims to address drawbacks like the seesaw effect between gradient vanishing and representation collapse by introducing learnable depth and width connections. The library allows for flexible integration of features across depths and dynamic rearrangement of layers, particularly beneficial for large language models and vision tasks. It is actively developed, with its current version being 0.4.9, and features frequent releases.
Common errors
-
ModuleNotFoundError: No module named 'hyper_connections'
cause The `hyper-connections` library is not installed in the current Python environment or the environment is not correctly activated.fixEnsure the library is installed using `pip install hyper-connections`. If using a virtual environment, activate it before running your script. -
RuntimeError: gradient explosion or NaN loss encountered during training
cause This library introduces flexible connections, and without careful implementation or regularization, the repeated mixing of signals can lead to unstable gradients, especially in deep models or with specific initialization schemes.fixImplement gradient clipping (e.g., `torch.nn.utils.clip_grad_norm_`) in your training loop. Consider normalizing inputs or activations. Investigate if 'Manifold-Constrained Hyper-Connections' (mHC) offers a more stable variant of the architecture or if the library provides built-in regularization options for stability. Ensure `num_streams > 1`. -
TypeError: init_hyper_conn() got an unexpected keyword argument 'dim'
cause The API for `init_hyper_conn` might have changed, or the parameter `dim` is not expected in the current version or context. This could be due to an older or newer version of the library than anticipated by the example.fixCheck the exact function signature and required arguments for `init_hyper_conn` in the `hyper-connections` GitHub repository's source code or latest README for your installed version. Adjust the arguments passed to match the current API.
Warnings
- breaking Major architectural changes were introduced around version 0.4.0, which jumped from 0.3.16. While specific breaking API changes are not explicitly detailed in release notes, research-heavy libraries like this often introduce significant shifts in API or underlying behavior between minor versions due to rapid development and integration of new research findings.
- gotcha Unconstrained `hyper-connections` can lead to training instability, such as exploding signals, due to repeated and unregularized mixing operations across multiple streams in very deep networks. This issue has been identified and led to the development of 'Manifold-Constrained Hyper-Connections' (mHC).
- gotcha The full benefits of Hyper-Connections in mitigating the 'seesaw effect' (between vanishing gradients and representation collapse) are observed when the expansion rate (`num_streams`) is greater than 1. Using `num_streams=1` does not significantly improve performance and the seesaw effect persists.
Install
-
pip install hyper-connections
Imports
- get_init_and_expand_reduce_stream_functions
from hyper_connections import get_init_and_expand_reduce_stream_functions
Quickstart
import torch
from torch import nn
from hyper_connections import get_init_and_expand_reduce_stream_functions
# Define a simple neural network layer (branch) that will be enhanced by hyper-connections
class SimpleFFN(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim)
)
def forward(self, x):
return self.net(x)
# Example dimensions for a tensor and number of hyper-connection streams
dim = 512
batch_size = 2
seq_len = 1024
num_streams = 4 # Recommended to be > 1 for full benefits of hyper-connections
# Retrieve the utility functions for hyper-connections
init_hyper_conn, expand_stream, reduce_stream = get_init_and_expand_reduce_stream_functions(num_streams)
# Instantiate a base layer
branch_layer = SimpleFFN(dim)
# Wrap the base layer with hyper-connections logic
hyper_conn_branch = init_hyper_conn(dim=dim, branch=branch_layer)
# Create an initial input tensor (e.g., from a transformer layer's output)
input_tensor = torch.randn(batch_size, seq_len, dim)
print(f"Initial input shape: {input_tensor.shape}")
# 1. Expand the input into multiple residual streams
# The exact shape transformation depends on internal implementation, but typically adds a stream dimension.
expanded_input = expand_stream(input_tensor)
print(f"Shape after expansion (may vary internally): {expanded_input.shape}")
# 2. Forward pass through the wrapped branch function, which processes the multiple streams
output_streams = hyper_conn_branch(expanded_input)
print(f"Shape after hyper-connected branch (may vary internally): {output_streams.shape}")
# 3. Reduce the multiple streams back to a single output tensor
final_output = reduce_stream(output_streams)
print(f"Final output shape after reduction: {final_output.shape}")