Training
1 2 3 4 5 6 7 8 | from yann.train import Trainer train = Trainer( model='resnet18', dataset='MNIST', optimizer='AdamW', loss='cross_entropy' ) |
Register callbacks
1 2 3 | @train.on('epoch_end') def sync_data(): pass |
Implement custom step logic
1 2 3 4 5 6 7 8 | for e in train.epochs(4): for inputs, targets in train.batches(): train.optimizer.zero_grad() outputs = train.model(inputs) loss = train.loss(outputs, targets) loss.backwards() train.checkpoint() |
Checkpointing
1 | train.checkpoint() |
Continue where you left off
1 | train.load_checkpoint('latest') |
History
1 | train.history.plot() |
Functional Interface
1 2 3 4 5 | from yann.train import train for epoch in range(10): for _ in train(model, loader, optimizer, loss='cross_entropy', device='cuda'): pass |