cntk.train.training_session module

A training session encapsulates a typical training loop and binds together a minibatch source that is used for training, a Trainer and an optional cross validation minibatch source. A training session takes care of consistent checkpointing and progress printing with specified frequencies.

class CheckpointConfig(filename, frequency=None, restore=True, preserve_all=False)[source]

Bases: cntk.cntk_py.CheckpointConfig

A checkpoint configuration for the training session.

Parameters:
  • filename (str) – checkpoint file name.
  • frequency (int, tuple) – checkpointing period (number samples between checkpoints). If None, no checkpointing takes place. If sys.maxsize, a single checkpoint is taken at the end of the training. If a tuple of (frequency, DataUnit), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. See DataUnit for more information on frequency data unit.
  • restore (bool) – flag, indicating whether to restore from available checkpoint before the start of the training
  • preserve_all (bool) – saves all checkpoints, using filename as prefix and checkpoint index as a suffix.
class CrossValidationConfig(minibatch_source=None, frequency=None, minibatch_size=32, callback=None, max_samples=None, model_inputs_to_streams=None, criterion=None, source=None, mb_size=None)[source]

Bases: cntk.cntk_py.CrossValidationConfig

A cross validation configuration for the training session.

Parameters:
  • minibatch_source (MinibatchSource) – minibatch source used for cross validation
  • frequency (int, tuple) – frequency in samples for cross validation If None or sys.maxsize, a single cross validation is performed at the end of training. If a tuple of (frequency, DataUnit), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. See DataUnit for more information on frequency data unit.
  • minibatch_size (int or minibatch_size_schedule, defaults to 32) – minibatch schedule for cross validation
  • callback (func (index, average_error, cv_num_samples, cv_num_minibatches)) – Callback that will be called with frequency which can implement custom cross validation logic, returns False if training should be stopped.
  • max_samples (int, default None) – number of samples to perform cross-validation on. If None, all samples are taken.
  • model_inputs_to_streams (dict) – mapping between input variables and input streams If None, the mapping provided to the training session constructor is used. Don’t specify this if minibatch_source is a tuple of numpy/scipy arrays.
  • criterion (Function) – criterion function. Must be specified if minibatch_source is a tuple of numpy/scipy arrays.
  • source (MinibatchSource) – DEPRECATED, use minibatch_source instead
  • mb_size (int or minibatch_size_schedule, defaults to 32) – DEPRECATED, use minibatch_size instead
class DataUnit[source]

Bases: enum.IntEnum

Indicates that whether the processing steps in the training data is counted by samples, minibatch or epoch.

minibatch = 2

Steps on data are counted by samples.

sample = 1

Steps on data are counted by samples.

sweep = 0

Steps on data are counted by sweeps of epochs.

class TestConfig(minibatch_source=None, minibatch_size=32, model_inputs_to_streams=None, criterion=None, source=None, mb_size=None)[source]

Bases: cntk.cntk_py.TestConfig

A test configuration for the training session.

Parameters:
  • minibatch_source (MinibatchSource) – minibatch source used for cross validation
  • minibatch_size (int or minibatch_size_schedule, defaults to 32) – minibatch schedule for cross validation
  • model_inputs_to_streams (dict) – mapping between input variables and input streams If None, the mapping provided to the training session constructor is used. Don’t specify this if minibatch_source is a tuple of numpy/scipy arrays.
  • criterion (Function) – criterion function. Must be specified if minibatch_source is a tuple of numpy/scipy arrays.
  • source (MinibatchSource) – DEPRECATED, use minibatch_source instead
  • mb_size (int or minibatch_size_schedule, defaults to 32) – DEPRECATED, use minibatch_size instead
class TrainingSession(trainer, mb_source, mb_size, model_inputs_to_streams, max_samples, progress_frequency, checkpoint_config, cv_config, test_config)[source]

Bases: cntk.cntk_py.TrainingSession

The instance of the class should be created by using training_session() function.

A training session trains a model using the specified trainer and configs. Different aspects of training such as data sources, checkpointing, cross validation, progress printing can be configured using the corresponding config classes.

Parameters:
  • trainer (Trainer) – trainer
  • mb_source (MinibatchSource) – minibatch source used for training
  • mb_size (minibatch_size_schedule or int) – minibatch size schedule for training
  • model_inputs_to_streams (dict) – mapping between input variables and input streams
  • max_samples (int) – maximum number of samples used for training
  • progress_frequency (int, tuple) – the number of samples, minibatches, sweeps of epochs per which aggregated progress is printed If a tuple of (frequency, DataUnit), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. See DataUnit for more information on frequency data unit.
  • checkpoint_config (CheckpointConfig) – checkpoint configuration
  • cv_config (CrossValidationConfig) – cross validation configuration
  • test_config (TestConfig) – test configuration
on_cross_validation_end(index, average_error, num_samples, num_minibatches)[source]

Callback that gets executed at the end of cross validation.

Parameters:
  • index (int) – index of the current callback.
  • average_error (float) – average error for the cross validation
  • num_samples (int) – number of samples in cross validation
  • num_minibatches (int) – number of minibatch in cross validation
Returns:

True if training should continue, False otherwise.

train(device=None)[source]

Perform training on a specified device.

Parameters:device (DeviceDescriptor) – the device descriptor containing the type and id of the device where training takes place.
minibatch_size_schedule(schedule, epoch_size=1)[source]

Creates a minibatch size schedule.

Examples

>>> # Use a fixed value 32 for all minibatches
>>> s = minibatch_size_schedule(32)
>>> s[0], s[1]
(32, 32)
>>> # Use minibatches of size 32 for the first 1000 samples, then 64 for the remaining ones
>>> s = minibatch_size_schedule([32, 64], 1000)
>>> s[0], s[1], s[1000], s[1001]
(32, 32, 64, 64)
>>> # Use 32 for the first 12 epochs, then 64 for the next 15,
>>> # followed by 128 for the remaining ones, with a 100 samples in an epoch
>>> s = minibatch_size_schedule([(12, 32), (15, 64), (1, 128)], 100)
>>> s[0], s[1199], s[1200], s[2699], s[2700], s[5000]
(32, 32, 64, 64, 128, 128)
Parameters:
  • schedule (int or list) – if integer, this minibatch size will be used for the whole training. In case of list of integers, the elements are used as the values for epoch_size samples. If list contains pair, the second element is used as a value for (epoch_size x first element) samples
  • epoch_size (int) – number of samples as a scheduling unit.
Returns:

training parameter schedule

training_session(trainer, mb_source, mb_size, model_inputs_to_streams, progress_frequency=None, max_samples=None, checkpoint_config=None, cv_config=None, test_config=None)[source]

A factory function to create a training session object.

Parameters:
  • trainer (Trainer) – trainer
  • mb_source (MinibatchSource) – minibatch source used for training
  • mb_size (minibatch_size_schedule) – minibatch schedule for training
  • model_inputs_to_streams (dict) – mapping between input variables and input streams
  • progress_frequency (int, tuple) – frequency in samples for aggregated progress printing If a tuple of (frequency, DataUnit), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. See DataUnit for more information on frequency data unit.
  • max_samples (int) – maximum number of samples used for training
  • checkpoint_config (CheckpointConfig) – checkpoint configuration
  • cv_config (CrossValidationConfig) – cross validation configuration
  • test_config (TestConfig) – test configuration
Returns:

Instance of TrainingSession