Getting Started#
Below we present a simplified training routine where we focus on the lines that create stages and inject fixtures:
1class TrainManager(Extensible):
2
3 def train(...):
4
5 with self.staged(
6 "train", {"epoch_num": 0}
7 ):
8 # Train
9 for _ in range(self.fixtures["epoch_num"], self.epochs):
10
11 with self.staged("train_step_epoch"):
12
13 for batch in ... :
14 with self.staged(
15 "train_step_batch",
16 {
17 "batch": batch,
18 "true_batch_size": self.get_true_batch_size(batch),
19 },
20 ):
21 ...
22
23 prediction = self.model.forward(batch)
24 self.fixtures["prediction"] = prediction
25
26 loss = self.loss(batch, prediction)
27 self.fixtures["loss"] = loss
28
29 ...
30
31 self.fixtures.modify("epoch_num", self.fixtures["epoch_num"] + 1)
train: epoch_num
train_step_epoch
train_step_batch: batch, true_batch_size, prediction, loss