Skip to content

Commit

Permalink
Add erf and atan2 (#2842)
Browse files Browse the repository at this point in the history
* Added element-wise gauss error function (ERF)

* Added element-wise arctan2

* Format java

* Fixed docs

* added * to other_ptr in Atan2
  • Loading branch information
TalGrbr authored and frankfliu committed Apr 26, 2024
1 parent 425d2d9 commit 3114ee1
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 0 deletions.
34 changes: 34 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,24 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN
*/
NDArray atan();

/**
* Returns the element-wise arc-tangent of this/other choosing the quadrant correctly.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray x = manager.create(new float[] {0f, 1f});
* jshell&gt; NDArray y = manager.create(new float[] {0f, -6f});
* jshell&gt; x.atan2(y);
* ND: (2) cpu() float64
* [0. , 2.9764]
* </pre>
*
* @param other The other {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray atan2(NDArray other);

/**
* Returns the hyperbolic sine of this {@code NDArray} element-wise.
*
Expand Down Expand Up @@ -4922,6 +4940,22 @@ default NDArray countNonzero(int axis) {
*/
NDArray erfinv();

/**
* Returns element-wise gauss error function of the {@code NDArray}.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
* jshell&gt; array.erf();
* ND: (3) cpu() float32
* [0., 0.5, -1]
* </pre>
*
* @return The gauss error of the {@code NDArray}, element-wise
*/
NDArray erf();

/** {@inheritDoc} */
@Override
default List<NDArray> getResourceNDArrays() {
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,12 @@ public NDArray atan() {
return getAlternativeArray().atan();
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
return getAlternativeArray().atan2(other);
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down Expand Up @@ -1188,6 +1194,12 @@ public NDArray erfinv() {
return getAlternativeArray().erfinv();
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return getAlternativeArray().erf();
}

/** {@inheritDoc} */
@Override
public NDArray inverse() {
Expand Down
19 changes: 19 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -1996,4 +1996,23 @@ public static NDArray logicalXor(NDArray a, NDArray b) {
public static NDArray erfinv(NDArray input) {
return input.erfinv();
}

/**
* Returns element-wise gauss error function of the {@code NDArray}.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
* jshell&gt; array.erf();
* ND: (3) cpu() float32
* [0., 0.5, -1]
* </pre>
*
* @param input The input {@code NDArray}
* @return The gauss error of the {@code NDArray}, element-wise
*/
public static NDArray erf(NDArray input) {
return input.erf();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,13 @@ public NDArray atan() {
return manager.invoke("_npi_arctan", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
other = manager.from(other);
return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null);
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down Expand Up @@ -1601,6 +1608,12 @@ public NDArray erfinv() {
return manager.invoke("erfinv", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return manager.invoke("erf", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,12 @@ public PtNDArray atan() {
return JniUtils.atan(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray atan2(NDArray other) {
return JniUtils.atan2(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public PtNDArray sinh() {
Expand Down Expand Up @@ -1539,6 +1545,12 @@ public PtNDArray erfinv() {
return JniUtils.erfinv(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray erf() {
return JniUtils.erf(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray inverse() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,12 @@ public static PtNDArray atan(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle()));
}

public static PtNDArray atan2(PtNDArray self, PtNDArray other) {
return new PtNDArray(
self.getManager(),
PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle()));
}

public static PtNDArray sqrt(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle()));
Expand Down Expand Up @@ -1334,6 +1340,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle()));
}

public static PtNDArray erf(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle()));
}

public static PtNDArray inverse(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ native long[] torchUnique(

native long torchAtan(long handle);

native long torchAtan2(long self, long other);

native long torchSqrt(long handle);

native long torchSinh(long handle);
Expand Down Expand Up @@ -405,6 +407,8 @@ native long tensorUniform(

native long torchErfinv(long handle);

native long torchErf(long handle);

native long torchInverse(long self);

native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv*
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2(
JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
const auto* other_ptr = reinterpret_cast<torch::Tensor*>(jother);
const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
Expand Down Expand Up @@ -496,6 +506,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->erf());
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ public NDArray erfinv() {
return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
Expand Down Expand Up @@ -911,6 +917,12 @@ public NDArray atan() {
return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.testng.annotations.Test;

import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

public class NDArrayNumericOpTest {

Expand Down Expand Up @@ -499,6 +500,42 @@ public void testAtan() {
}
}

@Test
public void testAtan2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
double[] x1 = {1.0, -1.0, -1.0, 0.0, 0.0, 0.0};
NDArray array = manager.create(x1);
double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0};
NDArray other = manager.create(y1);
double[] output =
IntStream.range(0, x1.length)
.mapToDouble(i -> Math.atan2(x1[i], y1[i]))
.toArray();
NDArray expected = manager.create(output);
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test multi-dim
double[] x2 = {-1.0, -0.5, 0, 0.5, 1.0};
array = manager.create(x2, new Shape(5, 1));
double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3};
other = manager.create(y2, new Shape(5, 1));
output =
IntStream.range(0, x2.length)
.mapToDouble(i -> Math.atan2(x2[i], y2[i]))
.toArray();
expected = manager.create(output, new Shape(5, 1));
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test scalar
array = manager.create(0f);
other = manager.create(0f);
expected = manager.create(0f);
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test zero-dim
array = manager.create(new Shape(1, 0));
other = manager.create(new Shape(1, 0));
Assert.assertEquals(array.atan2(other), array);
}
}

@Test
public void testToDegrees() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,40 @@ public void testErfinv() {
}
}

@Test
public void testErf() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
// test 1-D
NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
NDArray expected = manager.create(new float[] {0f, 0.5f, -1f});
Assertions.assertAlmostEquals(NDArrays.erf(array), expected);
// test 3-D
array =
manager.create(
new float[] {
Float.NEGATIVE_INFINITY,
-0.8134f,
-0.4769f,
-0.2253f,
0f,
0.2253f,
0.4769f,
0.8134f,
Float.POSITIVE_INFINITY
})
.reshape(3, 1, 3);
expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3);
Assertions.assertAlmostEquals(array.erf(), expected);
// test scalar
array = manager.create(Float.POSITIVE_INFINITY);
expected = manager.create(1f);
Assertions.assertAlmostEquals(array.erf(), expected);
// test zero-dim
array = manager.create(new Shape(2, 0));
Assertions.assertAlmostEquals(array.erf(), array);
}
}

@Test
public void testInverse() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down

0 comments on commit 3114ee1

Please sign in to comment.