Temporal PyTorch Complex Tensor Class
torch-complex is a Python library that provides a custom `ComplexTensor` class and related functional operations for PyTorch. It serves as a temporal solution to enable complex-valued tensor computations in PyTorch, developed primarily because PyTorch historically lacked comprehensive native support for complex tensors. The project's stated goal is to be superseded and eventually 'thrown away' once PyTorch's native complex tensor capabilities are fully mature and performant. The current version is 0.4.4, with a focused release cadence driven by specific needs for complex tensor operations.
Common errors
-
TypeError: 'ComplexTensor' object is not callable
cause Attempting to call `ComplexTensor` as a function, e.g., `x = ComplexTensor(real_part, imag_part)()` instead of `x = ComplexTensor(real_part, imag_part)`.fixThe `ComplexTensor` is a class that needs to be instantiated, not called as a function. Remove the extra parentheses when creating an instance. -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
cause Attempting to perform operations between a `ComplexTensor` on one device (e.g., CPU) and a `torch.Tensor` or another `ComplexTensor` on a different device (e.g., GPU).fixEnsure all `ComplexTensor` and `torch.Tensor` objects involved in an operation are on the same device using `.to(device)` or `.cuda()`/`.cpu()` methods. Example: `x = x.cuda()` before performing operations with other tensors on CUDA. -
AttributeError: 'ComplexTensor' object has no attribute 'some_native_pytorch_method'
cause `torch-complex`'s `ComplexTensor` is a custom class that reimplements many `torch.Tensor` methods, but it might not cover all specialized native PyTorch tensor methods or those introduced in newer PyTorch versions.fixCheck the `torch-complex` source or documentation to see if the desired method is implemented. If not, you might need to convert the `ComplexTensor` to its real and imaginary `torch.Tensor` components to use the native method, or consider if native PyTorch complex tensors now support your use case.
Warnings
- breaking This library is explicitly a 'temporal' solution and its maintainer intends to 'throw away' the project once native PyTorch ComplexTensor support is fully developed. Users should plan to migrate to native PyTorch complex dtypes (torch.complex64, torch.complex128) in the long term, as native support is now stable and actively maintained.
- gotcha Operations in `torch-complex` are implemented in Python by combining real-valued tensor computations, which can be significantly slower than native C++/CUDA optimized PyTorch operations. This library prioritizes functionality over raw performance.
- gotcha This library requires PyTorch >= 1.0. Older versions of PyTorch might not have adequate underlying support for some of the real tensor operations used by `torch-complex`.
- gotcha There are several similarly named, but distinct, Python packages for complex numbers in PyTorch (e.g., `complexPyTorch`, `pytorch-complex`, `complextorch`). Each may have different `ComplexTensor` implementations and import paths.
Install
-
pip install torch-complex
Imports
- ComplexTensor
from torch_complex.tensor import ComplexTensor
- functional (as F)
import torch_complex.nn as nn
import torch_complex.functional as F
Quickstart
import numpy as np
import torch
from torch_complex.tensor import ComplexTensor
import torch_complex.functional as F
# Create ComplexTensor from real and imaginary parts
real_part = np.random.randn(3, 10, 10)
imag_part = np.random.randn(3, 10, 10)
x = ComplexTensor(real_part, imag_part)
# Perform basic mathematical operations
y = x + x
z = F.matmul(x, x) # Equivalent to x @ x
w = x.conj()
print(f"Original ComplexTensor shape: {x.shape}")
print(f"Result of addition (y) shape: {y.shape}")
print(f"Result of matrix multiplication (z) shape: {z.shape}")
print(f"Conjugate (w) shape: {w.shape}")
# Move to CUDA if available
if torch.cuda.is_available():
x_cuda = x.cuda()
print(f"ComplexTensor moved to CUDA: {x_cuda.device}")
else:
print("CUDA not available, running on CPU.")