# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================
import tensorflow as tf
import numpy as np
from cntk.contrib import crosstalk as cstk
VariableType = 'Variable'
TrainableType = 'Trainable'
DictTrainableType = 'DictTrainable'
[docs]def find_trainable(name, scope=None):
'''
Find a single trainable variable in a function by name when the function has multiple parameters.
Args:
func: The function to search
name (`str`): The name of the parameter
scope (`str`): The scope of the search
Returns:
The trainable variable that is found
'''
found = [tp for tp in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) if name in tp.name]
if len(found)==0:
raise Exception('not found')
elif len(found) > 1:
raise Exception('more than 1 found')
return found[0]
def _trainable_setter(sess):
def _set(p, raw_value, attr=None):
if p.get_shape() != raw_value.shape:
v = raw_value.reshape(p.get_shape())
else:
v = raw_value
tf.assign(p, v).eval(session=sess)
return _set
def _trainable_getter(sess):
def _get(p, attr=None):
return p.eval(sess)
return _get
def _dict_trainable_setter(sess):
def _set(td, raw_value, attr=None):
if len(td) != len(raw_value):
raise Exception('mismatch len')
if td.keys() != raw_value.keys():
raise Exception('mismatch keys')
for k in td.keys():
_trainable_setter(sess)(td[k], raw_value[k])
return _set
def _dict_trainable_getter(sess):
def _get(td, attr=None):
return {k : _trainable_getter(sess)(td[k]) for k in td.keys()}
return _get
def _variable_getter(sess, data):
def _get(p, attr=None):
return p.eval(data, session=sess)
return _get
def _conv2d_getter(sess):
def _get(pd, attr):
W = _trainable_getter(sess)(pd.W)
#handling input with sequence axis:
W_rank = len(W.shape)
#the transpose from tf [H, W, C] to cntk's [C, H, W] happens at the tailing axes excluding the leading dynamic
#axes (batch and sequence axes) in the data format:
axis_perm = (list(range(W_rank - 3)) if W_rank > 3 else []) + [i + W_rank - 3 for i in [2,0,1]]
if pd.b:
b = _trainable_getter(sess)(pd.b)
else:
b = None
return cstk.Conv2DArgs(W=W.transpose(axis_perm), b=b.reshape(attr.num_filters,))
return _get
def _conv2d_setter(sess):
def _set(pd, raw_value, attr):
_trainable_setter(sess)(pd.W, raw_value.W.transpose(1,2,0))
if pd.b:
_trainable_setter(sess)(pd.b, raw_value.b)
return _set
def _adjust_forget_bias(all_bias, hidden_dim, forget_bias):
i,m,f,o = np.split(all_bias, 4)
f += forget_bias
return np.concatenate((i,m,f,o))
def _rnn_trainable_in_scope(scope):
if tf.VERSION.startswith('0.12'):
fw_M=find_trainable('Matrix', scope=scope+'/FW')
fw_b=find_trainable('Bias', scope=scope+'/FW')
bw_M=find_trainable('Matrix', scope=scope+'/BW')
bw_b=find_trainable('Bias', scope=scope+'/BW')
elif tf.VERSION.startswith('1'):
if tf.VERSION.startswith('1.1'):
fw_M=find_trainable('weights', scope=scope+'/fw')
fw_b=find_trainable('biases', scope=scope+'/fw')
bw_M=find_trainable('weights', scope=scope+'/bw')
bw_b=find_trainable('biases', scope=scope+'/bw')
else: # the following changes started with version '1.2' until as of version 1.7 for now
fw_M = find_trainable('kernel', scope=scope + '/fw')
fw_b = find_trainable('bias', scope=scope + '/fw')
bw_M = find_trainable('kernel', scope=scope + '/bw')
bw_b = find_trainable('bias', scope=scope + '/bw')
else:
raise Exception('only supports 0.12.* and 1.*')
return fw_M, fw_b, bw_M, bw_b
def _rnn_getter(sess):
def _get(scope, attr):
if not attr.bidirectional:
raise NotImplementedError()
fw_M, fw_b, bw_M, bw_b = _rnn_trainable_in_scope(scope)
fw_W, fw_H = np.split(_trainable_getter(sess)(fw_M), [attr.input_dim])
fw_b = _adjust_forget_bias(_trainable_getter(sess)(fw_b), attr.hidden_dim, attr.forget_bias)
bw_W, bw_H = np.split(_trainable_getter(sess)(bw_M), [attr.input_dim])
bw_b = _adjust_forget_bias(_trainable_getter(sess)(bw_b), attr.hidden_dim, attr.forget_bias)
return cstk.RnnArgs(fw_W=fw_W, fw_H=fw_H, fw_b=fw_b, bw_W=bw_W, bw_H=bw_H, bw_b=bw_b)
return _get
def _rnn_setter(sess):
def _set(scope, raw_value, attr):
fw_M, fw_b, bw_M, bw_b = _rnn_trainable_in_scope(scope)
if not attr.bidirectional:
raise NotImplementedError()
_trainable_setter(sess)(fw_M, np.concatenate((raw_value.fw_W, raw_value.fw_H)))
_trainable_setter(sess)(fw_b, _adjust_forget_bias(raw_value.fw_b, attr.hidden_dim, -attr.forget_bias))
_trainable_setter(sess)(bw_M, np.concatenate((raw_value.bw_W, raw_value.bw_H)))
_trainable_setter(sess)(bw_b, _adjust_forget_bias(raw_value.bw_b, attr.hidden_dim, -attr.forget_bias))
return _set
def _embed_getter(sess):
def _get(p, attr):
map = {}
value = _trainable_getter(sess)(p)
for i in range(attr.input_dim):
map[attr.dict[i]] = value[i,:]
return map
return _get
def _embed_setter(sess):
def _set(p, raw_value, attr):
out = [None]*attr.input_dim
for w in raw_value.keys():
out[attr.dict.index(w)] = raw_value[w]
_trainable_setter(sess)(p, np.asarray(out))
return _set
[docs]class TensorFlowCrosstalk(cstk.Crosstalk):
'''
TensorFlow implementation for crosstalk
'''
def __init__(self):
super(TensorFlowCrosstalk, self).__init__()
[docs] def set_data(self, sess, data):
'''
Set session and mapped data for setter/getters
Args:
sess : The tensorflow session
data : The input data feed dict for eval
'''
super(TensorFlowCrosstalk, self).register_funcs(TrainableType, setter=_trainable_setter(sess), getter=_trainable_getter(sess))
super(TensorFlowCrosstalk, self).register_funcs(DictTrainableType, setter=_dict_trainable_setter(sess), getter=_dict_trainable_getter(sess))
super(TensorFlowCrosstalk, self).register_funcs(VariableType, getter=_variable_getter(sess, data))
super(TensorFlowCrosstalk, self).register_funcs(cstk.Conv2DAttr, setter=_conv2d_setter(sess), getter=_conv2d_getter(sess))
super(TensorFlowCrosstalk, self).register_funcs(cstk.RnnAttr, setter=_rnn_setter(sess), getter=_rnn_getter(sess))
super(TensorFlowCrosstalk, self).register_funcs(cstk.EmbedAttr, setter=_embed_setter(sess), getter=_embed_getter(sess))
[docs] def is_trainable(self, name):
'''
Check if variable with name is a trainable
Args:
name (`str`): Variable name to check
'''
var_type = self.vars[name].type
return var_type != VariableType
[docs] def load_all_trainables(self):
'''
Load all trainables from files in working directory
'''
self.load([n for n in self.vars.keys() if self.is_trainable(n)])
[docs] def save_all_trainables(self):
'''
Save all trainables to files in working directory
'''
self.save([n for n in self.vars.keys() if self.is_trainable(n)])
instance = TensorFlowCrosstalk()