dump_rnn.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. #!/usr/bin/python
  2. from __future__ import print_function
  3. from keras.models import Sequential
  4. from keras.layers import Dense
  5. from keras.layers import LSTM
  6. from keras.layers import GRU
  7. from keras.models import load_model
  8. from keras import backend as K
  9. import sys
  10. import re
  11. import numpy as np
  12. def printVector(f, ft, vector, name):
  13. v = np.reshape(vector, (-1));
  14. #print('static const float ', name, '[', len(v), '] = \n', file=f)
  15. f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v)))
  16. for i in range(0, len(v)):
  17. f.write('{}'.format(min(127, int(round(256*v[i])))))
  18. ft.write('{}'.format(min(127, int(round(256*v[i])))))
  19. if (i!=len(v)-1):
  20. f.write(',')
  21. else:
  22. break;
  23. ft.write(" ")
  24. if (i%8==7):
  25. f.write("\n ")
  26. else:
  27. f.write(" ")
  28. #print(v, file=f)
  29. f.write('\n};\n\n')
  30. ft.write("\n")
  31. return;
  32. def printLayer(f, ft, layer):
  33. weights = layer.get_weights()
  34. activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
  35. if len(weights) > 2:
  36. ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]/3))
  37. else:
  38. ft.write('{} {} '.format(weights[0].shape[0], weights[0].shape[1]))
  39. if activation == 'SIGMOID':
  40. ft.write('1\n')
  41. elif activation == 'RELU':
  42. ft.write('2\n')
  43. else:
  44. ft.write('0\n')
  45. printVector(f, ft, weights[0], layer.name + '_weights')
  46. if len(weights) > 2:
  47. printVector(f, ft, weights[1], layer.name + '_recurrent_weights')
  48. printVector(f, ft, weights[-1], layer.name + '_bias')
  49. name = layer.name
  50. if len(weights) > 2:
  51. f.write('static const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n'
  52. .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation))
  53. else:
  54. f.write('static const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n'
  55. .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation))
  56. def structLayer(f, layer):
  57. weights = layer.get_weights()
  58. name = layer.name
  59. if len(weights) > 2:
  60. f.write(' {},\n'.format(weights[0].shape[1]/3))
  61. else:
  62. f.write(' {},\n'.format(weights[0].shape[1]))
  63. f.write(' &{},\n'.format(name))
  64. def foo(c, name):
  65. return None
  66. def mean_squared_sqrt_error(y_true, y_pred):
  67. return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
  68. model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo})
  69. weights = model.get_weights()
  70. f = open(sys.argv[2], 'w')
  71. ft = open(sys.argv[3], 'w')
  72. f.write('/*This file is automatically generated from a Keras model*/\n\n')
  73. f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n#include "rnn_data.h"\n\n')
  74. ft.write('rnnoise-nu model file version 1\n')
  75. layer_list = []
  76. for i, layer in enumerate(model.layers):
  77. if len(layer.get_weights()) > 0:
  78. printLayer(f, ft, layer)
  79. if len(layer.get_weights()) > 2:
  80. layer_list.append(layer.name)
  81. f.write('const struct RNNModel model_{} = {{\n'.format(sys.argv[4]))
  82. for i, layer in enumerate(model.layers):
  83. if len(layer.get_weights()) > 0:
  84. structLayer(f, layer)
  85. f.write('};\n')
  86. #hf.write('struct RNNState {\n')
  87. #for i, name in enumerate(layer_list):
  88. # hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper()))
  89. #hf.write('};\n')
  90. f.close()