test.js 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import assert from "assert";
  2. import {promises as fs} from "fs";
  3. import path from "path";
  4. import http from "http";
  5. import N from "../index.js";
  6. describe("npyjs parser", function () {
  7. it("should correctly parse npy files", async function () {
  8. const server = http.createServer(async function (req, res) {
  9. const fpath = path.resolve(req.url.slice(1));
  10. const data = await fs.readFile(fpath);
  11. res.writeHead(200);
  12. res.end(data);
  13. });
  14. server.listen();
  15. const {port} = server.address()
  16. const records = JSON.parse(await fs.readFile("test/records.json"));
  17. const n = new N();
  18. for (const fname in records) {
  19. const fpath = path.join("test", `${fname}.npy`)
  20. const data = await n.load(`http://localhost:${port}/${fpath}`)
  21. // Get the last 5 values for comparison
  22. const resultValues = Array.prototype.slice.call(
  23. data.data.slice(-5)
  24. );
  25. // Compare with expected values
  26. resultValues.forEach((actual, j) => {
  27. const expected = records[fname][j];
  28. // Use approximate equality for floating point comparisons
  29. if (data.dtype.includes('float')) {
  30. assert.ok(
  31. Math.abs(actual - expected) < 1e-5,
  32. `${fname}: Expected ${expected} but got ${actual} at index ${j}`
  33. );
  34. } else if (data.dtype.includes('int64') || data.dtype.includes('uint64')) {
  35. // Convert BigInt to string for comparison
  36. assert.strictEqual(
  37. actual.toString(),
  38. expected.toString(),
  39. `${fname}: Expected ${expected} but got ${actual} at index ${j}`
  40. );
  41. } else {
  42. assert.strictEqual(
  43. actual,
  44. expected,
  45. `${fname}: Expected ${expected} but got ${actual} at index ${j}`
  46. );
  47. }
  48. });
  49. }
  50. server.close();
  51. });
  52. // Add specific test for float16 conversion
  53. it("should correctly convert float16 to float32", function() {
  54. const n = new N();
  55. // Test some known float16 to float32 conversions
  56. const testCases = [
  57. { input: 0x0000, expected: 0 }, // Zero
  58. { input: 0x8000, expected: -0 }, // Negative zero
  59. { input: 0x3C00, expected: 1 }, // One
  60. { input: 0xBC00, expected: -1 }, // Negative one
  61. { input: 0x7C00, expected: Infinity }, // Infinity
  62. { input: 0xFC00, expected: -Infinity }, // Negative infinity
  63. { input: 0x7E00, expected: NaN }, // NaN
  64. { input: 0x3200, expected: 0.1875 } // 1.5 * 2^-9
  65. ];
  66. testCases.forEach(({input, expected}) => {
  67. const result = N.float16ToFloat32(input);
  68. if (Number.isNaN(expected)) {
  69. assert.ok(Number.isNaN(result), `Expected NaN for input 0x${input.toString(16)}`);
  70. } else {
  71. assert.strictEqual(
  72. result,
  73. expected,
  74. `Failed converting 0x${input.toString(16)}: expected ${expected}, got ${result}`
  75. );
  76. }
  77. });
  78. });
  79. it("should handle float16 data based on conversion flag", async function() {
  80. const server = http.createServer(async function (req, res) {
  81. const fpath = path.resolve(req.url.slice(1));
  82. const data = await fs.readFile(fpath);
  83. res.writeHead(200);
  84. res.end(data);
  85. });
  86. server.listen();
  87. const {port} = server.address();
  88. // Test with conversion enabled (default)
  89. const nWithConversion = new N();
  90. const dataConverted = await nWithConversion.load(
  91. `http://localhost:${port}/test/data/10-float16.npy`
  92. );
  93. assert.ok(dataConverted.data instanceof Float32Array,
  94. "With conversion enabled, should return Float32Array");
  95. // Test with conversion disabled
  96. const nWithoutConversion = new N({ convertFloat16: false });
  97. const dataRaw = await nWithoutConversion.load(
  98. `http://localhost:${port}/test/data/10-float16.npy`
  99. );
  100. assert.ok(dataRaw.data instanceof Uint16Array,
  101. "With conversion disabled, should return Uint16Array");
  102. server.close();
  103. });
  104. });