extensible.train_manager#
Classes
|
- class extensible.train_manager.TrainManager(model: torch.nn.modules.module.Module, loss: Callable[[Any, Any], torch.Tensor], epochs: int, train_data: Iterator[Any], eval_data: Union[Dict[str, Iterator[Any]], NoneType] = None, extensions: OrderedDict[str, extensible.extensions.extension.Extension] = <factory>, train_dir: pathlib.Path = PosixPath('2024-04-21 06:19:50.519299'), writer: ploteries.writer.Writer = <class 'extensible.defs.Unassigned'>, device: Union[torch.device, str] = device(type='cpu', index=0), optimizer: Union[torch.optim.optimizer.Optimizer, Callable[[Union[Iterator[torch.nn.parameter.Parameter], Dict[str, Iterator[torch.nn.parameter.Parameter]]]], torch.optim.optimizer.Optimizer]] = <class 'torch.optim.adam.Adam'>, lr_schedule: Union[Any, NoneType] = None)#
Bases:
Extensible- model: Module#
The model applied to each batch. Overload any of
model_forward(),eval_model_forward(),train_model_forward()if needed to call the model correctly. If the model implements methodinitialize_params, that method will be called byinitialize_params().
- loss: Callable[[Any, Any], Tensor]#
The loss computed from the batch output. Overload any of
model_forward(),eval_model_forward(),train_model_forward()if needed to call the loss correctly.
- epochs: int#
The number of training epochs
- train_data: Iterator[Any]#
The training data
- eval_data: Dict[str, Iterator[Any]] | None = None#
Various validation data sets.
- extensions: OrderedDict[str, Extension]#
User-provided extensions. By default, if the keys
'eval_state'and'ckpt_saver'are not included,EvalStateandCheckpointSaverextensions are added at those keys.
- train_dir: Path = PosixPath('2024-04-21 06:19:50.519299')#
Path that will contain the writer output and default checkpoints directory when training; checkpoints will be loaded from here when evaluating
- writer#
A visualization writer. Defaults to a
ploteries.Writerobject with pathf"{train_dir}/ploteries.plts"alias of
Unassigned
- device: device | str = device(type='cpu', index=0)#
The main device used. The model will be moved to this device by
setup().
- optimizer#
The optimizer.
alias of
Adam
- lr_schedule: Any | None = None#
An optional learning rate scheduler
- mode_name: str#
The current mode. Can be one of
'train'or'eval'.
- run_stage(*args, **kwargs)#
The run stage is entered either at the start of
train()or, for standalone evaluations, at the start ofeval()and supports adding hooks to the top-level method being executed.
- get_true_batch_size(batch: Any) int#
Note
When keeping track of the total number of samples, this method needs to be implemented. It should return the actual batch size, which might be different from the nominal size, particularly for the last batch in the epoch. By default, it will return
len(batch), if batch supports it, otherwise1.
- initialize_params()#
By default, this method attempts to call method
self.model.initialize_params(). Ifself.modeldoes not implement that method, it carries out xavier uniform initialization on all parameters with more than one dimension.Note
Consider adding an
initialize_params()method to your model or overloading this method in your train manager. Note that overloads of this method can take fixture arguments.
- model_forward(batch)#
Note
This method likely needs to be overloaded to extract the correct model input from the batch.
Note
Unless an explicit implementation is provided in a derived class, this method takes the place of both
train_model_forward()andeval_model_forward().
- loss_forward(batch, prediction) Tensor#
Note
This method likely needs to be overloaded to extract the correct model input from the batch and the model’s prediction.
Note
Unless an explicit implementation is provided in a derived class, this method takes the place of both
train_loss_forward()andeval_loss_forward().
- eval(eval_data: Dict[str, Iterator[Any]] | None = None)#
Runs evaluation over all the evaluation datasets.
- train()#
Trains the model over multiple epochs, evaluating against all evaluation datasets before the first epoch and after every epoch.