diff --git a/api/src/main/java/ai/djl/ndarray/NDList.java b/api/src/main/java/ai/djl/ndarray/NDList.java index e48c243a3ec..f0069d3f3f3 100644 --- a/api/src/main/java/ai/djl/ndarray/NDList.java +++ b/api/src/main/java/ai/djl/ndarray/NDList.java @@ -100,12 +100,12 @@ public static NDList decode(NDManager manager, byte[] byteArray) { try { if (byteArray[0] == 'P' && byteArray[1] == 'K') { return decodeNumpy(manager, new ByteArrayInputStream(byteArray)); - } else if (byteArray[0] == (byte) 0x39 + } else if (byteArray[0] == (byte) 0x93 && byteArray[1] == 'N' && byteArray[2] == 'U' && byteArray[3] == 'M') { return new NDList( - NDSerializer.decode(manager, new ByteArrayInputStream(byteArray))); + NDSerializer.decodeNumpy(manager, new ByteArrayInputStream(byteArray))); } else if (byteArray[8] == '{') { return decodeSafetensors(manager, new ByteArrayInputStream(byteArray)); } @@ -144,11 +144,11 @@ public static NDList decode(NDManager manager, InputStream is) { if (magic[0] == 'P' && magic[1] == 'K') { // assume this is npz file return decodeNumpy(manager, pis); - } else if (magic[0] == (byte) 0x39 + } else if (magic[0] == (byte) 0x93 && magic[1] == 'N' && magic[2] == 'U' && magic[3] == 'M') { - return new NDList(NDSerializer.decode(manager, pis)); + return new NDList(NDSerializer.decodeNumpy(manager, pis)); } else if (magic[8] == '{') { return decodeSafetensors(manager, pis); } diff --git a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java index 0e38c2d8be6..e89f2244203 100644 --- a/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java +++ b/api/src/test/java/ai/djl/ndarray/NDSerializerTest.java @@ -107,7 +107,7 @@ private static byte[] encode(NDArray array) throws IOException { private static NDArray decode(NDManager manager, byte[] data) throws IOException { try (ByteArrayInputStream bis = new ByteArrayInputStream(data)) { - return NDSerializer.decodeNumpy(manager, bis); + return NDList.decode(manager, bis).get(0); } }