lightning-ai/pytorch-lightning
Pretrain, finetune ANY AI model of ANY size on 1 or 10,000+ GPUs with zero code changes.
Abstracts PyTorch training loops into reusable components for single-GPU to multi-thousand-GPU deployment
Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met.
Under the hood, the system uses 3 feedback loops, 3 data pools, 5 control points to manage its runtime behavior.
A 8-component ml training. 648 files analyzed. Data flows through 6 distinct pipeline stages.
How Data Flows Through the System
Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met.
- Initialize training setup — Trainer.__init__() configures distributed strategy, precision plugin, accelerator, and callbacks. Strategy.setup() initializes process groups and device placement. DataConnector prepares DataLoaders with appropriate samplers for distributed training
- Setup model and optimizers — Trainer calls LightningModule.configure_optimizers() to get optimizer and scheduler configuration. Strategy wraps the model for distributed training (e.g., DDP, FSDP). PrecisionPlugin configures mixed precision if enabled [LightningModule → OptimizerConfig]
- Execute training step — TrainingLoop iterates through DataLoader batches. Each batch is passed to LightningModule.training_step(batch, batch_idx) which computes the loss. PrecisionPlugin handles forward pass precision and loss scaling [TrainingBatch → LoggerMetrics]
- Compute gradients and optimize — Loss.backward() computes gradients with PrecisionPlugin handling gradient scaling. Strategy synchronizes gradients across devices. Optimizer.step() updates model parameters, then optimizer.zero_grad() clears gradients [LoggerMetrics]
- Aggregate and log metrics — ResultCollection gathers metrics from training_step across all devices and processes. Metrics are reduced (averaged, summed) and synchronized. Logger instances (TensorBoard, WandB) receive formatted metrics for visualization [LoggerMetrics → LoggerMetrics]
- Run validation and checkpointing — ValidationLoop executes periodically, running LightningModule.validation_step() on validation data. ModelCheckpoint monitors validation metrics and saves CheckpointDict to disk when criteria are met (best loss, every N epochs) [TrainingBatch → CheckpointDict]
Data Models
The data structures that flow between stages — the contracts that hold the system together.
src/lightning/pytorch/core/module.pyPyTorch nn.Module subclass with training_step(batch, batch_idx) -> loss: Tensor, validation_step(batch, batch_idx) -> metrics: Dict, configure_optimizers() -> Optimizer | List[Optimizer]
User defines model with required hooks, Trainer calls methods during training loop execution, outputs logged metrics and checkpoints
src/lightning/pytorch/trainer/connectors/data_connector.pyGeneric container for batch data - typically Tensor or Dict[str, Tensor] with batch dimension first, exact shape depends on dataset
DataLoader produces batches, Trainer routes to LightningModule.training_step(), consumed for loss computation and backpropagation
src/lightning/pytorch/trainer/connectors/logger_connector/result.pyDict[str, Union[Tensor, float]] with metric names as keys and scalar values, includes metadata like step, epoch, dataloader_idx
Generated in training/validation steps, aggregated by ResultCollection, formatted and sent to configured loggers (TensorBoard, Weights & Biases, etc.)
src/lightning/pytorch/trainer/connectors/checkpoint_connector.pyDict with keys: state_dict (model weights), optimizer_states, lr_schedulers, epoch, global_step, pytorch-lightning_version, hyper_parameters
Created during training at specified intervals, saved to disk with ModelCheckpoint callback, loaded to restore training state or for inference
src/lightning/pytorch/utilities/types.pyDict with scheduler: LRScheduler, name: Optional[str], interval: 'epoch'|'step', frequency: int, reduce_on_plateau: bool, monitor: Optional[str]
Returned from LightningModule.configure_optimizers(), used by Trainer to manage learning rate scheduling throughout training
Hidden Assumptions
Things this code relies on but never validates. These are the things that cause silent failures when the system changes.
LightningModule passed to fit() implements training_step(batch, batch_idx) returning a loss Tensor and optionally validation_step(batch, batch_idx) returning metrics dict, but never validates these method signatures or return types
If this fails: If training_step returns wrong type (e.g., dict instead of Tensor) or validation_step returns non-dict, trainer silently fails or produces cryptic errors during backward pass
examples/fabric/build_your_own_trainer/trainer.py:MyCustomTrainer.__init__
Input tensor has shape (batch_size, 1, 28, 28) for MNIST data, with fc1 layer expecting exactly 9216 features (64 * 12 * 12 after conv/pool operations), but never validates input dimensions
If this fails: If input has different spatial dimensions or channels, fc1 receives wrong tensor size causing RuntimeError about mismatched dimensions during forward pass
examples/fabric/image_classifier/train_fabric.py:Net.forward
os.cpu_count() returns a valid integer for DataLoader workers, but os.cpu_count() can return None on some systems where CPU count is undetermined
If this fails: DataLoader fails with TypeError when trying to use None as num_workers, causing training to crash at data loading stage
examples/fabric/dcgan/train_fabric.py:workers assignment
All linear layer inner dimensions are divisible by 16 for Float8 conversion (except decoder which is filtered out), but never validates this mathematical constraint
If this fails: Float8 conversion fails silently or produces incorrect results when linear layers have dimensions not divisible by 16, leading to subtle numerical errors
examples/fabric/fp8_distributed_transformer/train.py:convert_to_float8_training
Checkpoint directory is writable and has sufficient disk space for model state_dict serialization, but never checks filesystem permissions or available space
If this fails: Checkpoint saving fails mid-training with disk full or permission errors, losing training progress without graceful recovery
examples/fabric/build_your_own_trainer/trainer.py:_save_checkpoint
Validation step returns a dict with string keys for metric names, but doesn't validate dict structure or key types before logging
If this fails: Non-string keys or nested dicts cause logger failures or metric aggregation errors, breaking validation monitoring
examples/fabric/build_your_own_trainer/trainer.py:_run_validation
GPU memory can handle batch_size=128 with 64x64x3 images plus generator/discriminator models, roughly 200MB+ per batch, but never checks available VRAM
If this fails: Training crashes with CUDA out of memory errors when GPU has insufficient memory, requiring manual batch size tuning
examples/fabric/dcgan/train_fabric.py:batch_size=128
Validation frequency counting is based on completed epochs, but doesn't account for early stopping or interrupted training affecting validation timing
If this fails: Validation may not run at expected intervals when training is interrupted and resumed, potentially missing important metric checkpoints
examples/fabric/build_your_own_trainer/trainer.py:validation_frequency
CelebA dataset exists in 'data/' directory and is properly formatted, but never validates dataset integrity or file permissions
If this fails: Training fails with cryptic errors during data loading if dataset is corrupted, missing, or has wrong file structure
examples/fabric/dcgan/train_fabric.py:dataroot='data/'
Model setup, optimizer configuration, and data preparation happen in specific order before training loop starts, but doesn't enforce or validate this sequencing
If this fails: If components are accessed before proper initialization (e.g., calling backward before fabric.setup), training produces confusing errors about uninitialized state
examples/fabric/build_your_own_trainer/trainer.py:fit method
System Behavior
How the system operates at runtime — where data accumulates, what loops, what waits, and what controls what.
Data Pools
Persists model weights, optimizer states, hyperparameters, and training metadata to disk at configured intervals or when validation metrics improve
Accumulates metrics from training/validation steps before aggregation and logging, handles synchronization across distributed processes
Holds computed gradients in model parameters, synchronized across devices in distributed training before optimizer step
Feedback Loops
- Training Loop (training-loop, reinforcing) — Trigger: Trainer.fit() called with max_epochs > 0. Action: Execute training_step on each batch, compute gradients, update parameters, log metrics, run validation periodically. Exit: Reaches max_epochs, max_steps, or early stopping callback triggers.
- Learning Rate Scheduling (convergence, balancing) — Trigger: LR scheduler configured with ReduceLROnPlateau and validation metric monitored. Action: Reduce learning rate when validation metric plateaus, monitor for patience epochs. Exit: Reaches minimum learning rate or training completes.
- Gradient Accumulation (gradient-accumulation, reinforcing) — Trigger: accumulate_grad_batches > 1 in Trainer configuration. Action: Accumulate gradients over multiple batches before calling optimizer.step(), scales effective batch size. Exit: Accumulated specified number of batches or end of epoch.
Delays
- Distributed Synchronization (async-processing, ~varies by network and model size) — All processes wait for gradient synchronization before optimizer step, ensuring consistent parameter updates
- Checkpoint Saving (checkpoint-save, ~varies by model size and storage speed) — Training pauses while model state is serialized and written to disk, frequency controlled by save configuration
- Validation Interval (batch-window, ~depends on validation dataset size) — Training loop pauses to run full validation pass, frequency controlled by check_val_every_n_epoch
Control Points
- Distributed Strategy (architecture-switch) — Controls: How model and gradients are distributed across devices - DDP, FSDP, DeepSpeed, model parallel. Default: ddp
- Precision Mode (precision-mode) — Controls: Numerical precision for forward/backward passes - fp32, fp16, bf16, int8 quantization. Default: 32-true
- Accelerator Selection (device-selection) — Controls: Hardware backend for computation - CPU, CUDA GPU, MPS, TPU. Default: auto
- Gradient Clipping (threshold) — Controls: Maximum gradient norm to prevent exploding gradients, applied before optimizer step
- Accumulate Grad Batches (hyperparameter) — Controls: Number of batches to accumulate gradients over before optimizer step, effectively increases batch size. Default: 1
Technology Stack
Core tensor operations, neural network primitives, autograd, and CUDA integration for GPU acceleration
Multi-process training coordination, gradient synchronization, and process group management for DDP/FSDP
Standardized metric computation and aggregation across distributed training processes
Training visualization, metric logging, and hyperparameter tracking integration
GPU computation backend for accelerated training on NVIDIA hardware
Configuration management for complex training experiments with multiple parameter sweeps
Microsoft's optimization library for large model training with ZeRO memory optimization
Key Components
- Trainer (orchestrator) — Orchestrates the entire training process - manages training/validation loops, callbacks, checkpointing, logging, and distributed coordination across devices and strategies
src/lightning/pytorch/trainer/trainer.py - Fabric (adapter) — Provides minimal abstractions for custom training loops - handles device placement, distributed setup, mixed precision, and gradient synchronization without enforcing training structure
src/lightning/fabric/fabric.py - DDPStrategy (processor) — Implements PyTorch's DistributedDataParallel for multi-GPU training - handles process group setup, gradient synchronization, and model replication across devices
src/lightning/pytorch/strategies/ddp.py - ModelCheckpoint (monitor) — Monitors training metrics and automatically saves model checkpoints based on configured criteria (best validation loss, every N epochs, etc.)
src/lightning/pytorch/callbacks/model_checkpoint.py - ResultCollection (processor) — Aggregates metrics from training/validation steps across devices and dataloaders, handles synchronization in distributed training and metric reduction
src/lightning/pytorch/trainer/connectors/logger_connector/result.py - TrainingLoop (executor) — Executes the training epoch loop - iterates through batches, calls training_step, handles gradient accumulation, optimizer stepping, and learning rate scheduling
src/lightning/pytorch/loops/fit_loop.py - PrecisionPlugin (adapter) — Handles mixed precision training (fp16, bf16, int8) by wrapping forward passes, loss scaling, and gradient unscaling for different hardware accelerators
src/lightning/pytorch/plugins/precision/ - DataConnector (adapter) — Manages DataLoader setup, distributed sampling, and batch iteration - handles train/val/test split management and ensures proper data distribution across processes
src/lightning/pytorch/trainer/connectors/data_connector.py
Explore the interactive analysis
See the full architecture map, data flow, and code patterns visualization.
Analyze on CodeSeaCompare pytorch-lightning
Related Ml Training Repositories
Frequently Asked Questions
What is pytorch-lightning used for?
Abstracts PyTorch training loops into reusable components for single-GPU to multi-thousand-GPU deployment lightning-ai/pytorch-lightning is a 8-component ml training written in Python. Data flows through 6 distinct pipeline stages. The codebase contains 648 files.
How is pytorch-lightning architected?
pytorch-lightning is organized into 4 architecture layers: User APIs, Training Orchestration, Distributed Strategies, Hardware Abstraction. Data flows through 6 distinct pipeline stages. This layered structure keeps concerns separated and modules independent.
How does data flow through pytorch-lightning?
Data moves through 6 stages: Initialize training setup → Setup model and optimizers → Execute training step → Compute gradients and optimize → Aggregate and log metrics → .... Training starts when Trainer.fit() loads the LightningModule and DataLoaders, sets up the selected distributed strategy and hardware configuration. Each training epoch iterates through batches from the DataLoader, routes each batch to LightningModule.training_step() which computes loss, then executes backpropagation and optimizer steps. Metrics are collected, aggregated across devices, and sent to loggers. Validation runs periodically using validation_step(), and checkpoints are saved based on configured criteria. The process continues for the specified number of epochs or until early stopping conditions are met. This pipeline design reflects a complex multi-stage processing system.
What technologies does pytorch-lightning use?
The core stack includes PyTorch (Core tensor operations, neural network primitives, autograd, and CUDA integration for GPU acceleration), PyTorch Distributed (Multi-process training coordination, gradient synchronization, and process group management for DDP/FSDP), TorchMetrics (Standardized metric computation and aggregation across distributed training processes), TensorBoard (Training visualization, metric logging, and hyperparameter tracking integration), CUDA (GPU computation backend for accelerated training on NVIDIA hardware), Hydra (Configuration management for complex training experiments with multiple parameter sweeps), and 1 more. A focused set of dependencies that keeps the build manageable.
What system dynamics does pytorch-lightning have?
pytorch-lightning exhibits 3 data pools (CheckpointStore, MetricsBuffer), 3 feedback loops, 5 control points, 3 delays. The feedback loops handle training-loop and convergence. These runtime behaviors shape how the system responds to load, failures, and configuration changes.
What design patterns does pytorch-lightning use?
4 design patterns detected: Strategy Pattern, Hook-based Training Loop, Plugin Architecture, Connector Pattern.
How does pytorch-lightning compare to alternatives?
CodeSea has side-by-side architecture comparisons of pytorch-lightning with transformers, deepspeed, composer. These comparisons show tech stack differences, pipeline design, system behavior, and code patterns. See the comparison pages above for detailed analysis.
Analyzed on April 20, 2026 by CodeSea. Written by Karolina Sarna.