extensible.extensions.utils#
Classes
|
Computes the average of a scalar fixture across an entire evaluation epoch |
- class extensible.extensions.utils.Avg(target, accum=None, visualize: bool = True, fig: str = 'avgs/{datasource_name}/{target}', counter='num_batches')#
Bases:
ExtensionComputes the average of a scalar fixture across an entire evaluation epoch
Optionally, plots a visualization of the average at the end of the evaluation epoch.
This extension can be used by adding an instance of it to a train manager after an extension that generates the fixture to average. Another possibility is to derive it, e.g., :
class ThresholdedLoss(Avg): def __init__(self, *args, **kwargs): super().__init__("thresh_loss", *args, **kwargs) def post_eval_step_batch(self, train_manager, fixtures, batch, prediction): # Add the fixture to average. hard_prediction = prediction>0.5 fixtures["thresh_loss"] = train_manager.loss(targets, hard_prediction) # Call the super's hook to carry out the averaging fixtures(super().post_eval_step_batch)
Exposes the following fixtures:
self.accum, which by default is'accum_' + self.target
Depends on
EvalState- Parameters:
target – The name of the fixture to accumulate
accum – The name of the fixture to create as an accumulator
visualize – Whether to create a scalar graph
fig – The name of the figure visualization will be placed – can include formatting placeholders for any fixture and
targetandaccum.counter' – The counter that will be used to divide the running sum