# cntk.train.trainer module¶

A trainer encapsulates the overall training process and employs one or more learners to tune the parameters of a specified model using gradients of parameters w.r.t. a training objective.

class Trainer(model, criterion, parameter_learners, progress_writers=None)[source]

Bases: cntk.cntk_py.Trainer

Class for training the model parameters of a models’ specified loss function, using the specified set of parameter_learners for updating the model’s parameters using computed gradients. An optional specified metric function, which can be non-differentiable, can be used for tracking the trained model’s quality.

Parameters: model (Function) – root node of the function to train criterion (tuple of Function or Variable) – Function with one or two outputs, representing loss and, if given, evaluation metric (in this order). Alternatively, a tuple(loss Function, evaluation Function) is also accepted. parameter_learners (list) – list of learners from cntk.learners progress_writers (progress writer or list of them) – optionally, list of progress writers from cntk.logging to automatically track training progress.
evaluation_function

The evaluation function that the trainer is using.

loss_function

The loss function that the trainer is using.

model

The model that the trainer is training.

parameter_learners

The parameter learners that the trainer is using.

previous_minibatch_evaluation_average

The average evaluation criterion value per sample for the last minibatch trained

previous_minibatch_loss_average

The average training loss per sample for the last minibatch trained

previous_minibatch_sample_count

The number of samples in the last minibatch trained with

print_node_timing()[source]

Prints per-node average timing per-minibatch for each primitive function statistics would reset after print

restore_from_checkpoint(filename)[source]

Restores a checkpoint of the model and Trainer state from the specified file location.

Parameters: filename (str) – filename to restore the checkpoint from
save_checkpoint(filename, external_state={})[source]

Saves a checkpoint of the model and other Trainer state at the specified file location.

In distributed environment the checkpointing is done by the main worker.

Parameters: filename (str) – filename to store the checkpoint. external_state (dict) – additional external state, default is empty.
summarize_test_progress()[source]

Updates the progress writers with the summary of test progress since start and resets the internal accumulators.

summarize_training_progress()[source]

Updates the progress writers with the summary of training progress since start and resets the internal accumulators.

test_minibatch(arguments, device=None)[source]

Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer.

Parameters: arguments – maps variables to their input data. The interpretation depends on the input type: dict: keys are input variable or names, and values are the input data. See forward() for details on passing input data. any other type: if node has a unique input, arguments is mapped to this input. For nodes with more than one input, only dict is allowed. In both cases, every sample in the data will be interpreted as a new sequence. To mark samples as continuations of the previous sequence, specify arguments as tuple: the first element will be used as arguments, and the second one will be used as a list of bools, denoting whether a sequence is a new one (True) or a continuation of the previous one (False). Data should be either NumPy arrays or a MinibatchData instance. device (DeviceDescriptor) – the device descriptor that contains the type and id of the device on which the computation is to be performed.

Note

See forward() for examples on passing input data.

Returns: the average evaluation criterion value per sample for the tested minibatch. float
total_number_of_samples_seen

The number of samples seen globally between all workers from the beginning of training.

train_minibatch(arguments, outputs=None, device=None, is_sweep_end=None)[source]

Optimize model parameters using the specified ‘arguments’ minibatch of training samples.

Parameters: arguments – maps variables to their input data. Empty map signifies end of local training data. The interpretation depends on the input type: dict: keys are input variable or names, and values are the input data. any other type: if node has a unique input, arguments is mapped to this input. For nodes with more than one input, only dict is allowed. In both cases, every sample in the data will be interpreted as a new sequence. To mark samples as continuations of the previous sequence, specify arguments as tuple: the first element will be used as arguments, and the second one will be used as a list of bools, denoting whether a sequence is a new one (True) or a continuation of the previous one (False). Data should be either NumPy arrays or a MinibatchData instance. outputs (iterable) – outputs to fetch values for. device (DeviceDescriptor) – the device descriptor that contains the type and id of the device on which the computation is to be performed. is_sweep_end (bool) – indicate whether this minibatch is at the end of a sweep (of an eopoch), default to None. is used in combination with arguments being fed with numpy arrays data; when the data is from (This) – MinibatchData, is_sweep_end is provided by MinibatchData so there is no need to specify it manually.

Note

See forward() for examples on passing input data.

Returns: If outputs have not been provided, the returned value is True if updates have been performed, False if all parameter learners indicate end of learning (through their update). Otherwise, the return value is a tuple of the that bool and a dictionary that maps the variables in outputs to their respective NumPy arrays. bool or tuple