Skip to content

Commit

Permalink
Add Classification model (android)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakmro committed Dec 15, 2024
1 parent 58e35b4 commit c0e021e
Show file tree
Hide file tree
Showing 7 changed files with 1,140 additions and 10 deletions.
59 changes: 59 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/Classification.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.classification.ClassificationModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

class Classification(reactContext: ReactApplicationContext) :
NativeClassificationSpec(reactContext) {

private lateinit var classificationModel: ClassificationModel

companion object {
const val NAME = "Classification"
init {
if(!OpenCVLoader.initLocal()){
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}
}

override fun loadModule(modelSource: String, promise: Promise) {
try {
classificationModel = ClassificationModel(reactApplicationContext)
classificationModel.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val image = ImageProcessor.readImage(input)
val output = classificationModel.runModel(image)

val writableMap: WritableMap = Arguments.createMap()

for ((key, value) in output) {
writableMap.putDouble(key, value.toDouble())
}

promise.resolve(writableMap)
}catch(e: Exception){
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ class RnExecutorchPackage : TurboReactPackage() {
ETModule(reactContext)
} else if (name == StyleTransfer.NAME) {
StyleTransfer(reactContext)
} else {
} else if (name == Classification.NAME) {
Classification(reactContext)
}
else {
null
}

Expand Down Expand Up @@ -51,6 +54,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[Classification.NAME] = ReactModuleInfo(
Classification.NAME,
Classification.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ abstract class BaseModel<Input, Output>(val context: Context) {

abstract fun runModel(input: Input): Output

protected abstract fun preprocess(input: Input): Input
protected abstract fun preprocess(input: Input): EValue

protected abstract fun postprocess(input: Tensor): Output
protected abstract fun postprocess(output: Array<EValue>): Output
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.EValue


class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Mat>(reactApplicationContext) {
Expand All @@ -19,22 +20,23 @@ class StyleTransferModel(reactApplicationContext: ReactApplicationContext) : Bas
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): Mat {
override fun preprocess(input: Mat): EValue {
originalSize = input.size()
Imgproc.resize(input, input, getModelImageSize())
return input
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(input: Tensor): Mat {
override fun postprocess(output: Array<EValue>): Mat {
val tensor = output[0].toTensor()
val modelShape = getModelImageSize()
val result = ImageProcessor.EValueToMat(input.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
val result = ImageProcessor.EValueToMat(tensor.dataAsFloatArray, modelShape.width.toInt(), modelShape.height.toInt())
Imgproc.resize(result, result, originalSize)
return result
}

override fun runModel(input: Mat): Mat {
val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0))
val outputTensor = forward(inputTensor)
return postprocess(outputTensor[0].toTensor())
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.swmansion.rnexecutorch.models.classification

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.EValue
import com.swmansion.rnexecutorch.models.BaseModel


class ClassificationModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Map<String, Float>>(reactApplicationContext) {
private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun postprocess(output: Array<EValue>): Map<String, Float> {
val tensor = output[0].toTensor()
val probabilities = tensor.dataAsFloatArray

val result = mutableMapOf<String, Float>()

for (i in probabilities.indices) {
result[imagenet1k_v1_labels_map[i]!!] = probabilities[i]
}

return result
}

override fun runModel(input: Mat): Map<String, Float> {
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}
}
Loading

0 comments on commit c0e021e

Please sign in to comment.