-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: ExecuTorch bindings for android (#38)
## Description Added TurboModule(ETModule) which let's user use ExecuTorch Module methods such as: - loadMethod - loadModule - loadForward - forward ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
- Loading branch information
1 parent
8ff6076
commit e91ce11
Showing
17 changed files
with
480 additions
and
59 deletions.
There are no files selected for viewing
Binary file not shown.
97 changes: 97 additions & 0 deletions
97
android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package com.swmansion.rnexecutorch | ||
|
||
import com.facebook.react.bridge.Promise | ||
import com.facebook.react.bridge.ReactApplicationContext | ||
import com.facebook.react.bridge.ReadableArray | ||
import com.swmansion.rnexecutorch.utils.ArrayUtils | ||
import com.swmansion.rnexecutorch.utils.Fetcher | ||
import com.swmansion.rnexecutorch.utils.ProgressResponseBody | ||
import com.swmansion.rnexecutorch.utils.ResourceType | ||
import com.swmansion.rnexecutorch.utils.TensorUtils | ||
import okhttp3.OkHttpClient | ||
import org.pytorch.executorch.Module | ||
import org.pytorch.executorch.Tensor | ||
import java.net.URL | ||
|
||
class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) { | ||
private lateinit var module: Module | ||
private val client = OkHttpClient() | ||
|
||
override fun getName(): String { | ||
return NAME | ||
} | ||
|
||
private fun downloadModel( | ||
url: URL, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit | ||
) { | ||
Fetcher.downloadResource(reactApplicationContext, | ||
client, | ||
url, | ||
resourceType, | ||
{ path, error -> callback(path, error) }, | ||
object : ProgressResponseBody.ProgressListener { | ||
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) { | ||
} | ||
}) | ||
} | ||
|
||
override fun loadModule(modelPath: String, promise: Promise) { | ||
try { | ||
downloadModel( | ||
URL(modelPath), ResourceType.MODEL | ||
) { path, error -> | ||
if (error != null) { | ||
promise.reject(error.message!!, "-1") | ||
return@downloadModel | ||
} | ||
|
||
module = Module.load(path) | ||
promise.resolve(0) | ||
return@downloadModel | ||
} | ||
} catch (e: Exception) { | ||
promise.reject(e.message!!, "-1") | ||
} | ||
} | ||
|
||
override fun loadMethod(methodName: String, promise: Promise) { | ||
val result = module.loadMethod(methodName) | ||
if (result != 0) { | ||
promise.reject("Method loading failed", result.toString()) | ||
return | ||
} | ||
|
||
promise.resolve(result) | ||
} | ||
|
||
override fun forward( | ||
input: ReadableArray, | ||
shape: ReadableArray, | ||
inputType: Double, | ||
promise: Promise | ||
) { | ||
try { | ||
val executorchInput = | ||
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt()) | ||
|
||
lateinit var result: Tensor | ||
module.forward(executorchInput)[0].toTensor().also { result = it } | ||
|
||
promise.resolve(ArrayUtils.createReadableArray(result)) | ||
return | ||
} catch (e: IllegalArgumentException) { | ||
//The error is thrown when transformation to Tensor fails | ||
promise.reject("Forward Failed Execution", "18") | ||
return | ||
} catch (e: Exception) { | ||
//Executorch forward method throws an exception with a message: "Method forward failed with code XX" | ||
val exceptionCode = e.message!!.substring(e.message!!.length - 2) | ||
promise.reject("Forward Failed Execution", exceptionCode) | ||
return | ||
} | ||
} | ||
|
||
companion object { | ||
const val NAME = "ETModule" | ||
} | ||
} |
14 changes: 8 additions & 6 deletions
14
android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchModule.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
package com.swmansion.rnexecutorch.utils | ||
|
||
import com.facebook.react.bridge.Arguments | ||
import com.facebook.react.bridge.ReadableArray | ||
import org.pytorch.executorch.DType | ||
import org.pytorch.executorch.Tensor | ||
|
||
class ArrayUtils { | ||
companion object { | ||
fun createByteArray(input: ReadableArray): ByteArray { | ||
val byteArray = ByteArray(input.size()) | ||
for (i in 0 until input.size()) { | ||
byteArray[i] = input.getInt(i).toByte() | ||
} | ||
return byteArray | ||
} | ||
|
||
fun createIntArray(input: ReadableArray): IntArray { | ||
val intArray = IntArray(input.size()) | ||
for (i in 0 until input.size()) { | ||
intArray[i] = input.getInt(i) | ||
} | ||
return intArray | ||
} | ||
|
||
fun createFloatArray(input: ReadableArray): FloatArray { | ||
val floatArray = FloatArray(input.size()) | ||
for (i in 0 until input.size()) { | ||
floatArray[i] = input.getDouble(i).toFloat() | ||
} | ||
return floatArray | ||
} | ||
|
||
fun createLongArray(input: ReadableArray): LongArray { | ||
val longArray = LongArray(input.size()) | ||
for (i in 0 until input.size()) { | ||
longArray[i] = input.getInt(i).toLong() | ||
} | ||
return longArray | ||
} | ||
|
||
fun createDoubleArray(input: ReadableArray): DoubleArray { | ||
val doubleArray = DoubleArray(input.size()) | ||
for (i in 0 until input.size()) { | ||
doubleArray[i] = input.getDouble(i) | ||
} | ||
return doubleArray | ||
} | ||
|
||
fun createReadableArray(result: Tensor): ReadableArray { | ||
val resultArray = Arguments.createArray() | ||
when (result.dtype()) { | ||
DType.UINT8 -> { | ||
val byteArray = result.dataAsByteArray | ||
for (i in byteArray) { | ||
resultArray.pushInt(i.toInt()) | ||
} | ||
} | ||
|
||
DType.INT32 -> { | ||
val intArray = result.dataAsIntArray | ||
for (i in intArray) { | ||
resultArray.pushInt(i) | ||
} | ||
} | ||
|
||
DType.FLOAT -> { | ||
val longArray = result.dataAsFloatArray | ||
for (i in longArray) { | ||
resultArray.pushDouble(i.toDouble()) | ||
} | ||
} | ||
|
||
DType.DOUBLE -> { | ||
val floatArray = result.dataAsDoubleArray | ||
for (i in floatArray) { | ||
resultArray.pushDouble(i) | ||
} | ||
} | ||
|
||
DType.INT64 -> { | ||
val doubleArray = result.dataAsLongArray | ||
for (i in doubleArray) { | ||
resultArray.pushLong(i) | ||
} | ||
} | ||
|
||
else -> { | ||
throw IllegalArgumentException("Invalid dtype: ${result.dtype()}") | ||
} | ||
} | ||
|
||
return resultArray | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...sion/rnexecutorch/ProgressResponseBody.kt → ...nexecutorch/utils/ProgressResponseBody.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
package com.swmansion.rnexecutorch.utils | ||
|
||
import com.facebook.react.bridge.ReadableArray | ||
import org.pytorch.executorch.EValue | ||
import org.pytorch.executorch.Tensor | ||
|
||
class TensorUtils { | ||
companion object { | ||
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue { | ||
try { | ||
when (type) { | ||
0 -> { | ||
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape) | ||
return EValue.from(inputTensor) | ||
} | ||
|
||
1 -> { | ||
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape) | ||
return EValue.from(inputTensor) | ||
} | ||
|
||
2 -> { | ||
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape) | ||
return EValue.from(inputTensor) | ||
} | ||
|
||
3 -> { | ||
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape) | ||
return EValue.from(inputTensor) | ||
} | ||
|
||
4 -> { | ||
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape) | ||
return EValue.from(inputTensor) | ||
} | ||
|
||
else -> { | ||
throw IllegalArgumentException("Invalid input type: $type") | ||
} | ||
} | ||
} catch (e: IllegalArgumentException) { | ||
throw e | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.