matchzoo.trainers.trainer

Base Trainer.

Module Contents

class matchzoo.trainers.trainer.Trainer(model:BaseModel, optimizer:optim.Optimizer, trainloader:DataLoader, validloader:DataLoader, device:typing.Optional[torch.device]=None, start_epoch:int=1, epochs:int=10, validate_interval:typing.Optional[int]=None, scheduler:typing.Any=None, clip_norm:typing.Union[float, int]=None, patience:typing.Optional[int]=None, key:typing.Any=None, data_parallel:bool=True, checkpoint:typing.Union[str, Path]=None, save_dir:typing.Union[str, Path]=None, save_all:bool=False, verbose:int=1, **kwargs)

MatchZoo tranier.

Parameters:
  • model – A BaseModel instance.
  • optimizer – A optim.Optimizer instance.
  • trainloader – A :class`DataLoader` instance. The dataloader is used for training the model.
  • validloader – A :class`DataLoader` instance. The dataloader is used for validating the model.
  • device – The desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
  • start_epoch – Int. Number of starting epoch.
  • epochs – The maximum number of epochs for training. Defaults to 10.
  • validate_interval – Int. Interval of validation.
  • scheduler – LR scheduler used to adjust the learning rate based on the number of epochs.
  • clip_norm – Max norm of the gradients to be clipped.
  • patience – Number fo events to wait if no improvement and then stop the training.
  • key – Key of metric to be compared.
  • data_parallel – Bool. Whether support data parallel.
  • checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
  • save_dir – Directory to save trainer.
  • save_all – Bool. If True, save Trainer instance; If False, only save model. Defaults to False.
  • verbose – 0, 1, or 2. Verbosity mode. 0 = silent, 1 = verbose, 2 = one log line per epoch.
_load_dataloader(self, trainloader:DataLoader, validloader:DataLoader, validate_interval:typing.Optional[int]=None)

Load trainloader and determine validate interval.

Parameters:
  • trainloader – A :class`DataLoader` instance. The dataloader is used to train the model.
  • validloader – A :class`DataLoader` instance. The dataloader is used to validate the model.
  • validate_interval – int. Interval of validation.
_load_model(self, model:BaseModel, device:typing.Optional[torch.device], data_parallel:bool=True)

Load model.

Parameters:
  • modelBaseModel instance.
  • device – the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_tensor_type()). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.
  • data_parallel – bool. Whether support data parallel.
_load_path(self, checkpoint:typing.Union[str, Path], save_dir:typing.Union[str, Path])

Load save_dir and Restore from checkpoint.

Parameters:
  • checkpoint – A checkpoint from which to continue training. If None, training starts from scratch. Defaults to None. Should be a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name.
  • save_dir – Directory to save trainer.
_backward(self, loss)

Computes the gradient of current loss graph leaves.

Parameters:loss – Tensor. Loss of model.
_run_scheduler(self)

Run scheduler.

run(self)

Train model.

The processes:
Run each epoch -> Run scheduler -> Should stop early?
_run_epoch(self)

Run each epoch.

The training steps:
  • Get batch and feed them into model
  • Get outputs. Caculate all losses and sum them up
  • Loss backwards and optimizer steps
  • Evaluation
  • Update and output result
evaluate(self, dataloader:DataLoader)

Evaluate the model.

Parameters:dataloader – A DataLoader object to iterate over the data.
classmethod _eval_metric_on_data_frame(cls, metric:BaseMetric, id_left:typing.Any, y_true:typing.Union[list, np.array], y_pred:typing.Union[list, np.array])

Eval metric on data frame.

This function is used to eval metrics for Ranking task.

Parameters:
  • metric – Metric for Ranking task.
  • id_left – id of input left. Samples with same id_left should be grouped for evaluation.
  • y_true – Labels of dataset.
  • y_pred – Outputs of model.
Returns:

Evaluation result.

predict(self, dataloader:DataLoader)

Generate output predictions for the input samples.

Parameters:dataloader – input DataLoader
Returns:predictions
_save(self)

Save.

save_model(self)

Save the model.

save(self)

Save the trainer.

Trainer parameters like epoch, best_so_far, model, optimizer and early_stopping will be savad to specific file path.

Parameters:path – Path to save trainer.
restore_model(self, checkpoint:typing.Union[str, Path])

Restore model.

Parameters:checkpoint – A checkpoint from which to continue training.
restore(self, checkpoint:typing.Union[str, Path]=None)

Restore trainer.

Parameters:checkpoint – A checkpoint from which to continue training.