torchflow package¶
Submodules¶
torchflow.callbacks module¶
Callbacks for training lifecycle events.
Provides a base Callback class and a couple of useful implementations like EarlyStopping and ModelCheckpoint. These are designed to be simple and lightweight so they can be used in unit tests and examples.
- torchflow.callbacks.make_summary_writer(log_dir=None)[source]¶
Safely create a SummaryWriter if tensorboard is available.
Returns None when tensorboard or its dependencies are not installed.
- Return type:
Optional
[Any
]- Parameters:
log_dir (str | None)
- class torchflow.callbacks.Callback[source]¶
Bases:
object
Base callback. Subclass and override the desired methods.
Methods receive the trainer instance (self) or context information so callbacks can control training (for example, stopping early).
- on_epoch_end(trainer, epoch, logs=None)[source]¶
- Return type:
None
- Parameters:
trainer (Any)
epoch (int)
logs (dict | None)
- class torchflow.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta=0.0, mode='min')[source]¶
Bases:
Callback
Stop training when a monitored metric has stopped improving.
- Parameters:
monitor (
str
) – metric name to monitor (e.g. ‘val_loss’).patience (
int
) – epochs with no improvement after which training will be stopped.min_delta (
float
) – minimum change to qualify as improvement.mode (
str
) – ‘min’ or ‘max’ depending whether lower is better.
- class torchflow.callbacks.ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True, mode='min')[source]¶
Bases:
Callback
Save the model after every epoch or only when monitored metric improves.
- Parameters:
filepath (str)
monitor (str)
save_best_only (bool)
mode (str)
- class torchflow.callbacks.LambdaCallback(on_train_begin=None, on_train_end=None, on_epoch_begin=None, on_epoch_end=None, on_batch_end=None, on_validation_end=None)[source]¶
Bases:
Callback
Create a callback from simple callables.
Usage: LambdaCallback(on_epoch_end=lambda trainer, epoch, logs: …)
- Parameters:
on_train_begin (Optional[Callable])
on_train_end (Optional[Callable])
on_epoch_begin (Optional[Callable])
on_epoch_end (Optional[Callable])
on_batch_end (Optional[Callable])
on_validation_end (Optional[Callable])
- on_epoch_end(trainer, epoch, logs=None)[source]¶
- Return type:
None
- Parameters:
trainer (Any)
epoch (int)
logs (dict | None)
- class torchflow.callbacks.LearningRateScheduler(scheduler=None, schedule_fn=None, step_at='epoch')[source]¶
Bases:
Callback
Adjust learning rate according to a schedule.
You may provide either a PyTorch lr_scheduler instance (object with a .step(…) method) or a simple schedule function that accepts the epoch number and returns either a single lr or a list/tuple of lrs (one per param_group).
By default the scheduler.step() will be called at the end of each epoch.
- Parameters:
scheduler (Optional[Any])
schedule_fn (Optional[Callable])
step_at (str)
- class torchflow.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0, mode='min', min_delta=1e-08, cooldown=0)[source]¶
Bases:
Callback
Reduce learning rate when a monitored metric has stopped improving.
A lightweight wrapper around the common behavior. If you already have a torch.optim.lr_scheduler.ReduceLROnPlateau instance you can use LearningRateScheduler(scheduler=your_scheduler) instead.
- Parameters:
monitor (str)
factor (float)
patience (int)
min_lr (float)
mode (str)
min_delta (float)
cooldown (int)
- class torchflow.callbacks.CSVLogger(filename, append=True, separator=',')[source]¶
Bases:
Callback
Log epoch-level metrics to a CSV file.
Writes a header the first time and appends rows for each epoch. Accepts a logs dict (as passed to on_epoch_end) and writes key/value pairs.
- Parameters:
filename (str)
append (bool)
separator (str)
- class torchflow.callbacks.TensorBoardCallback(writer=None, log_dir=None)[source]¶
Bases:
Callback
Log metrics to TensorBoard SummaryWriter.
Provide either an existing SummaryWriter instance or a log_dir to create one. The callback writes scalars for train/val losses and any metrics present in logs.
- Parameters:
writer (Optional[Any])
log_dir (Optional[str])
torchflow.trainer module¶
A Trainer class for managing the training and validation of PyTorch models.
Includes support for metrics, TensorBoard logging, and MLflow tracking.
- class torchflow.trainer.Trainer(model, criterion, optimizer, device='cpu', metrics=None, writer=None, mlflow_tracking=False, callbacks=None)[source]¶
Bases:
object
Trainer class for managing the training and validation of PyTorch models.
- Parameters:
model (torch.nn.Module) – The model to be trained.
criterion (torch.nn.Module) – The loss function.
optimizer (torch.optim.Optimizer) – The optimizer.
device (str, optional) – Device to run the training on. Defaults to ‘cpu’.
metrics (torchmetrics.Metric or list, optional) – Metrics for evaluation. Can be a single metric or a list of metrics. Defaults to None.
writer (SummaryWriter, optional) – TensorBoard writer. Must implement the TensorBoard SummaryWriter API. Defaults to None.
mlflow_tracking (bool, optional) – Whether to use MLflow for tracking. Defaults to False.
callbacks (Iterable[Any] | None)
- train_one_epoch(dataloader)[source]¶
Train the model for one epoch.
- Parameters:
dataloader (torch.utils.data.DataLoader) – The data loader for the training data.
- Returns:
The average loss for the epoch.
- Return type:
float
- Raises:
RuntimeError – If there is an issue during training.
Notes
This method sets the model to training mode, iterates over the data loader, computes the loss, performs backpropagation, and updates the model parameters.
- validate(dataloader)[source]¶
Validate the model.
- Parameters:
dataloader (torch.utils.data.DataLoader) – The data loader for the validation data.
- Returns:
The average loss and optional metrics for the validation data.
- Return type:
Tuple[float, Optional[dict]]
- Raises:
RuntimeError – If there is an issue during validation.
Notes
This method sets the model to evaluation mode, iterates over the data loader, computes the loss and metrics without updating model parameters.
- train(train_loader, val_loader=None, num_epochs=5)[source]¶
Fit the model to the training data and validate on the validation data.
- Parameters:
train_loader (torch.utils.data.DataLoader) – The data loader for the training data.
val_loader (Optional[torch.utils.data.DataLoader], optional) – The data loader for the validation data. Defaults to None.
num_epochs (int, optional) – The number of epochs to train. Defaults to 5.
- Returns:
The trained model, training history, and the Trainer instance.
- Return type:
Tuple[torch.nn.Module, dict, Trainer]
- Raises:
RuntimeError – If there is an issue during training or validation.
Notes
This method manages the overall training process, including logging to TensorBoard and MLflow if enabled.
torchflow.tuner module¶
Hyperparameter tuning utilities using Optuna.
This module provides a tiny helper to run Optuna studies that build models
via a user-supplied build_fn(trial)
and train them using torchflow.Trainer
. Optuna
is imported lazily so importing this module
doesn’t require Optuna to be installed.
- torchflow.tuner.tune(build_fn, train_loader, val_loader, num_epochs=5, n_trials=20, direction='min', study_name=None, storage=None, n_jobs=1, show_progress=False)[source]¶
Run an Optuna study over
n_trials
.build_fn should accept an Optuna trial and return a dict with keys
'model'
,'optimizer'
,'criterion'
and optionally'device'
,'callbacks'
,'writer'
,'metrics'
,'mlflow_tracking'
.The objective returned to Optuna is the last validation loss recorded by the trainer (falls back to last training loss, or inf).
- Parameters:
build_fn (Callable[[Any], dict])
num_epochs (int)
n_trials (int)
direction (str)
study_name (str | None)
storage (str | None)
n_jobs (int)
show_progress (bool)
torchflow.utils module¶
Small runtime utilities used by examples and tests.
This module is intentionally lightweight and avoids importing heavy
dependencies (like torch
or numpy
) at import time. Functions
that use those libraries perform lazy imports so the module can be imported
in environments where those extras are not installed.
Utilities included: - set_seed(seed): set RNG seeds for reproducibility (random, numpy, torch when available) - to_device(obj, device): move tensors / modules / nested containers to a device - ensure_list(x): ensure the value is a list (wraps non-iterables)
- torchflow.utils.set_seed(seed)[source]¶
Set random seeds for reproducible runs.
This sets the seed for Python’s
random
, optionally NumPy and PyTorch if they are available. PassingNone
disables explicit seeding.- Parameters:
seed (
Optional
[int
]) – integer seed or None to skip.- Return type:
None
Module contents¶
A high-level deep learning library built on PyTorch for fast prototyping and experimentation. Provides modules for building, training, and evaluating neural networks with ease.
version: 0.1.0
author: Muyiwa Obadara
license: MIT License
copyright: Copyright 2024, Muyiwa Obadara
repository: https://github.com/maobadara/torchflow
twitter: https://twitter.com/m_aobadara