From fe0a4d3a43289ca0381e5484cc5ff992b729bb00 Mon Sep 17 00:00:00 2001 From: Tal Date: Wed, 8 Nov 2023 14:50:58 +0200 Subject: [PATCH 1/5] Added element-wise gauss error function (ERF) --- api/src/main/java/ai/djl/ndarray/NDArray.java | 16 +++++++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++++ .../main/java/ai/djl/ndarray/NDArrays.java | 18 ++++++++++ .../java/ai/djl/mxnet/engine/MxNDArray.java | 6 ++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 6 ++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 5 +++ .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ ...orch_jni_PyTorchLibrary_torch_pointwise.cc | 8 +++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 6 ++++ .../tests/ndarray/NDArrayOtherOpTest.java | 34 +++++++++++++++++++ 10 files changed, 107 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 385c32e88e3..036a87d1a4a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4922,6 +4922,22 @@ default NDArray countNonzero(int axis) { */ NDArray erfinv(); + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @return The of gauss error of the {@code NDArray}, element-wise + */ + NDArray erf(); + /** {@inheritDoc} */ @Override default List getResourceNDArrays() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index 59047e688c8..e80245c8ac0 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -1188,6 +1188,12 @@ public NDArray erfinv() { return getAlternativeArray().erfinv(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return getAlternativeArray().erf(); + } + /** {@inheritDoc} */ @Override public NDArray inverse() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrays.java b/api/src/main/java/ai/djl/ndarray/NDArrays.java index 304b803939c..dfdb21c7509 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrays.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrays.java @@ -1996,4 +1996,22 @@ public static NDArray logicalXor(NDArray a, NDArray b) { public static NDArray erfinv(NDArray input) { return input.erfinv(); } + + /** + * Returns element-wise gauss error function of the {@code NDArray}. + * + *

Examples + * + *

+     * jshell> NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
+     * jshell> array.erf();
+     * ND: (3) cpu() float32
+     * [0., 0.5, -1]
+     * 
+ * + * @return The of gauss error of the {@code NDArray}, element-wise + */ + public static NDArray erf(NDArray input) { + return input.erf(); + } } diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index 87ccba78e96..dce2fd11075 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -1601,6 +1601,12 @@ public NDArray erfinv() { return manager.invoke("erfinv", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.invoke("erf", this, null); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index 9e36ec35884..ed459ef562e 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -1539,6 +1539,12 @@ public PtNDArray erfinv() { return JniUtils.erfinv(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray erf() { + return JniUtils.erf(this); + } + /** {@inheritDoc} */ @Override public PtNDArray inverse() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index aad38ae8f0c..beb752f64a1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1334,6 +1334,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle())); } + public static PtNDArray erf(PtNDArray ndArray) { + return new PtNDArray( + ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle())); + } + public static PtNDArray inverse(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index c0f7b553ab2..13179ae3ef6 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -405,6 +405,8 @@ native long tensorUniform( native long torchErfinv(long handle); + native long torchErf(long handle); + native long torchInverse(long self); native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index 28e40e916be..740e5309995 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -496,6 +496,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + const auto* result_ptr = new torch::Tensor(tensor_ptr->erf()); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) { API_BEGIN() const auto* self_ptr = reinterpret_cast(jself); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 07c31bacd99..24c0244f1f4 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -457,6 +457,12 @@ public NDArray erfinv() { return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray erf() { + return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray norm(boolean keepDims) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index 6788a405f22..b87135717de 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -875,6 +875,40 @@ public void testErfinv() { } } + @Test + public void testErf() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + // test 1-D + NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY}); + NDArray expected = manager.create(new float[] {0f, 0.5f, -1f}); + Assertions.assertAlmostEquals(NDArrays.erf(array), expected); + // test 3-D + array = + manager.create( + new float[] { + Float.NEGATIVE_INFINITY, + -0.8134f, + -0.4769f, + -0.2253f, + 0f, + 0.2253f, + 0.4769f, + 0.8134f, + Float.POSITIVE_INFINITY + }) + .reshape(3, 1, 3); + expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3); + Assertions.assertAlmostEquals(array.erf(), expected); + // test scalar + array = manager.create(Float.POSITIVE_INFINITY); + expected = manager.create(1f); + Assertions.assertAlmostEquals(array.erf(), expected); + // test zero-dim + array = manager.create(new Shape(2, 0)); + Assertions.assertAlmostEquals(array.erf(), array); + } + } + @Test public void testInverse() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { From a1f34160bfa8bf081158388b31bea34f3ace4abd Mon Sep 17 00:00:00 2001 From: Tal Date: Wed, 8 Nov 2023 16:19:30 +0200 Subject: [PATCH 2/5] Added element-wise arctan2 --- api/src/main/java/ai/djl/ndarray/NDArray.java | 18 +++++++++++ .../java/ai/djl/ndarray/NDArrayAdapter.java | 6 ++++ .../main/java/ai/djl/ndarray/NDArrays.java | 1 + .../java/ai/djl/mxnet/engine/MxNDArray.java | 7 +++++ .../java/ai/djl/pytorch/engine/PtNDArray.java | 6 ++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 5 +++ .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 ++ ...orch_jni_PyTorchLibrary_torch_pointwise.cc | 10 ++++++ .../ai/djl/tensorflow/engine/TfNDArray.java | 6 ++++ .../tests/ndarray/NDArrayNumericOpTest.java | 31 +++++++++++++++++++ 10 files changed, 92 insertions(+) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 036a87d1a4a..4decee2d4b2 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -2344,6 +2344,24 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN */ NDArray atan(); + /** + * Returns the element-wise arc-tangent of this/other choosing the quadrant correctly. + * + *

Examples + * + *

+     * jshell> NDArray x = manager.create(new float[] {0f, 1f});
+     * jshell> NDArray y = manager.create(new float[] {0f, -6f});
+     * jshell> x.atan2(y);
+     * ND: (2) cpu() float64
+     * [0.    , 2.9764]
+     * 
+ * + * @param other The other {@code NDArray} + * @return the result {@code NDArray} + */ + NDArray atan2(NDArray other); + /** * Returns the hyperbolic sine of this {@code NDArray} element-wise. * diff --git a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java index e80245c8ac0..855c7183003 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java @@ -726,6 +726,12 @@ public NDArray atan() { return getAlternativeArray().atan(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return getAlternativeArray().atan2(other); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { diff --git a/api/src/main/java/ai/djl/ndarray/NDArrays.java b/api/src/main/java/ai/djl/ndarray/NDArrays.java index dfdb21c7509..dcc98ee4154 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrays.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrays.java @@ -2009,6 +2009,7 @@ public static NDArray erfinv(NDArray input) { * [0., 0.5, -1] * * + * @param input The input {@code NDArray} * @return The of gauss error of the {@code NDArray}, element-wise */ public static NDArray erf(NDArray input) { diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java index dce2fd11075..19eec837259 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArray.java @@ -888,6 +888,13 @@ public NDArray atan() { return manager.invoke("_npi_arctan", this, null); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + other = manager.from(other); + return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java index ed459ef562e..1e9ac83c173 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArray.java @@ -888,6 +888,12 @@ public PtNDArray atan() { return JniUtils.atan(this); } + /** {@inheritDoc} */ + @Override + public PtNDArray atan2(NDArray other) { + return JniUtils.atan2(this, manager.from(other)); + } + /** {@inheritDoc} */ @Override public PtNDArray sinh() { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index beb752f64a1..5ff5851019c 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1145,6 +1145,11 @@ public static PtNDArray atan(PtNDArray ndArray) { ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle())); } + public static PtNDArray atan2(PtNDArray self, PtNDArray other) { + return new PtNDArray( + self.getManager(), PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); + } + public static PtNDArray sqrt(PtNDArray ndArray) { return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 13179ae3ef6..a1829306d20 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -332,6 +332,8 @@ native long[] torchUnique( native long torchAtan(long handle); + native long torchAtan2(long self, long other); + native long torchSqrt(long handle); native long torchSinh(long handle); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index 740e5309995..37bc2028b98 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv* API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2( +JNIEnv* env, jobject jthis, jlong jself, jlong jother) { + API_BEGIN() + const auto* self_ptr = reinterpret_cast(jself); + const auto* other_ptr = reinterpret_cast(jother); + const auto* result_ptr = new torch::Tensor(self_ptr->atan2(other_ptr)); + return reinterpret_cast(result_ptr); + API_END_RETURN() +} + JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java index 24c0244f1f4..44488537765 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArray.java @@ -917,6 +917,12 @@ public NDArray atan() { return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow(); } + /** {@inheritDoc} */ + @Override + public NDArray atan2(NDArray other) { + return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow(); + } + /** {@inheritDoc} */ @Override public NDArray sinh() { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java index 410b4009a6d..f9f65a50488 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java @@ -22,6 +22,7 @@ import org.testng.annotations.Test; import java.util.stream.DoubleStream; +import java.util.stream.IntStream; public class NDArrayNumericOpTest { @@ -499,6 +500,36 @@ public void testAtan() { } } + @Test + public void testAtan2() { + try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { + double[] x1 = {1.0, -1.0, -1.0, 0.0, 0.0, 0.0}; + NDArray array = manager.create(x1); + double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0}; + NDArray other = manager.create(y1); + double[] output = IntStream.range(0, x1.length).mapToDouble(i->Math.atan2(x1[i], y1[i])).toArray(); + NDArray expected = manager.create(output); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test multi-dim + double[] x2 = {-1.0, -0.5, 0, 0.5, 1.0}; + array = manager.create(x2, new Shape(5, 1)); + double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3}; + other = manager.create(y2, new Shape(5, 1)); + output = IntStream.range(0, x2.length).mapToDouble(i->Math.atan2(x2[i], y2[i])).toArray(); + expected = manager.create(output, new Shape(5, 1)); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test scalar + array = manager.create(0f); + other = manager.create(0f); + expected = manager.create(0f); + Assertions.assertAlmostEquals(array.atan2(other), expected); + // test zero-dim + array = manager.create(new Shape(1, 0)); + other = manager.create(new Shape(1, 0)); + Assert.assertEquals(array.atan2(other), array); + } + } + @Test public void testToDegrees() { try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) { From 1c16ce88a1ae0566fb1d5d81fb6d6fd370d12b43 Mon Sep 17 00:00:00 2001 From: Tal Date: Wed, 8 Nov 2023 17:17:24 +0200 Subject: [PATCH 3/5] Format java --- .../main/java/ai/djl/pytorch/jni/JniUtils.java | 3 ++- .../tests/ndarray/NDArrayNumericOpTest.java | 10 ++++++++-- .../tests/ndarray/NDArrayOtherOpTest.java | 18 +++++++++--------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 5ff5851019c..8e6be7b8d15 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -1147,7 +1147,8 @@ public static PtNDArray atan(PtNDArray ndArray) { public static PtNDArray atan2(PtNDArray self, PtNDArray other) { return new PtNDArray( - self.getManager(), PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); + self.getManager(), + PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle())); } public static PtNDArray sqrt(PtNDArray ndArray) { diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java index f9f65a50488..04779187267 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayNumericOpTest.java @@ -507,7 +507,10 @@ public void testAtan2() { NDArray array = manager.create(x1); double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0}; NDArray other = manager.create(y1); - double[] output = IntStream.range(0, x1.length).mapToDouble(i->Math.atan2(x1[i], y1[i])).toArray(); + double[] output = + IntStream.range(0, x1.length) + .mapToDouble(i -> Math.atan2(x1[i], y1[i])) + .toArray(); NDArray expected = manager.create(output); Assertions.assertAlmostEquals(array.atan2(other), expected); // test multi-dim @@ -515,7 +518,10 @@ public void testAtan2() { array = manager.create(x2, new Shape(5, 1)); double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3}; other = manager.create(y2, new Shape(5, 1)); - output = IntStream.range(0, x2.length).mapToDouble(i->Math.atan2(x2[i], y2[i])).toArray(); + output = + IntStream.range(0, x2.length) + .mapToDouble(i -> Math.atan2(x2[i], y2[i])) + .toArray(); expected = manager.create(output, new Shape(5, 1)); Assertions.assertAlmostEquals(array.atan2(other), expected); // test scalar diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java index b87135717de..00e7465f745 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDArrayOtherOpTest.java @@ -886,15 +886,15 @@ public void testErf() { array = manager.create( new float[] { - Float.NEGATIVE_INFINITY, - -0.8134f, - -0.4769f, - -0.2253f, - 0f, - 0.2253f, - 0.4769f, - 0.8134f, - Float.POSITIVE_INFINITY + Float.NEGATIVE_INFINITY, + -0.8134f, + -0.4769f, + -0.2253f, + 0f, + 0.2253f, + 0.4769f, + 0.8134f, + Float.POSITIVE_INFINITY }) .reshape(3, 1, 3); expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3); From ff4d3b0c9230e19a94533964f3de5060ce3e6617 Mon Sep 17 00:00:00 2001 From: Tal Date: Wed, 8 Nov 2023 17:22:38 +0200 Subject: [PATCH 4/5] Fixed docs --- api/src/main/java/ai/djl/ndarray/NDArray.java | 2 +- api/src/main/java/ai/djl/ndarray/NDArrays.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index 4decee2d4b2..23d452be367 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -4952,7 +4952,7 @@ default NDArray countNonzero(int axis) { * [0., 0.5, -1] * * - * @return The of gauss error of the {@code NDArray}, element-wise + * @return The gauss error of the {@code NDArray}, element-wise */ NDArray erf(); diff --git a/api/src/main/java/ai/djl/ndarray/NDArrays.java b/api/src/main/java/ai/djl/ndarray/NDArrays.java index dcc98ee4154..0e1c0922a7b 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArrays.java +++ b/api/src/main/java/ai/djl/ndarray/NDArrays.java @@ -2010,7 +2010,7 @@ public static NDArray erfinv(NDArray input) { * * * @param input The input {@code NDArray} - * @return The of gauss error of the {@code NDArray}, element-wise + * @return The gauss error of the {@code NDArray}, element-wise */ public static NDArray erf(NDArray input) { return input.erf(); From 7bf4063b1b08b5e037acb8767f5eecd25c01d167 Mon Sep 17 00:00:00 2001 From: Tal Date: Wed, 8 Nov 2023 23:57:20 +0200 Subject: [PATCH 5/5] added * to other_ptr in Atan2 --- .../native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc index 37bc2028b98..ccf2616dc65 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_torch_pointwise.cc @@ -360,7 +360,7 @@ JNIEnv* env, jobject jthis, jlong jself, jlong jother) { API_BEGIN() const auto* self_ptr = reinterpret_cast(jself); const auto* other_ptr = reinterpret_cast(jother); - const auto* result_ptr = new torch::Tensor(self_ptr->atan2(other_ptr)); + const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr)); return reinterpret_cast(result_ptr); API_END_RETURN() }