Source code for cntk.misc.converter

# ==============================================================================
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import cntk as C

[docs]def convert(root_func, filter, converter): ''' Clones the graph underlying root_func and in the clone substitutes all Functions obtained by applying 'filter', with a new Function obtained by calling the specified 'converter' Args: root_func: a root function of a graph to be cloned and converted filter: a lambda for filtering out the Functions to be converted converter: a lambda for obtaining the substitute for each of the Functions to be converted Returns: Cloned and converted Function (graph) ''' # recursively convert for blocks in root_func blocks = C.logging.graph.depth_first_search(root_func, lambda x : type(x) == C.Function and x.root_function.is_block, depth = 0) for i in range(len(blocks)): # search for blocks again in case block input/output has been modified blocks1 = C.logging.graph.depth_first_search(root_func, lambda x : type(x) == C.Function and x.root_function.is_block, depth = 0) block = blocks1[i] # assuming depth_first_search order to be stable, so use the old index on new search results block_root = C.as_composite(block.block_root) new_block_root = convert(block_root, filter, converter) if new_block_root != block_root: block_arguments_mapping = dict(block.block_arguments_mapping) new_block_arguments_mapping = [] for arg, new_arg in zip(block_root.arguments, new_block_root.arguments): new_block_arguments_mapping += [(new_arg, block_arguments_mapping[arg])] new_block = C.as_block(new_block_root, new_block_arguments_mapping, block.op_name, block.name) if all([x not in root_func.outputs for x in block.outputs]) or all([x in block.outputs for x in root_func.outputs]): root_func = root_func.clone(C.CloneMethod.share, dict(zip(block.outputs, new_block.outputs))) else: new_outputs = [new_block.outputs[block.outputs.index(x)] if x in block.outputs else None for x in root_func.outputs] root_func_nonreplaced = C.combine([x for x in root_func.outputs if x not in block.outputs]) root_func_nonreplaced_clone = root_func_nonreplaced.clone(C.CloneMethod.share, dict(zip(block.outputs, new_block.outputs))) idx = 0 for nonreplaced_output in root_func_nonreplaced_clone.outputs: while new_outputs[idx]: idx += 1 new_outputs[idx] = nonreplaced_output root_func = C.combine(new_outputs) # replace all Function instances under root_func that pass the specified 'filter' functions_to_convert = C.logging.graph.depth_first_search(root_func, filter, depth = 0) for i in range(len(functions_to_convert)): # The graph could be modified already by this function, so we need to rescan to the new set. functions_to_convert1 = C.logging.graph.depth_first_search(root_func, filter, depth = 0) # We are using a filter passed in by the caller. So once a function is converted, we may not # get the same number of functions again, so we need to use correct index depending on the new size. index = 0 if len(functions_to_convert) > len(functions_to_convert1): assert(len(functions_to_convert) - len(functions_to_convert1) == i) # Only one conversion at a time. # index = 0 will work for this case, we are picking the first function from the new list. elif len(functions_to_convert) == len(functions_to_convert1): index = i # here we pick the current index of the for loop. else: raise RuntimeError("The conversion adds another possible conversion(s). Stopping infinite conversions.") function_to_convert = functions_to_convert1[index] converted = converter(function_to_convert) if not function_to_convert.output in root_func.outputs: root_func = root_func.clone(C.CloneMethod.share, {function_to_convert.output : converted.output}) else: # if cudnn_rnn output is the root_func output, just use converted as root_func and no clone needed if len(root_func.outputs) > 1: root_func = C.combine([converted if x == function_to_convert.output else x for x in root_func.outputs]) else: root_func = converted return root_func