Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add erf and atan2 #2842

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,24 @@ default boolean allClose(NDArray other, double rtol, double atol, boolean equalN
*/
NDArray atan();

/**
* Returns the element-wise arc-tangent of this/other choosing the quadrant correctly.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray x = manager.create(new float[] {0f, 1f});
* jshell&gt; NDArray y = manager.create(new float[] {0f, -6f});
* jshell&gt; x.atan2(y);
* ND: (2) cpu() float64
* [0. , 2.9764]
* </pre>
*
* @param other The other {@code NDArray}
* @return the result {@code NDArray}
*/
NDArray atan2(NDArray other);

/**
* Returns the hyperbolic sine of this {@code NDArray} element-wise.
*
Expand Down Expand Up @@ -4922,6 +4940,22 @@ default NDArray countNonzero(int axis) {
*/
NDArray erfinv();

/**
* Returns element-wise gauss error function of the {@code NDArray}.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
* jshell&gt; array.erf();
* ND: (3) cpu() float32
* [0., 0.5, -1]
* </pre>
*
* @return The gauss error of the {@code NDArray}, element-wise
*/
NDArray erf();

/** {@inheritDoc} */
@Override
default List<NDArray> getResourceNDArrays() {
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,12 @@ public NDArray atan() {
return getAlternativeArray().atan();
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
return getAlternativeArray().atan2(other);
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down Expand Up @@ -1188,6 +1194,12 @@ public NDArray erfinv() {
return getAlternativeArray().erfinv();
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return getAlternativeArray().erf();
}

/** {@inheritDoc} */
@Override
public NDArray inverse() {
Expand Down
19 changes: 19 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrays.java
Original file line number Diff line number Diff line change
Expand Up @@ -1996,4 +1996,23 @@ public static NDArray logicalXor(NDArray a, NDArray b) {
public static NDArray erfinv(NDArray input) {
return input.erfinv();
}

/**
* Returns element-wise gauss error function of the {@code NDArray}.
*
* <p>Examples
*
* <pre>
* jshell&gt; NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
* jshell&gt; array.erf();
* ND: (3) cpu() float32
* [0., 0.5, -1]
* </pre>
*
* @param input The input {@code NDArray}
* @return The gauss error of the {@code NDArray}, element-wise
*/
public static NDArray erf(NDArray input) {
return input.erf();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,13 @@ public NDArray atan() {
return manager.invoke("_npi_arctan", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
other = manager.from(other);
return manager.invoke("_npi_arctan2", new NDArray[] {this, other}, null);
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down Expand Up @@ -1601,6 +1608,12 @@ public NDArray erfinv() {
return manager.invoke("erfinv", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return manager.invoke("erf", this, null);
}

/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,12 @@ public PtNDArray atan() {
return JniUtils.atan(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray atan2(NDArray other) {
return JniUtils.atan2(this, manager.from(other));
}

/** {@inheritDoc} */
@Override
public PtNDArray sinh() {
Expand Down Expand Up @@ -1539,6 +1545,12 @@ public PtNDArray erfinv() {
return JniUtils.erfinv(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray erf() {
return JniUtils.erf(this);
}

/** {@inheritDoc} */
@Override
public PtNDArray inverse() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,12 @@ public static PtNDArray atan(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchAtan(ndArray.getHandle()));
}

public static PtNDArray atan2(PtNDArray self, PtNDArray other) {
return new PtNDArray(
self.getManager(),
PyTorchLibrary.LIB.torchAtan2(self.getHandle(), other.getHandle()));
}

public static PtNDArray sqrt(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchSqrt(ndArray.getHandle()));
Expand Down Expand Up @@ -1334,6 +1340,11 @@ public static PtNDArray erfinv(PtNDArray ndArray) {
ndArray.getManager(), PyTorchLibrary.LIB.torchErfinv(ndArray.getHandle()));
}

public static PtNDArray erf(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchErf(ndArray.getHandle()));
}

public static PtNDArray inverse(PtNDArray ndArray) {
return new PtNDArray(
ndArray.getManager(), PyTorchLibrary.LIB.torchInverse(ndArray.getHandle()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ native long[] torchUnique(

native long torchAtan(long handle);

native long torchAtan2(long self, long other);

native long torchSqrt(long handle);

native long torchSinh(long handle);
Expand Down Expand Up @@ -405,6 +407,8 @@ native long tensorUniform(

native long torchErfinv(long handle);

native long torchErf(long handle);

native long torchInverse(long self);

native long torchNNInterpolate(long handle, long[] size, int mode, boolean alignCorners);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan(JNIEnv*
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchAtan2(
JNIEnv* env, jobject jthis, jlong jself, jlong jother) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
const auto* other_ptr = reinterpret_cast<torch::Tensor*>(jother);
const auto* result_ptr = new torch::Tensor(self_ptr->atan2(*other_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchSqrt(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
Expand Down Expand Up @@ -496,6 +506,14 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErfinv(JNIEn
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchErf(JNIEnv* env, jobject jthis, jlong jhandle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->erf());
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchInverse(JNIEnv* env, jobject jthis, jlong jself) {
API_BEGIN()
const auto* self_ptr = reinterpret_cast<torch::Tensor*>(jself);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,12 @@ public NDArray erfinv() {
return manager.opExecutor("Erfinv").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray erf() {
return manager.opExecutor("Erf").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray norm(boolean keepDims) {
Expand Down Expand Up @@ -911,6 +917,12 @@ public NDArray atan() {
return manager.opExecutor("Atan").addInput(this).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray atan2(NDArray other) {
return manager.opExecutor("Atan2").addInput(this).addInput(other).buildSingletonOrThrow();
}

/** {@inheritDoc} */
@Override
public NDArray sinh() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.testng.annotations.Test;

import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

public class NDArrayNumericOpTest {

Expand Down Expand Up @@ -499,6 +500,42 @@ public void testAtan() {
}
}

@Test
public void testAtan2() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
double[] x1 = {1.0, -1.0, -1.0, 0.0, 0.0, 0.0};
NDArray array = manager.create(x1);
double[] y1 = {1.0, 0.0, -1.0, 1.0, -1.0, 0.0};
NDArray other = manager.create(y1);
double[] output =
IntStream.range(0, x1.length)
.mapToDouble(i -> Math.atan2(x1[i], y1[i]))
.toArray();
NDArray expected = manager.create(output);
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test multi-dim
double[] x2 = {-1.0, -0.5, 0, 0.5, 1.0};
array = manager.create(x2, new Shape(5, 1));
double[] y2 = {-2.0, 3.0, 6.0, 0.0, -0.3};
other = manager.create(y2, new Shape(5, 1));
output =
IntStream.range(0, x2.length)
.mapToDouble(i -> Math.atan2(x2[i], y2[i]))
.toArray();
expected = manager.create(output, new Shape(5, 1));
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test scalar
array = manager.create(0f);
other = manager.create(0f);
expected = manager.create(0f);
Assertions.assertAlmostEquals(array.atan2(other), expected);
// test zero-dim
array = manager.create(new Shape(1, 0));
other = manager.create(new Shape(1, 0));
Assert.assertEquals(array.atan2(other), array);
}
}

@Test
public void testToDegrees() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,40 @@ public void testErfinv() {
}
}

@Test
public void testErf() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
// test 1-D
NDArray array = manager.create(new float[] {0f, 0.4769f, Float.NEGATIVE_INFINITY});
NDArray expected = manager.create(new float[] {0f, 0.5f, -1f});
Assertions.assertAlmostEquals(NDArrays.erf(array), expected);
// test 3-D
array =
manager.create(
new float[] {
Float.NEGATIVE_INFINITY,
-0.8134f,
-0.4769f,
-0.2253f,
0f,
0.2253f,
0.4769f,
0.8134f,
Float.POSITIVE_INFINITY
})
.reshape(3, 1, 3);
expected = manager.linspace(-1.0f, 1.0f, 9).reshape(3, 1, 3);
Assertions.assertAlmostEquals(array.erf(), expected);
// test scalar
array = manager.create(Float.POSITIVE_INFINITY);
expected = manager.create(1f);
Assertions.assertAlmostEquals(array.erf(), expected);
// test zero-dim
array = manager.create(new Shape(2, 0));
Assertions.assertAlmostEquals(array.erf(), array);
}
}

@Test
public void testInverse() {
try (NDManager manager = NDManager.newBaseManager(TestUtils.getEngine())) {
Expand Down
Loading