Training

1
2
3
4
5
6
7
8
from yann.train import Trainer

train = Trainer(
  model='resnet18',
  dataset='MNIST',
  optimizer='AdamW',
  loss='cross_entropy'
)

Register callbacks

1
2
3
@train.on('epoch_end')
def sync_data():
  pass

Implement custom step logic

1
2
3
4
5
6
7
8
for e in train.epochs(4):
  for inputs, targets in train.batches():
    train.optimizer.zero_grad()
    outputs = train.model(inputs)
    loss = train.loss(outputs, targets)
    loss.backwards()

  train.checkpoint()

Checkpointing

1
train.checkpoint()

Continue where you left off

1
train.load_checkpoint('latest')

History

1
train.history.plot()

Functional Interface

1
2
3
4
5
from yann.train import train

for epoch in range(10):
  for _ in train(model, loader, optimizer, loss='cross_entropy', device='cuda'):
    pass