merge batchnrom with convolution
To the best of my knowledge, there is no built-in feature in TensorFlow for folding batch normalization. That being said, it's not that hard to do it manually. One note, there is no such thing as folding dropoutas dropout is simply deactivated at inference time.
To fold batch normalization there is basically three steps:
- Given a TensorFlow graph, filter the variables that need folding,
- Fold the variables,
- Create a new graph with the folded variables.
We need to filter the variables that require folding. When using batch normalization, it creates variables with names containing moving_mean and moving_variance. You can use this to extract fairly easily the variables from layers that used batch norm.
Now that you know which layers used batch norm, for every such layer, you can extract its weights W, bias b, batch norm variance v, mean m, gamma and beta parameters. You need to create a new variable to store the folded weights and biases as follow:
W_new = gamma * W / var b_new = gamma * (b - mean) / var + beta
The last step consists in creating a new graph in which we deactivate batch norm and add biasvariables if necessary –which should be the case for every foldable layer since using bias with batch norm is pointless.
The whole code should look something like below. Depending on the parameters used for the batch norm, your graph may not have gamma or beta.
# ****** (1) Get variables ******
variables = {v.name: session.run(v) for v in tf.global_variables()}
# ****** (2) Fold variables ******
folded_variables = {}
for v in variables.keys():
if not v.endswith('moving_variance:0'):
continue n = get_layer_name(v)
# 'model/conv1/moving_variance:0' --> 'model/conv1' W = variable[n + '/weights:0']
# or "/kernel:0", etc. b = variable[n + '/bias:0']
# if a bias existed before gamma = variable[n + '/gamma:0'] beta = variable[n + '/beta:0'] m = variable[n + '/moving_mean:0'] var = variable[n + '/moving_variance:0']
# folding batch norm W_new = gamma * W / var b_new = gamma * (b - mean) / var + beta # remove `b` if no bias folded_variables[n + '/weights:0'] = W_new folded_variables[n + '/bias:0'] = b_new
# ****** (3) Create new graph ******
new_graph = tf.Graph() new_session = tf.Session(graph=new_graph) network = ...
# instance batch-norm free graph with bias added. # Careful, the names should match the original model for v in tf.global_variables(): try: new_session.run(v.assign(folded_variables[v.name])) except: new_session.run(v.assign(variables[v.name]))