A flexible PyTorch template for ML experiments with configuration management, logging, and hyperparameter optimization.
This project provides a flexible template for PyTorch-based machine learning experiments. It includes configuration management, logging with Weights & Biases (wandb), hyperparameter optimization with Optuna, and a modular structure for easy customization and experimentation.
config.py
: Defines the RunConfig
and OptimizeConfig
classes for managing experiment configurations and optimization settings.main.py
: The entry point of the project, handling command-line arguments and experiment execution.model.py
: Contains the model architecture (currently an MLP).util.py
: Utility functions for data loading, device selection, training, and analysis.run_template.yaml
: Template for run configuration.optimize_template.yaml
: Template for optimization configuration.analyze.py
: Script for analyzing completed runs and optimizations, utilizing functions from util.py
.Clone the repository:
git clone https://github.com/yourusername/pytorch_template.git
cd pytorch_template
Install the required packages:
# Use pip
pip install torch wandb survey polars numpy optuna matplotlib scienceplots
# Or Use uv with sync requirements.txt (recommended)
uv pip sync requirements.txt
# Or Use uv (fresh install)
uv pip install -U torch wandb survey polars numpy optuna matplotlib scienceplots
(Optional) Set up a Weights & Biases account for experiment tracking.
Configure your experiment by modifying run_template.yaml
or creating a new YAML file based on it.
(Optional) Configure hyperparameter optimization by modifying optimize_template.yaml
or creating a new YAML file based on it.
Run the experiment:
python main.py --run_config path/to/run_config.yaml [--optimize_config path/to/optimize_config.yaml]
If --optimize_config
is provided, the script will perform hyperparameter optimization using Optuna.
Analyze the results:
python analyze.py
run_template.yaml
)project
: Project name for wandb loggingdevice
: Device to run on (e.g., 'cpu', 'cuda:0')net
: Model class to useoptimizer
: Optimizer classscheduler
: Learning rate scheduler classepochs
: Number of training epochsbatch_size
: Batch size for trainingseeds
: List of random seeds for multiple runsnet_config
: Model-specific configurationoptimizer_config
: Optimizer-specific configurationscheduler_config
: Scheduler-specific configurationoptimize_template.yaml
)study_name
: Name of the optimization studytrials
: Number of optimization trialsseed
: Random seed for optimizationmetric
: Metric to optimizedirection
: Direction of optimization ('minimize' or 'maximize')sampler
: Optuna sampler configurationpruner
: (Optional) Optuna pruner configurationsearch_space
: Definition of the hyperparameter search spaceCustom model: Modify or add models in model.py
. Models should accept a hparams
argument as a dictionary, with keys matching the net_config
parameters in the run configuration YAML file.
Custom data: Modify the load_data
function in util.py
. The current example uses Cosine regression. The load_data
function should return train and validation datasets compatible with PyTorch's DataLoader.
Custom training: Customize the Trainer
class in util.py
by modifying step
, train_epoch
, val_epoch
, and train
methods to suit your task. Ensure that train
returns val_loss
or a custom metric for proper hyperparameter optimization.
The analyze.py
script utilizes functions from util.py
to analyze completed runs and optimizations. Key functions include:
select_group
: Select a run group for analysisselect_seed
: Select a specific seed from a run groupselect_device
: Choose a device for analysisload_model
: Load a trained model and its configurationload_study
: Load an Optuna studyload_best_model
: Load the best model from an optimization studyThese functions are defined in util.py
and used within analyze.py
.
To use the analysis tools:
Run the analyze.py
script:
python analyze.py
Follow the prompts to select the project, run group, and seed (if applicable).
The script will load the selected model and perform basic analysis, such as calculating the validation loss.
You can extend the main()
function in analyze.py
to add custom analysis as needed, utilizing the utility functions from util.py
.
Contributions are welcome! Please feel free to submit a Pull Request.
This project is provided as a template and is intended to be freely used, modified, and distributed. Users of this template are encouraged to choose a license that best suits their specific project needs.
For the template itself:
When using this template for your own project, please remember to:
For more information on choosing a license, visit choosealicense.com.