PyTorch Gradient Reversal Layer

0.2.0 · maintenance · verified Thu Apr 16

pytorch-revgrad is a minimalist PyTorch package that provides a gradient reversal layer (GRL) as both a module and a function. This layer is commonly used in domain adaptation techniques, such as Domain-Adversarial Neural Networks (DANN), to encourage feature extractors to learn domain-invariant representations by reversing the gradient signal for a subsequent domain classifier. The current version, `0.2.0`, was released in January 2021, and the library maintains a low release cadence, indicating stability for its core functionality.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to integrate `RevGrad` into a simple PyTorch model architecture, typical for domain adaptation. It shows a `FeatureExtractor` and a `DomainClassifier` where `RevGrad` is placed before the classifier's layers to reverse gradients for domain classification.

import torch
from torch import nn
from pytorch_revgrad import RevGrad

# Define a simple feature extractor
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)

    def forward(self, x):
        return torch.relu(self.fc1(x))

# Define a domain classifier with a RevGrad layer
class DomainClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.revgrad = RevGrad()
        self.fc1 = nn.Linear(5, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.revgrad(x)
        x = torch.relu(self.fc1(x))
        return torch.sigmoid(self.fc2(x))

# Example usage
feature_extractor = FeatureExtractor()
domain_classifier = DomainClassifier()

input_data = torch.randn(64, 10, requires_grad=True)

# Forward pass
features = feature_extractor(input_data)
domain_output = domain_classifier(features)

print(f"Input shape: {input_data.shape}")
print(f"Features shape: {features.shape}")
print(f"Domain output shape: {domain_output.shape}")

# Simulate a loss and backward pass (conceptual)
# In a real scenario, you'd define a combined loss for source and target, 
# and optimize both feature_extractor and domain_classifier.
# For demonstration, we'll just show a dummy backward pass.

dummy_loss = domain_output.mean()
dummy_loss.backward()

# Check if gradients are flowing (should be for input_data and features)
print(f"Gradient for input data exists: {input_data.grad is not None}")
print(f"Gradient for feature_extractor.fc1.weight exists: {feature_extractor.fc1.weight.grad is not None}")
print(f"Gradient for domain_classifier.fc1.weight exists: {domain_classifier.fc1.weight.grad is not None}")

view raw JSON →