torchflow package

Submodules

torchflow.callbacks module

Callbacks for training lifecycle events.

Provides a base Callback class and a couple of useful implementations like EarlyStopping and ModelCheckpoint. These are designed to be simple and lightweight so they can be used in unit tests and examples.

torchflow.callbacks.make_summary_writer(log_dir=None)[source]

Safely create a SummaryWriter if tensorboard is available.

Returns None when tensorboard or its dependencies are not installed.

Return type:

Optional[Any]

Parameters:

log_dir (str | None)

class torchflow.callbacks.Callback[source]

Bases: object

Base callback. Subclass and override the desired methods.

Methods receive the trainer instance (self) or context information so callbacks can control training (for example, stopping early).

on_train_begin(trainer)[source]
Return type:

None

Parameters:

trainer (Any)

on_train_end(trainer)[source]
Return type:

None

Parameters:

trainer (Any)

on_epoch_begin(trainer, epoch)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

on_epoch_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

on_batch_end(trainer, batch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • batch (int)

  • logs (dict | None)

on_validation_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.EarlyStopping(monitor='val_loss', patience=3, min_delta=0.0, mode='min')[source]

Bases: Callback

Stop training when a monitored metric has stopped improving.

Parameters:
  • monitor (str) – metric name to monitor (e.g. ‘val_loss’).

  • patience (int) – epochs with no improvement after which training will be stopped.

  • min_delta (float) – minimum change to qualify as improvement.

  • mode (str) – ‘min’ or ‘max’ depending whether lower is better.

on_validation_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.ModelCheckpoint(filepath, monitor='val_loss', save_best_only=True, mode='min')[source]

Bases: Callback

Save the model after every epoch or only when monitored metric improves.

Parameters:
  • filepath (str)

  • monitor (str)

  • save_best_only (bool)

  • mode (str)

on_validation_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.LambdaCallback(on_train_begin=None, on_train_end=None, on_epoch_begin=None, on_epoch_end=None, on_batch_end=None, on_validation_end=None)[source]

Bases: Callback

Create a callback from simple callables.

Usage: LambdaCallback(on_epoch_end=lambda trainer, epoch, logs: …)

Parameters:
  • on_train_begin (Optional[Callable])

  • on_train_end (Optional[Callable])

  • on_epoch_begin (Optional[Callable])

  • on_epoch_end (Optional[Callable])

  • on_batch_end (Optional[Callable])

  • on_validation_end (Optional[Callable])

on_train_begin(trainer)[source]
Return type:

None

Parameters:

trainer (Any)

on_train_end(trainer)[source]
Return type:

None

Parameters:

trainer (Any)

on_epoch_begin(trainer, epoch)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

on_epoch_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

on_batch_end(trainer, batch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • batch (int)

  • logs (dict | None)

on_validation_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.LearningRateScheduler(scheduler=None, schedule_fn=None, step_at='epoch')[source]

Bases: Callback

Adjust learning rate according to a schedule.

You may provide either a PyTorch lr_scheduler instance (object with a .step(…) method) or a simple schedule function that accepts the epoch number and returns either a single lr or a list/tuple of lrs (one per param_group).

By default the scheduler.step() will be called at the end of each epoch.

Parameters:
  • scheduler (Optional[Any])

  • schedule_fn (Optional[Callable])

  • step_at (str)

on_epoch_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=0.0, mode='min', min_delta=1e-08, cooldown=0)[source]

Bases: Callback

Reduce learning rate when a monitored metric has stopped improving.

A lightweight wrapper around the common behavior. If you already have a torch.optim.lr_scheduler.ReduceLROnPlateau instance you can use LearningRateScheduler(scheduler=your_scheduler) instead.

Parameters:
  • monitor (str)

  • factor (float)

  • patience (int)

  • min_lr (float)

  • mode (str)

  • min_delta (float)

  • cooldown (int)

on_validation_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.CSVLogger(filename, append=True, separator=',')[source]

Bases: Callback

Log epoch-level metrics to a CSV file.

Writes a header the first time and appends rows for each epoch. Accepts a logs dict (as passed to on_epoch_end) and writes key/value pairs.

Parameters:
  • filename (str)

  • append (bool)

  • separator (str)

on_epoch_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

class torchflow.callbacks.TensorBoardCallback(writer=None, log_dir=None)[source]

Bases: Callback

Log metrics to TensorBoard SummaryWriter.

Provide either an existing SummaryWriter instance or a log_dir to create one. The callback writes scalars for train/val losses and any metrics present in logs.

Parameters:
  • writer (Optional[Any])

  • log_dir (Optional[str])

on_epoch_end(trainer, epoch, logs=None)[source]
Return type:

None

Parameters:
  • trainer (Any)

  • epoch (int)

  • logs (dict | None)

on_train_end(trainer)[source]
Return type:

None

Parameters:

trainer (Any)

torchflow.trainer module

A Trainer class for managing the training and validation of PyTorch models.

Includes support for metrics, TensorBoard logging, and MLflow tracking.

class torchflow.trainer.Trainer(model, criterion, optimizer, device='cpu', metrics=None, writer=None, mlflow_tracking=False, callbacks=None)[source]

Bases: object

Trainer class for managing the training and validation of PyTorch models.

Parameters:
  • model (torch.nn.Module) – The model to be trained.

  • criterion (torch.nn.Module) – The loss function.

  • optimizer (torch.optim.Optimizer) – The optimizer.

  • device (str, optional) – Device to run the training on. Defaults to ‘cpu’.

  • metrics (torchmetrics.Metric or list, optional) – Metrics for evaluation. Can be a single metric or a list of metrics. Defaults to None.

  • writer (SummaryWriter, optional) – TensorBoard writer. Must implement the TensorBoard SummaryWriter API. Defaults to None.

  • mlflow_tracking (bool, optional) – Whether to use MLflow for tracking. Defaults to False.

  • callbacks (Iterable[Any] | None)

train_one_epoch(dataloader)[source]

Train the model for one epoch.

Parameters:

dataloader (torch.utils.data.DataLoader) – The data loader for the training data.

Returns:

The average loss for the epoch.

Return type:

float

Raises:

RuntimeError – If there is an issue during training.

Notes

This method sets the model to training mode, iterates over the data loader, computes the loss, performs backpropagation, and updates the model parameters.

validate(dataloader)[source]

Validate the model.

Parameters:

dataloader (torch.utils.data.DataLoader) – The data loader for the validation data.

Returns:

The average loss and optional metrics for the validation data.

Return type:

Tuple[float, Optional[dict]]

Raises:

RuntimeError – If there is an issue during validation.

Notes

This method sets the model to evaluation mode, iterates over the data loader, computes the loss and metrics without updating model parameters.

train(train_loader, val_loader=None, num_epochs=5)[source]

Fit the model to the training data and validate on the validation data.

Parameters:
  • train_loader (torch.utils.data.DataLoader) – The data loader for the training data.

  • val_loader (Optional[torch.utils.data.DataLoader], optional) – The data loader for the validation data. Defaults to None.

  • num_epochs (int, optional) – The number of epochs to train. Defaults to 5.

Returns:

The trained model, training history, and the Trainer instance.

Return type:

Tuple[torch.nn.Module, dict, Trainer]

Raises:

RuntimeError – If there is an issue during training or validation.

Notes

This method manages the overall training process, including logging to TensorBoard and MLflow if enabled.

save_model(path)[source]

Save the model to the specified path and log to MLflow if enabled. :type path: str :param path: The path to save the model. :type path: str

Return type:

None

Returns:

None

Parameters:

path (str)

torchflow.tuner module

Hyperparameter tuning utilities using Optuna.

This module provides a tiny helper to run Optuna studies that build models via a user-supplied build_fn(trial) and train them using torchflow.Trainer. Optuna is imported lazily so importing this module doesn’t require Optuna to be installed.

torchflow.tuner.tune(build_fn, train_loader, val_loader, num_epochs=5, n_trials=20, direction='min', study_name=None, storage=None, n_jobs=1, show_progress=False)[source]

Run an Optuna study over n_trials.

build_fn should accept an Optuna trial and return a dict with keys 'model', 'optimizer', 'criterion' and optionally 'device', 'callbacks', 'writer', 'metrics', 'mlflow_tracking'.

The objective returned to Optuna is the last validation loss recorded by the trainer (falls back to last training loss, or inf).

Parameters:
  • build_fn (Callable[[Any], dict])

  • num_epochs (int)

  • n_trials (int)

  • direction (str)

  • study_name (str | None)

  • storage (str | None)

  • n_jobs (int)

  • show_progress (bool)

torchflow.tuner.example_build_fn(trial)[source]

Tiny example build_fn for docs/tests.

Samples a learning rate and a hidden-size, constructs a tiny MLP and returns the dict expected by tune.

torchflow.utils module

Small runtime utilities used by examples and tests.

This module is intentionally lightweight and avoids importing heavy dependencies (like torch or numpy) at import time. Functions that use those libraries perform lazy imports so the module can be imported in environments where those extras are not installed.

Utilities included: - set_seed(seed): set RNG seeds for reproducibility (random, numpy, torch when available) - to_device(obj, device): move tensors / modules / nested containers to a device - ensure_list(x): ensure the value is a list (wraps non-iterables)

torchflow.utils.set_seed(seed)[source]

Set random seeds for reproducible runs.

This sets the seed for Python’s random, optionally NumPy and PyTorch if they are available. Passing None disables explicit seeding.

Parameters:

seed (Optional[int]) – integer seed or None to skip.

Return type:

None

torchflow.utils.ensure_list(x)[source]

Return x as a list. If x is None returns an empty list.

Useful for normalizing API inputs that accept a single value or a list.

Return type:

List[Any]

Parameters:

x (Any)

torchflow.utils.to_device(obj, device)[source]

Move PyTorch tensors or modules (and nested containers) to device.

Works with tensors, torch.nn.Module, lists, tuples and dicts. If PyTorch is not available this is a no-op.

Parameters:
  • obj (Any)

  • device (Any)

Module contents

A high-level deep learning library built on PyTorch for fast prototyping and experimentation. Provides modules for building, training, and evaluating neural networks with ease.

version: 0.1.0

author: Muyiwa Obadara

license: MIT License

copyright: Copyright 2024, Muyiwa Obadara

repository: https://github.com/maobadara/torchflow

twitter: https://twitter.com/m_aobadara