Getting Started#

Below we present a simplified training routine where we focus on the lines that create stages and inject fixtures:

 1class TrainManager(Extensible):
 2
 3  def train(...):
 4
 5    with self.staged(
 6        "train", {"epoch_num": 0}
 7    ):
 8      # Train
 9      for _ in range(self.fixtures["epoch_num"], self.epochs):
10
11        with self.staged("train_step_epoch"):
12
13          for batch in ... :
14            with self.staged(
15                "train_step_batch",
16                {
17                    "batch": batch,
18                    "true_batch_size": self.get_true_batch_size(batch),
19                },
20            ):
21              ...
22
23              prediction = self.model.forward(batch)
24              self.fixtures["prediction"] = prediction
25
26              loss = self.loss(batch, prediction)
27              self.fixtures["loss"] = loss
28
29              ...
30
31          self.fixtures.modify("epoch_num", self.fixtures["epoch_num"] + 1)
  • train: epoch_num

    • train_step_epoch

      • train_step_batch: batch, true_batch_size, prediction, loss