extensible.train_manager#

Classes

TrainManager(model, loss, Any], ...)

class extensible.train_manager.TrainManager(model: torch.nn.modules.module.Module, loss: Callable[[Any, Any], torch.Tensor], epochs: int, train_data: Iterator[Any], eval_data: Union[Dict[str, Iterator[Any]], NoneType] = None, extensions: OrderedDict[str, extensible.extensions.extension.Extension] = <factory>, train_dir: pathlib.Path = PosixPath('2024-04-21 06:19:50.519299'), writer: ploteries.writer.Writer = <class 'extensible.defs.Unassigned'>, device: Union[torch.device, str] = device(type='cpu', index=0), optimizer: Union[torch.optim.optimizer.Optimizer, Callable[[Union[Iterator[torch.nn.parameter.Parameter], Dict[str, Iterator[torch.nn.parameter.Parameter]]]], torch.optim.optimizer.Optimizer]] = <class 'torch.optim.adam.Adam'>, lr_schedule: Union[Any, NoneType] = None)#

Bases: Extensible

model: Module#

The model applied to each batch. Overload any of model_forward(), eval_model_forward(), train_model_forward() if needed to call the model correctly. If the model implements method initialize_params, that method will be called by initialize_params().

loss: Callable[[Any, Any], Tensor]#

The loss computed from the batch output. Overload any of model_forward(), eval_model_forward(), train_model_forward() if needed to call the loss correctly.

epochs: int#

The number of training epochs

train_data: Iterator[Any]#

The training data

eval_data: Dict[str, Iterator[Any]] | None = None#

Various validation data sets.

extensions: OrderedDict[str, Extension]#

User-provided extensions. By default, if the keys 'eval_state' and 'ckpt_saver' are not included, EvalState and CheckpointSaver extensions are added at those keys.

train_dir: Path = PosixPath('2024-04-21 06:19:50.519299')#

Path that will contain the writer output and default checkpoints directory when training; checkpoints will be loaded from here when evaluating

writer#

A visualization writer. Defaults to a ploteries.Writer object with path f"{train_dir}/ploteries.plts"

alias of Unassigned

device: device | str = device(type='cpu', index=0)#

The main device used. The model will be moved to this device by setup().

optimizer#

The optimizer.

alias of Adam

lr_schedule: Any | None = None#

An optional learning rate scheduler

mode_name: str#

The current mode. Can be one of 'train' or 'eval'.

run_stage(*args, **kwargs)#

The run stage is entered either at the start of train() or, for standalone evaluations, at the start of eval() and supports adding hooks to the top-level method being executed.

get_true_batch_size(batch: Any) int#

Note

When keeping track of the total number of samples, this method needs to be implemented. It should return the actual batch size, which might be different from the nominal size, particularly for the last batch in the epoch. By default, it will return len(batch), if batch supports it, otherwise 1.

initialize_params()#

By default, this method attempts to call method self.model.initialize_params(). If self.model does not implement that method, it carries out xavier uniform initialization on all parameters with more than one dimension.

Note

Consider adding an initialize_params() method to your model or overloading this method in your train manager. Note that overloads of this method can take fixture arguments.

model_forward(batch)#

Note

This method likely needs to be overloaded to extract the correct model input from the batch.

Note

Unless an explicit implementation is provided in a derived class, this method takes the place of both train_model_forward() and eval_model_forward().

loss_forward(batch, prediction) Tensor#

Note

This method likely needs to be overloaded to extract the correct model input from the batch and the model’s prediction.

Note

Unless an explicit implementation is provided in a derived class, this method takes the place of both train_loss_forward() and eval_loss_forward().

eval(eval_data: Dict[str, Iterator[Any]] | None = None)#

Runs evaluation over all the evaluation datasets.

train()#

Trains the model over multiple epochs, evaluating against all evaluation datasets before the first epoch and after every epoch.