diff --git a/src/NumSharp.Core/APIs/np.load.cs b/src/NumSharp.Core/APIs/np.load.cs index b74addeb..f8790a14 100644 --- a/src/NumSharp.Core/APIs/np.load.cs +++ b/src/NumSharp.Core/APIs/np.load.cs @@ -33,8 +33,7 @@ public static NDArray load(Stream stream) if (!parseReader(reader, out bytes, out type, out shape)) throw new FormatException(); - Array array = Arrays.Create(type, shape.Aggregate((dims, dim) => dims * dim)); - + Array array = Arrays.Create(type, shape.Aggregate(1, (dims, dim) => dims * dim)); var result = new NDArray(readValueMatrix(reader, array, bytes, type, shape)); return result.reshape(shape); } @@ -165,6 +164,10 @@ public static Array LoadMatrix(Stream stream) int[] shape; if (!parseReader(reader, out bytes, out type, out shape)) throw new FormatException(); + + // Read scalar as a single element array + if (shape.Length == 0) + shape = new int[] { 1 }; Array matrix = Arrays.Create(type, shape); @@ -188,6 +191,10 @@ public static Array LoadJagged(Stream stream, bool trim = true) int[] shape; if (!parseReader(reader, out bytes, out type, out shape)) throw new FormatException(); + + // Read scalar as a single element array + if (shape.Length == 0) + shape = new int[] { 1 }; Array matrix = Arrays.Create(type, shape); @@ -357,7 +364,7 @@ private static bool parseReader(BinaryReader reader, out int bytes, out Type t, mark = "'shape': ("; s = header.IndexOf(mark) + mark.Length; - e = header.IndexOf(")", s + 1); + e = header.IndexOf(")", s); shape = header.Substring(s, e - s).Split(',').Where(v => !String.IsNullOrEmpty(v)).Select(Int32.Parse).ToArray(); return true; diff --git a/test/NumSharp.UnitTest/APIs/np.load.Test.cs b/test/NumSharp.UnitTest/APIs/np.load.Test.cs index 190c5803..a0118e7f 100644 --- a/test/NumSharp.UnitTest/APIs/np.load.Test.cs +++ b/test/NumSharp.UnitTest/APIs/np.load.Test.cs @@ -19,16 +19,6 @@ public void NumpyLoadTest() int[] b = np.Load(mem); } - [TestMethod] - public void NumpyLoad1DimTest() - { - int[] arr = np.Load(@"data/1-dim-int32_4_comma_empty.npy"); - Assert.IsTrue(arr[0] == 0); - Assert.IsTrue(arr[1] == 1); - Assert.IsTrue(arr[2] == 2); - Assert.IsTrue(arr[3] == 3); - } - [TestMethod] public void NumpyNPZRoundTripTest() { @@ -42,5 +32,71 @@ public void NumpyNPZRoundTripTest() var d2 = np.Load_Npz(ms); Assert.IsTrue(d2.Count == 2); } + + [TestMethod] + [DataRow(@"data/arange_f4_le.npy")] + [DataRow(@"data/arange_f8_le.npy")] + [DataRow(@"data/arange_i1.npy")] + [DataRow(@"data/arange_i2_le.npy")] + [DataRow(@"data/arange_i4_le.npy")] + [DataRow(@"data/arange_i8_le.npy")] + [DataRow(@"data/arange_u1.npy")] + [DataRow(@"data/arange_u2_le.npy")] + [DataRow(@"data/arange_u4_le.npy")] + [DataRow(@"data/arange_u8_le.npy")] + public void load_Arange(string path) + { + NDArray arr = np.load(path); + + for (int i = 0; i < arr.shape[0]; ++i) + { + int value = (int)Convert.ChangeType(arr.GetValue(i), typeof(int)); + Assert.AreEqual(i, value); + } + } + + [TestMethod] + [DataRow(@"data/hello_S5.npy")] + public void load_HelloWorld(string path) + { + string[] arr = np.Load(path); + Assert.AreEqual("Hello", arr[0]); + Assert.AreEqual("World", arr[1]); + } + + [TestMethod] + [DataRow(@"data/mgrid_i4.npy")] + public void load_Mgrid(string path) + { + NDArray arr = np.load(path); + + for (int i = 0; i < arr.shape[0]; i++) + { + for (int j = 0; j < arr.shape[1]; j++) + { + Assert.AreEqual(i, (int)arr.GetValue(0, i, j)); + Assert.AreEqual(j, (int)arr.GetValue(1, i, j)); + } + } + } + + [TestMethod] + [DataRow(@"data/scalar_b1.npy", false)] + [DataRow(@"data/scalar_i4_le.npy", 42)] + public void load_Scalar(string path, object expected) + { + NDArray arr = np.load(path); + Assert.AreEqual(Shape.Scalar, arr.shape); + Assert.AreEqual(expected, arr.GetValue(0)); + } + + [TestMethod] + [DataRow(@"data/scalar_b1.npy", false)] + [DataRow(@"data/scalar_i4_le.npy", 42)] + public void LoadMatrix_Scalar(string path, object expected) + { + Array arr = np.LoadMatrix(path); + Assert.AreEqual(expected, arr.GetValue(0)); + } } } diff --git a/test/NumSharp.UnitTest/data/arange_f4_le.npy b/test/NumSharp.UnitTest/data/arange_f4_le.npy new file mode 100644 index 00000000..86f46753 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_f4_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_f8_le.npy b/test/NumSharp.UnitTest/data/arange_f8_le.npy new file mode 100644 index 00000000..724acbe4 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_f8_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_i1.npy b/test/NumSharp.UnitTest/data/arange_i1.npy new file mode 100644 index 00000000..a61d940c Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_i1.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_i2_le.npy b/test/NumSharp.UnitTest/data/arange_i2_le.npy new file mode 100644 index 00000000..837037d4 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_i2_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_i4_le.npy b/test/NumSharp.UnitTest/data/arange_i4_le.npy new file mode 100644 index 00000000..a95c6a63 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_i4_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_i8_le.npy b/test/NumSharp.UnitTest/data/arange_i8_le.npy new file mode 100644 index 00000000..a3ff5af9 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_i8_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_u1.npy b/test/NumSharp.UnitTest/data/arange_u1.npy new file mode 100644 index 00000000..69a32813 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_u1.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_u2_le.npy b/test/NumSharp.UnitTest/data/arange_u2_le.npy new file mode 100644 index 00000000..4d16a3e3 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_u2_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_u4_le.npy b/test/NumSharp.UnitTest/data/arange_u4_le.npy new file mode 100644 index 00000000..55a3bee4 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_u4_le.npy differ diff --git a/test/NumSharp.UnitTest/data/arange_u8_le.npy b/test/NumSharp.UnitTest/data/arange_u8_le.npy new file mode 100644 index 00000000..b41bc956 Binary files /dev/null and b/test/NumSharp.UnitTest/data/arange_u8_le.npy differ diff --git a/test/NumSharp.UnitTest/data/hello_S5.npy b/test/NumSharp.UnitTest/data/hello_S5.npy new file mode 100644 index 00000000..b5032f02 Binary files /dev/null and b/test/NumSharp.UnitTest/data/hello_S5.npy differ diff --git a/test/NumSharp.UnitTest/data/mgrid_i4.npy b/test/NumSharp.UnitTest/data/mgrid_i4.npy new file mode 100644 index 00000000..81ff6138 Binary files /dev/null and b/test/NumSharp.UnitTest/data/mgrid_i4.npy differ diff --git a/test/NumSharp.UnitTest/data/scalar_S1.npy b/test/NumSharp.UnitTest/data/scalar_S1.npy new file mode 100644 index 00000000..2c4ad0b5 Binary files /dev/null and b/test/NumSharp.UnitTest/data/scalar_S1.npy differ diff --git a/test/NumSharp.UnitTest/data/scalar_b1.npy b/test/NumSharp.UnitTest/data/scalar_b1.npy new file mode 100644 index 00000000..185568c8 Binary files /dev/null and b/test/NumSharp.UnitTest/data/scalar_b1.npy differ diff --git a/test/NumSharp.UnitTest/data/scalar_i4_le.npy b/test/NumSharp.UnitTest/data/scalar_i4_le.npy new file mode 100644 index 00000000..4faa14c5 Binary files /dev/null and b/test/NumSharp.UnitTest/data/scalar_i4_le.npy differ