matchzoo.trainers
¶
Submodules¶
Package Contents¶
Classes¶
MatchZoo tranier. |
-
class
matchzoo.trainers.
Trainer
(model: BaseModel, optimizer: optim.Optimizer, trainloader: DataLoader, validloader: DataLoader, device: typing.Union[torch.device, int, list, None] = None, start_epoch: int = 1, epochs: int = 10, validate_interval: typing.Optional[int] = None, scheduler: typing.Any = None, clip_norm: typing.Union[float, int] = None, patience: typing.Optional[int] = None, key: typing.Any = None, checkpoint: typing.Union[str, Path] = None, save_dir: typing.Union[str, Path] = None, save_all: bool = False, verbose: int = 1, **kwargs)¶ MatchZoo tranier.
- Parameters
model – A
BaseModel
instance.optimizer – A
optim.Optimizer
instance.trainloader – A :class`DataLoader` instance. The dataloader is used for training the model.
validloader – A :class`DataLoader` instance. The dataloader is used for validating the model.
device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.
start_epoch – Int. Number of starting epoch.
epochs – The maximum number of epochs for training. Defaults to 10.
validate_interval – Int. Interval of validation.
scheduler – LR scheduler used to adjust the learning rate based on the number of epochs.
clip_norm – Max norm of the gradients to be clipped.
patience – Number fo events to wait if no improvement and then stop the training.
key – Key of metric to be compared.
checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
save_dir – Directory to save trainer.
save_all – Bool. If True, save Trainer instance; If False, only save model. Defaults to False.
verbose – 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch.
-
_load_dataloader
(self, trainloader: DataLoader, validloader: DataLoader, validate_interval: typing.Optional[int] = None)¶ Load trainloader and determine validate interval.
- Parameters
trainloader – A :class`DataLoader` instance. The dataloader is used to train the model.
validloader – A :class`DataLoader` instance. The dataloader is used to validate the model.
validate_interval – int. Interval of validation.
-
_load_model
(self, model: BaseModel, device: typing.Union[torch.device, int, list, None] = None)¶ Load model.
- Parameters
model –
BaseModel
instance.device – The desired device of returned tensor. Default: if None, use the current device. If torch.device or int, use device specified by user. If list, use data parallel.
-
_load_path
(self, checkpoint: typing.Union[str, Path], save_dir: typing.Union[str, Path])¶ Load save_dir and Restore from checkpoint.
- Parameters
checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
save_dir – Directory to save trainer.
-
_backward
(self, loss)¶ Computes the gradient of current loss graph leaves.
- Parameters
loss – Tensor. Loss of model.
-
_run_scheduler
(self)¶ Run scheduler.
-
run
(self)¶ Train model.
- The processes:
Run each epoch -> Run scheduler -> Should stop early?
-
_run_epoch
(self)¶ Run each epoch.
- The training steps:
Get batch and feed them into model
Get outputs. Caculate all losses and sum them up
Loss backwards and optimizer steps
Evaluation
Update and output result
-
evaluate
(self, dataloader: DataLoader)¶ Evaluate the model.
- Parameters
dataloader – A DataLoader object to iterate over the data.
-
classmethod
_eval_metric_on_data_frame
(cls, metric: BaseMetric, id_left: typing.Any, y_true: typing.Union[list, np.array], y_pred: typing.Union[list, np.array])¶ Eval metric on data frame.
This function is used to eval metrics for Ranking task.
- Parameters
metric – Metric for Ranking task.
id_left – id of input left. Samples with same id_left should be grouped for evaluation.
y_true – Labels of dataset.
y_pred – Outputs of model.
- Returns
Evaluation result.
-
predict
(self, dataloader: DataLoader) → np.array¶ Generate output predictions for the input samples.
- Parameters
dataloader – input DataLoader
- Returns
predictions
-
_save
(self)¶ Save.
-
save_model
(self)¶ Save the model.
-
save
(self)¶ Save the trainer.
Trainer parameters like epoch, best_so_far, model, optimizer and early_stopping will be savad to specific file path.
- Parameters
path – Path to save trainer.
-
restore_model
(self, checkpoint: typing.Union[str, Path])¶ Restore model.
- Parameters
checkpoint – A checkpoint from which to continue training.
-
restore
(self, checkpoint: typing.Union[str, Path] = None)¶ Restore trainer.
- Parameters
checkpoint – A checkpoint from which to continue training.