cntk.train.training_session module¶
A training session encapsulates a typical training loop and binds together a minibatch source that is used for training, a Trainer
and an optional cross validation minibatch source. A training session takes care of consistent checkpointing and progress printing with specified frequencies.
-
class
CheckpointConfig
(filename, frequency=None, restore=True, preserve_all=False)[source]¶ Bases:
cntk.cntk_py.CheckpointConfig
A checkpoint configuration for the training session.
Parameters: - filename (str) – checkpoint file name.
- frequency (int, tuple) – checkpointing period (number samples between checkpoints).
If None, no checkpointing takes place.
If
sys.maxsize
, a single checkpoint is taken at the end of the training. If a tuple of (frequency,DataUnit
), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. SeeDataUnit
for more information on frequency data unit. - restore (bool) – flag, indicating whether to restore from available checkpoint before the start of the training
- preserve_all (bool) – saves all checkpoints, using
filename
as prefix and checkpoint index as a suffix.
-
class
CrossValidationConfig
(minibatch_source=None, frequency=None, minibatch_size=32, callback=None, max_samples=None, model_inputs_to_streams=None, criterion=None, source=None, mb_size=None)[source]¶ Bases:
cntk.cntk_py.CrossValidationConfig
A cross validation configuration for the training session.
Parameters: - minibatch_source (
MinibatchSource
) – minibatch source used for cross validation - frequency (int, tuple) – frequency in samples for cross validation
If None or
sys.maxsize
, a single cross validation is performed at the end of training. If a tuple of (frequency,DataUnit
), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. SeeDataUnit
for more information on frequency data unit. - minibatch_size (int or
minibatch_size_schedule
, defaults to 32) – minibatch schedule for cross validation - callback (func (index, average_error, cv_num_samples, cv_num_minibatches)) – Callback that will be called with frequency which can implement custom cross validation logic, returns False if training should be stopped.
- max_samples (int, default None) – number of samples to perform cross-validation on. If None, all samples are taken.
- model_inputs_to_streams (dict) – mapping between input variables and input streams If None, the mapping provided to the training session constructor is used. Don’t specify this if minibatch_source is a tuple of numpy/scipy arrays.
- criterion (
Function
) – criterion function. Must be specified if minibatch_source is a tuple of numpy/scipy arrays. - source (
MinibatchSource
) – DEPRECATED, use minibatch_source instead - mb_size (int or
minibatch_size_schedule
, defaults to 32) – DEPRECATED, use minibatch_size instead
- minibatch_source (
-
class
DataUnit
[source]¶ Bases:
enum.IntEnum
Indicates that whether the processing steps in the training data is counted by samples, minibatch or epoch.
-
minibatch
= 2¶ Steps on data are counted by samples.
-
sample
= 1¶ Steps on data are counted by samples.
-
sweep
= 0¶ Steps on data are counted by sweeps of epochs.
-
-
class
TestConfig
(minibatch_source=None, minibatch_size=32, model_inputs_to_streams=None, criterion=None, source=None, mb_size=None)[source]¶ Bases:
cntk.cntk_py.TestConfig
A test configuration for the training session.
Parameters: - minibatch_source (
MinibatchSource
) – minibatch source used for cross validation - minibatch_size (int or
minibatch_size_schedule
, defaults to 32) – minibatch schedule for cross validation - model_inputs_to_streams (dict) – mapping between input variables and input streams If None, the mapping provided to the training session constructor is used. Don’t specify this if minibatch_source is a tuple of numpy/scipy arrays.
- criterion (
Function
) – criterion function. Must be specified if minibatch_source is a tuple of numpy/scipy arrays. - source (
MinibatchSource
) – DEPRECATED, use minibatch_source instead - mb_size (int or
minibatch_size_schedule
, defaults to 32) – DEPRECATED, use minibatch_size instead
- minibatch_source (
-
class
TrainingSession
(trainer, mb_source, mb_size, model_inputs_to_streams, max_samples, progress_frequency, checkpoint_config, cv_config, test_config)[source]¶ Bases:
cntk.cntk_py.TrainingSession
The instance of the class should be created by using
training_session()
function.A training session trains a model using the specified
trainer
and configs. Different aspects of training such as data sources, checkpointing, cross validation, progress printing can be configured using the corresponding config classes.Parameters: - trainer (
Trainer
) – trainer - mb_source (
MinibatchSource
) – minibatch source used for training - mb_size (
minibatch_size_schedule
or int) – minibatch size schedule for training - model_inputs_to_streams (dict) – mapping between input variables and input streams
- max_samples (int) – maximum number of samples used for training
- progress_frequency (int, tuple) – the number of samples, minibatches, sweeps of epochs per which aggregated progress is printed
If a tuple of (frequency,
DataUnit
), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. SeeDataUnit
for more information on frequency data unit. - checkpoint_config (
CheckpointConfig
) – checkpoint configuration - cv_config (
CrossValidationConfig
) – cross validation configuration - test_config (
TestConfig
) – test configuration
-
on_cross_validation_end
(index, average_error, num_samples, num_minibatches)[source]¶ Callback that gets executed at the end of cross validation.
Parameters: - index (int) – index of the current callback.
- average_error (float) – average error for the cross validation
- num_samples (int) – number of samples in cross validation
- num_minibatches (int) – number of minibatch in cross validation
Returns: True if training should continue, False otherwise.
-
train
(device=None)[source]¶ Perform training on a specified device.
Parameters: device ( DeviceDescriptor
) – the device descriptor containing the type and id of the device where training takes place.
- trainer (
-
minibatch_size_schedule
(schedule, epoch_size=1)[source]¶ Creates a minibatch size schedule.
Examples
>>> # Use a fixed value 32 for all minibatches >>> s = minibatch_size_schedule(32) >>> s[0], s[1] (32, 32)
>>> # Use minibatches of size 32 for the first 1000 samples, then 64 for the remaining ones >>> s = minibatch_size_schedule([32, 64], 1000) >>> s[0], s[1], s[1000], s[1001] (32, 32, 64, 64)
>>> # Use 32 for the first 12 epochs, then 64 for the next 15, >>> # followed by 128 for the remaining ones, with a 100 samples in an epoch >>> s = minibatch_size_schedule([(12, 32), (15, 64), (1, 128)], 100) >>> s[0], s[1199], s[1200], s[2699], s[2700], s[5000] (32, 32, 64, 64, 128, 128)
Parameters: - schedule (int or list) – if integer, this minibatch size will be used for the whole training.
In case of list of integers, the elements are used as the values for
epoch_size
samples. If list contains pair, the second element is used as a value for (epoch_size
x first element) samples - epoch_size (int) – number of samples as a scheduling unit.
Returns: training parameter schedule
- schedule (int or list) – if integer, this minibatch size will be used for the whole training.
In case of list of integers, the elements are used as the values for
-
training_session
(trainer, mb_source, mb_size, model_inputs_to_streams, progress_frequency=None, max_samples=None, checkpoint_config=None, cv_config=None, test_config=None)[source]¶ A factory function to create a training session object.
Parameters: - trainer (
Trainer
) – trainer - mb_source (
MinibatchSource
) – minibatch source used for training - mb_size (
minibatch_size_schedule
) – minibatch schedule for training - model_inputs_to_streams (dict) – mapping between input variables and input streams
- progress_frequency (int, tuple) – frequency in samples for aggregated progress printing
If a tuple of (frequency,
DataUnit
), the frequency is in terms of either DataUnit.sample, DataUnit.minibatch or DataUnit.sweep. SeeDataUnit
for more information on frequency data unit. - max_samples (int) – maximum number of samples used for training
- checkpoint_config (
CheckpointConfig
) – checkpoint configuration - cv_config (
CrossValidationConfig
) – cross validation configuration - test_config (
TestConfig
) – test configuration
Returns: Instance of
TrainingSession
- trainer (