Skip to content

Commit

Permalink
feat: ExecuTorch bindings for android (#38)
Browse files Browse the repository at this point in the history
## Description
Added TurboModule(ETModule) which let's user use ExecuTorch Module
methods such as:
- loadMethod
- loadModule
- loadForward
- forward

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [ ] iOS
- [x] Android

### Testing instructions
<!-- Provide step-by-step instructions on how to test your changes.
Include setup details if necessary. -->

### Screenshots
<!-- Add screenshots here, if applicable -->

### Related issues
<!-- Link related issues here using #issue-number -->

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [x] My changes generate no new warnings

### Additional notes
<!-- Include any additional information, assumptions, or context that
reviewers might need to understand this PR. -->
  • Loading branch information
NorbertKlockiewicz authored Nov 27, 2024
1 parent 8ff6076 commit e91ce11
Show file tree
Hide file tree
Showing 17 changed files with 480 additions and 59 deletions.
Binary file modified android/libs/executorch-llama.aar
Binary file not shown.
97 changes: 97 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.TensorUtils
import okhttp3.OkHttpClient
import org.pytorch.executorch.Module
import org.pytorch.executorch.Tensor
import java.net.URL

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module
private val client = OkHttpClient()

override fun getName(): String {
return NAME
}

private fun downloadModel(
url: URL, resourceType: ResourceType, callback: (path: String?, error: Exception?) -> Unit
) {
Fetcher.downloadResource(reactApplicationContext,
client,
url,
resourceType,
{ path, error -> callback(path, error) },
object : ProgressResponseBody.ProgressListener {
override fun onProgress(bytesRead: Long, contentLength: Long, done: Boolean) {
}
})
}

override fun loadModule(modelPath: String, promise: Promise) {
try {
downloadModel(
URL(modelPath), ResourceType.MODEL
) { path, error ->
if (error != null) {
promise.reject(error.message!!, "-1")
return@downloadModel
}

module = Module.load(path)
promise.resolve(0)
return@downloadModel
}
} catch (e: Exception) {
promise.reject(e.message!!, "-1")
}
}

override fun loadMethod(methodName: String, promise: Promise) {
val result = module.loadMethod(methodName)
if (result != 0) {
promise.reject("Method loading failed", result.toString())
return
}

promise.resolve(result)
}

override fun forward(
input: ReadableArray,
shape: ReadableArray,
inputType: Double,
promise: Promise
) {
try {
val executorchInput =
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())

lateinit var result: Tensor
module.forward(executorchInput)[0].toTensor().also { result = it }

promise.resolve(ArrayUtils.createReadableArray(result))
return
} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", "18")
return
} catch (e: Exception) {
//Executorch forward method throws an exception with a message: "Method forward failed with code XX"
val exceptionCode = e.message!!.substring(e.message!!.length - 2)
promise.reject("Forward Failed Execution", exceptionCode)
return
}
}

companion object {
const val NAME = "ETModule"
}
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.ReactApplicationContext
import android.os.Build
import android.util.Log
import androidx.annotation.RequiresApi
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactContextBaseJavaModule
import com.facebook.react.bridge.ReactMethod
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ProgressResponseBody
import com.swmansion.rnexecutorch.utils.ResourceType
import com.swmansion.rnexecutorch.utils.llms.ChatRole
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
import okhttp3.OkHttpClient
import okhttp3.Request
import org.pytorch.executorch.LlamaModule
import org.pytorch.executorch.LlamaCallback
import java.io.File
import org.pytorch.executorch.LlamaModule
import java.net.URL

class RnExecutorchModule(reactContext: ReactApplicationContext) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,40 @@ import com.facebook.react.module.model.ReactModuleInfo
import com.facebook.react.module.model.ReactModuleInfoProvider
import com.facebook.react.uimanager.ViewManager


class RnExecutorchPackage : TurboReactPackage() {
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
return listOf()
}

override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
if (name == RnExecutorchModule.NAME) {
RnExecutorchModule(reactContext)
} else {
null
}
override fun getModule(name: String, reactContext: ReactApplicationContext): NativeModule? =
if (name == RnExecutorchModule.NAME) {
RnExecutorchModule(reactContext)
} else if (name == ETModule.NAME) {
ETModule(reactContext)
} else {
null
}

override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
return ReactModuleInfoProvider {
override fun getReactModuleInfoProvider(): ReactModuleInfoProvider {
return ReactModuleInfoProvider {
val moduleInfos: MutableMap<String, ReactModuleInfo> = HashMap()
moduleInfos[RnExecutorchModule.NAME] = ReactModuleInfo(
RnExecutorchModule.NAME,
RnExecutorchModule.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
true, // hasConstants
false, // isCxxModule
true // isTurboModule
true,
)
moduleInfos[ETModule.NAME] = ReactModuleInfo(
ETModule.NAME,
ETModule.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReadableArray
import org.pytorch.executorch.DType
import org.pytorch.executorch.Tensor

class ArrayUtils {
companion object {
fun createByteArray(input: ReadableArray): ByteArray {
val byteArray = ByteArray(input.size())
for (i in 0 until input.size()) {
byteArray[i] = input.getInt(i).toByte()
}
return byteArray
}

fun createIntArray(input: ReadableArray): IntArray {
val intArray = IntArray(input.size())
for (i in 0 until input.size()) {
intArray[i] = input.getInt(i)
}
return intArray
}

fun createFloatArray(input: ReadableArray): FloatArray {
val floatArray = FloatArray(input.size())
for (i in 0 until input.size()) {
floatArray[i] = input.getDouble(i).toFloat()
}
return floatArray
}

fun createLongArray(input: ReadableArray): LongArray {
val longArray = LongArray(input.size())
for (i in 0 until input.size()) {
longArray[i] = input.getInt(i).toLong()
}
return longArray
}

fun createDoubleArray(input: ReadableArray): DoubleArray {
val doubleArray = DoubleArray(input.size())
for (i in 0 until input.size()) {
doubleArray[i] = input.getDouble(i)
}
return doubleArray
}

fun createReadableArray(result: Tensor): ReadableArray {
val resultArray = Arguments.createArray()
when (result.dtype()) {
DType.UINT8 -> {
val byteArray = result.dataAsByteArray
for (i in byteArray) {
resultArray.pushInt(i.toInt())
}
}

DType.INT32 -> {
val intArray = result.dataAsIntArray
for (i in intArray) {
resultArray.pushInt(i)
}
}

DType.FLOAT -> {
val longArray = result.dataAsFloatArray
for (i in longArray) {
resultArray.pushDouble(i.toDouble())
}
}

DType.DOUBLE -> {
val floatArray = result.dataAsDoubleArray
for (i in floatArray) {
resultArray.pushDouble(i)
}
}

DType.INT64 -> {
val doubleArray = result.dataAsLongArray
for (i in doubleArray) {
resultArray.pushLong(i)
}
}

else -> {
throw IllegalArgumentException("Invalid dtype: ${result.dtype()}")
}
}

return resultArray
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch
package com.swmansion.rnexecutorch.utils

import android.content.Context
import okhttp3.Call
Expand Down Expand Up @@ -113,11 +113,17 @@ class Fetcher {

private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
// Create a new URL using the base URL and append the desired path
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
val baseUrl =
modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
return URL(baseUrl + "resolve/main/config.json")
}

private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
private fun sendRequestToUrl(
url: URL,
method: String,
body: RequestBody?,
client: OkHttpClient
): Response {
val request = Request.Builder()
.url(url)
.method(method, body)
Expand All @@ -134,18 +140,18 @@ class Fetcher {
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
2. If file name contains / it means that the file is local and we should return the path
3. Check if the file has a valid extension
a. For tokenizer, the extension should be .bin
b. For model, the extension should be .pte
4. Check if models directory exists, if not create it
5. Check if the file already exists in the models directory, if yes return the path
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String

try {
Expand All @@ -165,7 +171,7 @@ class Fetcher {
return
}

var tempFile = File(context.filesDir, fileName)
val tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.swmansion.rnexecutorch
package com.swmansion.rnexecutorch.utils

import okhttp3.MediaType
import okhttp3.ResponseBody
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.ReadableArray
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class TensorUtils {
companion object {
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
try {
when (type) {
0 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
return EValue.from(inputTensor)
}

1 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
return EValue.from(inputTensor)
}

2 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
return EValue.from(inputTensor)
}

3 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
return EValue.from(inputTensor)
}

4 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
return EValue.from(inputTensor)
}

else -> {
throw IllegalArgumentException("Invalid input type: $type")
}
}
} catch (e: IllegalArgumentException) {
throw e
}
}
}
}
Loading

0 comments on commit e91ce11

Please sign in to comment.