shadow.js 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. function init_shadow() {
  2. function load_weight(path, shape, is_integer) {
  3. var deferred = new $.Deferred();
  4. var oReq = new XMLHttpRequest();
  5. oReq.open("GET", path, true);
  6. oReq.responseType = "arraybuffer";
  7. oReq.onload = function (oEvent) {
  8. var arrayBuffer = oReq.response; // Note: not oReq.responseText
  9. if (arrayBuffer) {
  10. if(is_integer) {
  11. var weights = new Int32Array(arrayBuffer);
  12. deferred.resolve({"d": weights, "t": "i", "s": shape});
  13. }else{
  14. var weights = new Float32Array(arrayBuffer);
  15. deferred.resolve({"d": weights, "t": "f", "s": shape});
  16. }
  17. }else{
  18. deferred.reject();
  19. }
  20. };
  21. oReq.send(null);
  22. return deferred.promise();
  23. }
  24. function load_descriptor(path) {
  25. var deferred = new $.Deferred();
  26. var oReq = new XMLHttpRequest();
  27. oReq.open("GET", path, true);
  28. oReq.responseType = "text";
  29. oReq.onload = function (oEvent) {
  30. if (this.readyState == 4 && this.status == 200) {
  31. var weight_set = JSON.parse(this.responseText);
  32. var children = [];
  33. for(i in weight_set) {
  34. weight = weight_set[i];
  35. if(weight["t"] == "n") {
  36. children.push(load_descriptor(weight["path"]));
  37. }else{
  38. children.push(load_weight(weight["path"], weight["shape"], weight["t"] == "i"));
  39. }
  40. }
  41. $.when.apply($, children).then(function(){
  42. results = [];
  43. for(i in arguments) {
  44. results.push(arguments[i]);
  45. }
  46. deferred.resolve(results);
  47. }).fail(function(){
  48. deferred.reject();
  49. });
  50. }else{
  51. deferred.reject();
  52. }
  53. };
  54. oReq.send(null);
  55. return deferred.promise();
  56. }
  57. load_descriptor("model/model.json").then(on_ready);
  58. // const v = tf.tensor2d([[1, 2], [3, 4]]);
  59. // const b = tf.tensor1d([1, 2]);
  60. // const W = [tf.tensor2d([[1, 2], [3, 4]])];
  61. // const U = [tf.tensor2d([[1, 2], [3, 4]])];
  62. // const Ub = [tf.tensor1d([1, 2])];
  63. // const Wb = [tf.tensor1d([1, 2])];
  64. function transform_to_tensor(ary, two_ds) {
  65. out = [];
  66. for(var i = 0;i<ary.length;++i)
  67. if(two_ds)
  68. out.push(tf.tensor2d(ary[i]["d"], ary[i]["s"]));
  69. else
  70. out.push(tf.tensor1d(ary[i]["d"]));
  71. return out;
  72. }
  73. function on_ready(weights) {
  74. const v = tf.tensor2d(weights[0]["d"], weights[0]["s"]);
  75. const b = tf.tensor1d(weights[1]["d"]);
  76. const W = transform_to_tensor(weights[2], true);
  77. const U = transform_to_tensor(weights[3], true);
  78. const Ub = transform_to_tensor(weights[4], false);
  79. const Wb = transform_to_tensor(weights[5], false);
  80. function normalize(r, t) {
  81. size = r.length;
  82. return tf.tidy(() => {
  83. r = tf.tensor2d(r, [1, size]);
  84. t = tf.tensor2d(t, [1, size]);
  85. m = tf.mean(r, axis=1, keepDims=true);
  86. r = tf.div(r, m);
  87. out = tf.concat([r, tf.mul(t, r)], 0)
  88. return tf.reshape(out, [-1, size*2]);
  89. });
  90. }
  91. function compute(input) {
  92. return tf.tidy(() => {
  93. function residue_layer(input, i) {
  94. const w = W[i];
  95. const u = U[i];
  96. const ub = Ub[i];
  97. const wb = Wb[i];
  98. const residue = tf.mul(tf.elu(tf.add(tf.matMul(input, u), ub)), input)
  99. const output = tf.add(tf.elu(tf.add(tf.matMul(residue, w), wb)), residue)
  100. return output;
  101. }
  102. a = input;
  103. for(var i = 0;i<W.length;++i)
  104. a = residue_layer(a, i);
  105. const output = tf.elu(tf.add(tf.matMul(a, v), b));
  106. return output;
  107. });
  108. }
  109. function compare(f0, f1) {
  110. return tf.tidy(() => {
  111. return tf.exp(tf.sum(tf.squaredDifference(f0.expandDims(1), f1.expandDims(0)), 2).neg());
  112. });
  113. }
  114. function get_class_score(response, class_indices) {
  115. return tf.tidy(() => {
  116. var class_scores = [];
  117. for(var c in class_indices) {
  118. class_scores.push(tf.max(response.gather(class_indices[c], 1), 1));
  119. }
  120. return tf.stack(class_scores, 1);
  121. });
  122. }
  123. function point_dist(p0, p1) {
  124. return Math.sqrt((p0[0] - p1[0])*(p0[0] - p1[0]) + (p0[1] - p1[1])*(p0[1] - p1[1]));
  125. }
  126. function point_radian(p0, p1){
  127. return Math.atan2(p0[1] - p1[1], p0[0] - p1[0]);
  128. }
  129. function radian_diff(r0, r1) {
  130. delta = r0 - r1;
  131. sign = (delta < 0? -1.0: 1.0);
  132. abs_delta = Math.abs(delta);
  133. while(abs_delta >= 2 * Math.PI)
  134. abs_delta = abs_delta - 2 * Math.PI;
  135. return (abs_delta < - (abs_delta - 2 * Math.PI)? sign * abs_delta : sign * (abs_delta - 2 * Math.PI));
  136. }
  137. function get_polar_stat(contour) {
  138. sx = 0;
  139. sy = 0;
  140. len_contour = contour.length;
  141. size = Math.max(len_contour, 1)
  142. for(var i = 0;i<size;++i) {
  143. sx = sx + contour[i][0];
  144. sy = sy + contour[i][1];
  145. }
  146. centroid = [sx / size, sy / size];
  147. r = new Float32Array(size);
  148. t = new Float32Array(size);
  149. for(var i = 0;i<size;++i) {
  150. r[i] = point_dist(contour[i], centroid);
  151. t[i] = radian_diff(point_radian(contour[(i + 1) % size], centroid), point_radian(contour[i], centroid));
  152. }
  153. return [r, t];
  154. }
  155. const templates = tf.tensor2d(weights[6]["d"], weights[6]["s"]);
  156. const class_lut = weights[7]["d"];
  157. const class_indices = {};
  158. for(var i in class_lut) {
  159. if(class_indices[class_lut[i]] == null) {
  160. var k = class_lut[i];
  161. class_indices[k] = class_lut.map( function( cls, idx ){ return ( cls == k) ? idx : -1 } ).filter(function(item){return item != -1;});
  162. }
  163. }
  164. const class_list = [];
  165. for(var c in class_indices) {
  166. class_list.push(c);
  167. }
  168. function sort_indices(ary) {
  169. indices = Array.apply(null, {length: ary.length}).map(Function.call, Number);
  170. indices.sort(function(a, b){
  171. return ary[b] - ary[a];
  172. });
  173. return indices;
  174. }
  175. function re_order(ary, indices) {
  176. out = [];
  177. for(var i in indices) {
  178. out.push(ary[indices[i]]);
  179. }
  180. return out;
  181. }
  182. window.classify_contour = function(contour_obj, on_inferred_callback) {
  183. var eqi_length = contour_obj.re_contour(256);
  184. r_t = get_polar_stat(eqi_length);
  185. const input = normalize(r_t[0], r_t[1]);
  186. const r0 = compute(input);
  187. const raw = compare(r0, templates);
  188. const class_scores = get_class_score(raw, class_indices);
  189. class_scores.data().then(function(class_scores_cpu){
  190. // console.log(class_scores_cpu,'class_scores_cpu');
  191. var indices = sort_indices(class_scores_cpu);
  192. on_inferred_callback(contour_obj.id, re_order(class_list, indices), re_order(class_scores_cpu, indices));
  193. input.dispose();
  194. r0.dispose();
  195. raw.dispose();
  196. class_scores.dispose();
  197. })
  198. };
  199. }
  200. };