TorchGeo
TorchGeo is a Python library providing datasets, samplers, transforms, and pre-trained models specifically designed for geospatial data within the PyTorch ecosystem. It aims to simplify the development of deep learning models for Earth observation and remote sensing tasks. Currently at version 0.9.0, TorchGeo maintains an active development pace with frequent releases, typically every 2-3 months, to incorporate new features and datasets.
Common errors
-
ModuleNotFoundError: No module named 'torchgeo.datasets.utils'
cause The `utils` submodule within `torchgeo.datasets` (or other modules) may have been refactored, moved, or its contents integrated directly into other classes/functions.fixCheck the official TorchGeo documentation or GitHub repository for the correct import path for the specific utility you are trying to use. The functionality might now be directly available on a class or in a different submodule. -
UserWarning: `torchgeo` is being used with an unsupported `lightning` version. This might lead to unexpected behavior.
cause The installed version of PyTorch Lightning is not officially tested or supported by the current TorchGeo version, leading to potential incompatibilities.fixVerify the required `lightning` version in TorchGeo's `pyproject.toml` or `setup.py`. Downgrade or upgrade your `lightning` installation to a compatible version, e.g., `pip install 'lightning<2.5'` if 2.5.x is causing issues with your TorchGeo version. -
TypeError: __init__ missing 1 required positional argument: 'root'
cause Many `torchgeo.datasets` classes, especially those loading data from disk, require a `root` directory path as a mandatory argument during initialization. This error occurs if `root` is omitted or incorrectly passed.fixWhen initializing a dataset, always provide the `root` argument pointing to the directory where the dataset should be stored or is located, e.g., `dataset = EuroSAT(root='./data', download=True)`. -
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!
cause This is a common PyTorch error indicating that tensors involved in an operation are on different devices (CPU vs. GPU), which can happen if not all data or models are explicitly moved to CUDA.fixEnsure all relevant tensors and models are moved to the same device (e.g., CUDA) before operations. Use `.to(device)` where `device = 'cuda' if torch.cuda.is_available() else 'cpu'`. For DataLoaders, custom collate functions or transform steps might be needed to ensure output tensors are on the desired device.
Warnings
- breaking TorchGeo 0.8.0 introduced a complete rewrite of `GeoDataset` and `GeoSampler` internals. Code relying on direct manipulation or specific internal structures of these base classes might break.
- gotcha TorchGeo versions 0.9.0 and later require Python 3.12 or newer. Users on older Python versions will encounter installation errors.
- gotcha There have been reported incompatibilities with specific versions of the `lightning` library (e.g., 2.5.5 was not supported in v0.7.2). Using an unsupported `lightning` version can lead to errors or unexpected behavior during training.
- gotcha Specific versions of `rasterio` (e.g., 1.4.0, 1.4.1) have caused issues in earlier TorchGeo versions, leading to potential crashes or incorrect data loading when processing raster data.
Install
-
pip install torchgeo -
pip install 'torchgeo[all]'
Imports
- EuroSAT
from torchgeo.datasets import EuroSAT
- RasterDataset
from torchgeo.datasets import RasterDataset
- RandomGeoSampler
from torchgeo.samplers import RandomGeoSampler
- AugmentationSequential
from torchgeo.transforms import AugmentationSequential
- ResNet18_Weights
from torchgeo.models import ResNet18_Weights
Quickstart
import torch
from torchgeo.datasets import EuroSAT
from torchgeo.transforms import AugmentationSequential, RandomGrayscale
from torchgeo.samplers import RandomBatchGeoSampler
from torch.utils.data import DataLoader
import tempfile
import os
# Initialize transforms
transforms = AugmentationSequential(
RandomGrayscale(p=0.5),
data_keys=["image"]
)
# Use a temporary directory for the dataset to avoid polluting the user's system
with tempfile.TemporaryDirectory() as tmpdir:
# Initialize EuroSAT dataset (will download if not present)
dataset = EuroSAT(root=tmpdir, split="train", transforms=transforms, download=True)
# Initialize a sampler to get patches
sampler = RandomBatchGeoSampler(dataset, patch_size=(64, 64), batch_size=4, length=10)
# Create a DataLoader
dataloader = DataLoader(dataset, sampler=sampler, num_workers=0)
# Iterate through one batch and print shapes
for batch in dataloader:
image = batch["image"]
label = batch["label"]
print(f"Batch image shape: {image.shape}, label shape: {label.shape}")
break # Just one batch for quickstart