lightning-ai/pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on 1 or 10,000+ GPUs with zero code changes.

31,058 stars Python 8 components

Abstracts PyTorch training loops into reusable components for single-GPU to multi-thousand-GPU deployment

Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met.

Under the hood, the system uses 3 feedback loops, 3 data pools, 5 control points to manage its runtime behavior.

A 8-component ml training. 648 files analyzed. Data flows through 6 distinct pipeline stages.

How Data Flows Through the System

Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met.

  1. Initialize training setup — Trainer.__init__() configures distributed strategy, precision plugin, accelerator, and callbacks. Strategy.setup() initializes process groups and device placement. DataConnector prepares DataLoaders with appropriate samplers for distributed training
  2. Setup model and optimizers — Trainer calls LightningModule.configure_optimizers() to get optimizer and scheduler configuration. Strategy wraps the model for distributed training (e.g., DDP, FSDP). PrecisionPlugin configures mixed precision if enabled [LightningModule → OptimizerConfig]
  3. Execute training step — TrainingLoop iterates through DataLoader batches. Each batch is passed to LightningModule.training_step(batch, batch_idx) which computes the loss. PrecisionPlugin handles forward pass precision and loss scaling [TrainingBatch → LoggerMetrics]
  4. Compute gradients and optimize — Loss.backward() computes gradients with PrecisionPlugin handling gradient scaling. Strategy synchronizes gradients across devices. Optimizer.step() updates model parameters, then optimizer.zero_grad() clears gradients [LoggerMetrics]
  5. Aggregate and log metrics — ResultCollection gathers metrics from training_step across all devices and processes. Metrics are reduced (averaged, summed) and synchronized. Logger instances (TensorBoard, WandB) receive formatted metrics for visualization [LoggerMetrics → LoggerMetrics]
  6. Run validation and checkpointing — ValidationLoop executes periodically, running LightningModule.validation_step() on validation data. ModelCheckpoint monitors validation metrics and saves CheckpointDict to disk when criteria are met (best loss, every N epochs) [TrainingBatch → CheckpointDict]

Data Models

The data structures that flow between stages — the contracts that hold the system together.

LightningModule src/lightning/pytorch/core/module.py
PyTorch nn.Module subclass with training_step(batch, batch_idx) -> loss: Tensor, validation_step(batch, batch_idx) -> metrics: Dict, configure_optimizers() -> Optimizer | List[Optimizer]
User defines model with required hooks, Trainer calls methods during training loop execution, outputs logged metrics and checkpoints
TrainingBatch src/lightning/pytorch/trainer/connectors/data_connector.py
Generic container for batch data - typically Tensor or Dict[str, Tensor] with batch dimension first, exact shape depends on dataset
DataLoader produces batches, Trainer routes to LightningModule.training_step(), consumed for loss computation and backpropagation
LoggerMetrics src/lightning/pytorch/trainer/connectors/logger_connector/result.py
Dict[str, Union[Tensor, float]] with metric names as keys and scalar values, includes metadata like step, epoch, dataloader_idx
Generated in training/validation steps, aggregated by ResultCollection, formatted and sent to configured loggers (TensorBoard, Weights & Biases, etc.)
CheckpointDict src/lightning/pytorch/trainer/connectors/checkpoint_connector.py
Dict with keys: state_dict (model weights), optimizer_states, lr_schedulers, epoch, global_step, pytorch-lightning_version, hyper_parameters
Created during training at specified intervals, saved to disk with ModelCheckpoint callback, loaded to restore training state or for inference
OptimizerConfig src/lightning/pytorch/utilities/types.py
Dict with scheduler: LRScheduler, name: Optional[str], interval: 'epoch'|'step', frequency: int, reduce_on_plateau: bool, monitor: Optional[str]
Returned from LightningModule.configure_optimizers(), used by Trainer to manage learning rate scheduling throughout training

Hidden Assumptions

Things this code relies on but never validates. These are the things that cause silent failures when the system changes.

critical Contract weakly guarded

LightningModule passed to fit() implements training_step(batch, batch_idx) returning a loss Tensor and optionally validation_step(batch, batch_idx) returning metrics dict, but never validates these method signatures or return types

If this fails: If training_step returns wrong type (e.g., dict instead of Tensor) or validation_step returns non-dict, trainer silently fails or produces cryptic errors during backward pass

examples/fabric/build_your_own_trainer/trainer.py:MyCustomTrainer.__init__
critical Shape unguarded

Input tensor has shape (batch_size, 1, 28, 28) for MNIST data, with fc1 layer expecting exactly 9216 features (64 * 12 * 12 after conv/pool operations), but never validates input dimensions

If this fails: If input has different spatial dimensions or channels, fc1 receives wrong tensor size causing RuntimeError about mismatched dimensions during forward pass

examples/fabric/image_classifier/train_fabric.py:Net.forward
critical Environment unguarded

os.cpu_count() returns a valid integer for DataLoader workers, but os.cpu_count() can return None on some systems where CPU count is undetermined

If this fails: DataLoader fails with TypeError when trying to use None as num_workers, causing training to crash at data loading stage

examples/fabric/dcgan/train_fabric.py:workers assignment
critical Domain weakly guarded

All linear layer inner dimensions are divisible by 16 for Float8 conversion (except decoder which is filtered out), but never validates this mathematical constraint

If this fails: Float8 conversion fails silently or produces incorrect results when linear layers have dimensions not divisible by 16, leading to subtle numerical errors

examples/fabric/fp8_distributed_transformer/train.py:convert_to_float8_training
warning Resource unguarded

Checkpoint directory is writable and has sufficient disk space for model state_dict serialization, but never checks filesystem permissions or available space

If this fails: Checkpoint saving fails mid-training with disk full or permission errors, losing training progress without graceful recovery

examples/fabric/build_your_own_trainer/trainer.py:_save_checkpoint
warning Contract weakly guarded

Validation step returns a dict with string keys for metric names, but doesn't validate dict structure or key types before logging

If this fails: Non-string keys or nested dicts cause logger failures or metric aggregation errors, breaking validation monitoring

examples/fabric/build_your_own_trainer/trainer.py:_run_validation
warning Scale unguarded

GPU memory can handle batch_size=128 with 64x64x3 images plus generator/discriminator models, roughly 200MB+ per batch, but never checks available VRAM

If this fails: Training crashes with CUDA out of memory errors when GPU has insufficient memory, requiring manual batch size tuning

examples/fabric/dcgan/train_fabric.py:batch_size=128
warning Temporal unguarded

Validation frequency counting is based on completed epochs, but doesn't account for early stopping or interrupted training affecting validation timing

If this fails: Validation may not run at expected intervals when training is interrupted and resumed, potentially missing important metric checkpoints

examples/fabric/build_your_own_trainer/trainer.py:validation_frequency
warning Environment unguarded

CelebA dataset exists in 'data/' directory and is properly formatted, but never validates dataset integrity or file permissions

If this fails: Training fails with cryptic errors during data loading if dataset is corrupted, missing, or has wrong file structure

examples/fabric/dcgan/train_fabric.py:dataroot='data/'
info Ordering unguarded

Model setup, optimizer configuration, and data preparation happen in specific order before training loop starts, but doesn't enforce or validate this sequencing

If this fails: If components are accessed before proper initialization (e.g., calling backward before fabric.setup), training produces confusing errors about uninitialized state

examples/fabric/build_your_own_trainer/trainer.py:fit method

System Behavior

How the system operates at runtime — where data accumulates, what loops, what waits, and what controls what.

Data Pools

CheckpointStore (file-store)
Persists model weights, optimizer states, hyperparameters, and training metadata to disk at configured intervals or when validation metrics improve
MetricsBuffer (buffer)
Accumulates metrics from training/validation steps before aggregation and logging, handles synchronization across distributed processes
GradientState (in-memory)
Holds computed gradients in model parameters, synchronized across devices in distributed training before optimizer step

Feedback Loops

Delays

Control Points

Technology Stack

PyTorch (framework)
Core tensor operations, neural network primitives, autograd, and CUDA integration for GPU acceleration
PyTorch Distributed (library)
Multi-process training coordination, gradient synchronization, and process group management for DDP/FSDP
TorchMetrics (library)
Standardized metric computation and aggregation across distributed training processes
TensorBoard (library)
Training visualization, metric logging, and hyperparameter tracking integration
CUDA (runtime)
GPU computation backend for accelerated training on NVIDIA hardware
Hydra (library)
Configuration management for complex training experiments with multiple parameter sweeps
DeepSpeed (library)
Microsoft's optimization library for large model training with ZeRO memory optimization

Key Components

Explore the interactive analysis

See the full architecture map, data flow, and code patterns visualization.

Analyze on CodeSea

Compare pytorch-lightning

Related Ml Training Repositories

Frequently Asked Questions

What is pytorch-lightning used for?

Abstracts PyTorch training loops into reusable components for single-GPU to multi-thousand-GPU deployment lightning-ai/pytorch-lightning is a 8-component ml training written in Python. Data flows through 6 distinct pipeline stages. The codebase contains 648 files.

How is pytorch-lightning architected?

pytorch-lightning is organized into 4 architecture layers: User APIs, Training Orchestration, Distributed Strategies, Hardware Abstraction. Data flows through 6 distinct pipeline stages. This layered structure keeps concerns separated and modules independent.

How does data flow through pytorch-lightning?

Data moves through 6 stages: Initialize training setup → Setup model and optimizers → Execute training step → Compute gradients and optimize → Aggregate and log metrics → .... Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met. This pipeline design reflects a complex multi-stage processing system.

What technologies does pytorch-lightning use?

The core stack includes PyTorch (Core tensor operations, neural network primitives, autograd, and CUDA integration for GPU acceleration), PyTorch Distributed (Multi-process training coordination, gradient synchronization, and process group management for DDP/FSDP), TorchMetrics (Standardized metric computation and aggregation across distributed training processes), TensorBoard (Training visualization, metric logging, and hyperparameter tracking integration), CUDA (GPU computation backend for accelerated training on NVIDIA hardware), Hydra (Configuration management for complex training experiments with multiple parameter sweeps), and 1 more. A focused set of dependencies that keeps the build manageable.

What system dynamics does pytorch-lightning have?

pytorch-lightning exhibits 3 data pools (CheckpointStore, MetricsBuffer), 3 feedback loops, 5 control points, 3 delays. The feedback loops handle training-loop and convergence. These runtime behaviors shape how the system responds to load, failures, and configuration changes.

What design patterns does pytorch-lightning use?

4 design patterns detected: Strategy Pattern, Hook-based Training Loop, Plugin Architecture, Connector Pattern.

How does pytorch-lightning compare to alternatives?

CodeSea has side-by-side architecture comparisons of pytorch-lightning with transformers, deepspeed, composer. These comparisons show tech stack differences, pipeline design, system behavior, and code patterns. See the comparison pages above for detailed analysis.

Analyzed on April 20, 2026 by CodeSea. Written by .