cntk.train.trainer module¶
A trainer encapsulates the overall training process and employs one or more
learners
to tune the parameters of a specified model
using gradients of parameters w.r.t. a training objective.
-
class
Trainer
(model, criterion, parameter_learners, progress_writers=None)[source]¶ Bases:
cntk.cntk_py.Trainer
Class for training the model parameters of a models’ specified loss function, using the specified set of
parameter_learners
for updating the model’s parameters using computed gradients. An optional specified metric function, which can be non-differentiable, can be used for tracking the trained model’s quality.Parameters: - model (
Function
) – root node of the function to train - criterion (tuple of
Function
orVariable
) – Function with one or two outputs, representing loss and, if given, evaluation metric (in this order). Alternatively, a tuple(loss Function, evaluation Function) is also accepted. - parameter_learners (list) – list of learners from
cntk.learners
- progress_writers (progress writer or list of them) – optionally, list of
progress writers from
cntk.logging
to automatically track training progress.
-
evaluation_function
¶ The evaluation function that the trainer is using.
-
loss_function
¶ The loss function that the trainer is using.
-
model
¶ The model that the trainer is training.
-
parameter_learners
¶ The parameter learners that the trainer is using.
-
previous_minibatch_evaluation_average
¶ The average evaluation criterion value per sample for the last minibatch trained
-
previous_minibatch_loss_average
¶ The average training loss per sample for the last minibatch trained
-
previous_minibatch_sample_count
¶ The number of samples in the last minibatch trained with
-
print_node_timing
()[source]¶ Prints per-node average timing per-minibatch for each primitive function statistics would reset after print
-
restore_from_checkpoint
(filename)[source]¶ Restores a checkpoint of the model and Trainer state from the specified file location.
Parameters: filename (str) – filename to restore the checkpoint from
-
save_checkpoint
(filename, external_state={})[source]¶ Saves a checkpoint of the model and other Trainer state at the specified file location.
In distributed environment the checkpointing is done by the main worker.
Parameters: - filename (str) – filename to store the checkpoint.
- external_state (dict) – additional external state, default is empty.
-
summarize_test_progress
()[source]¶ Updates the progress writers with the summary of test progress since start and resets the internal accumulators.
-
summarize_training_progress
()[source]¶ Updates the progress writers with the summary of training progress since start and resets the internal accumulators.
-
test_minibatch
(arguments, device=None)[source]¶ Test the model on the specified batch of samples using the evaluation Function specified during construction of the Trainer.
Parameters: - arguments –
maps variables to their input data. The interpretation depends on the input type:
- dict: keys are input variable or names, and values are the input data.
See
forward()
for details on passing input data. - any other type: if node has a unique input,
arguments
is mapped to this input. For nodes with more than one input, only dict is allowed.
In both cases, every sample in the data will be interpreted as a new sequence. To mark samples as continuations of the previous sequence, specify
arguments
as tuple: the first element will be used asarguments
, and the second one will be used as a list of bools, denoting whether a sequence is a new one (True) or a continuation of the previous one (False). Data should be either NumPy arrays or aMinibatchData
instance. - dict: keys are input variable or names, and values are the input data.
See
- device (
DeviceDescriptor
) – the device descriptor that contains the type and id of the device on which the computation is to be performed.
Note
See
forward()
for examples on passing input data.Returns: the average evaluation criterion value per sample for the tested minibatch. Return type: float - arguments –
-
total_number_of_samples_seen
¶ The number of samples seen globally between all workers from the beginning of training.
-
train_minibatch
(arguments, outputs=None, device=None, is_sweep_end=None)[source]¶ Optimize model parameters using the specified ‘arguments’ minibatch of training samples.
Parameters: - arguments –
maps variables to their input data. Empty map signifies end of local training data. The interpretation depends on the input type:
- dict: keys are input variable or names, and values are the input data.
- any other type: if node has a unique input,
arguments
is mapped to this input. For nodes with more than one input, only dict is allowed.
In both cases, every sample in the data will be interpreted as a new sequence. To mark samples as continuations of the previous sequence, specify
arguments
as tuple: the first element will be used asarguments
, and the second one will be used as a list of bools, denoting whether a sequence is a new one (True) or a continuation of the previous one (False). Data should be either NumPy arrays or aMinibatchData
instance. - outputs (iterable) – outputs to fetch values for.
- device (
DeviceDescriptor
) – the device descriptor that contains the type and id of the device on which the computation is to be performed. - is_sweep_end (bool) – indicate whether this minibatch is at the end of a sweep (of an eopoch), default to None.
- is used in combination with arguments being fed with numpy arrays data; when the data is from (This) –
MinibatchData
, is_sweep_end is provided byMinibatchData
so there is no need to specify it manually.
Note
See
forward()
for examples on passing input data.Returns: If outputs
have not been provided, the returned value is True if updates have been performed, False if all parameter learners indicate end of learning (through their update). Otherwise, the return value is a tuple of the that bool and a dictionary that maps the variables in outputs to their respective NumPy arrays.Return type: bool or tuple - arguments –
- model (