diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index fc4ba2f..6a4fdc2 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() { StyleTransfer(reactContext) } else if (name == Classification.NAME) { Classification(reactContext) - } + } else if (name == ObjectDetection.NAME) { + ObjectDetection(reactContext) + } else { null } @@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true ) + + moduleInfos[ObjectDetection.NAME] = ReactModuleInfo( + ObjectDetection.NAME, + ObjectDetection.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt index 9af9ab8..f8c5748 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt @@ -5,15 +5,19 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Mat import org.opencv.core.Size import org.opencv.imgproc.Imgproc -import org.pytorch.executorch.Tensor import com.swmansion.rnexecutorch.models.BaseModel import com.swmansion.rnexecutorch.utils.Bbox import com.swmansion.rnexecutorch.utils.CocoLabel import com.swmansion.rnexecutorch.utils.Detection +import com.swmansion.rnexecutorch.utils.nms import org.pytorch.executorch.EValue +const val detectionScoreThreshold = .7f +const val iouThreshold = .55f class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel>(reactApplicationContext) { + private var heightRatio: Float = 1.0f + private var widthRatio: Float = 1.0f private fun getModelImageSize(): Size { val inputShape = module.getInputShape(0) @@ -23,36 +27,45 @@ class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : Base return Size(height.toDouble(), width.toDouble()) } - override fun preprocess(input: Mat): Mat { + override fun preprocess(input: Mat): EValue { + this.widthRatio = (input.size().width / getModelImageSize().width).toFloat() + this.heightRatio = (input.size().height / getModelImageSize().height).toFloat() Imgproc.resize(input, input, getModelImageSize()) - return input + return ImageProcessor.matToEValue(input, module.getInputShape(0)) } - fun postprocessFromEValue(eValues: Array) : Array { - val scoresTensor = eValues[1].toTensor() - val numel = scoresTensor.numel() // bboxes is 4 * numel, labels is the same length - val bboxes = eValues[0].toTensor().dataAsFloatArray + override fun runModel(input: Mat): Array { + val modelInput = preprocess(input) + val modelOutput = forward(modelInput) + return postprocess(modelOutput) + } + + override fun postprocess(output: Array): Array { + val scoresTensor = output[1].toTensor() + val numel = scoresTensor.numel() + val bboxes = output[0].toTensor().dataAsFloatArray val scores = scoresTensor.dataAsFloatArray - val labels = eValues[2].toTensor().dataAsFloatArray + val labels = output[2].toTensor().dataAsFloatArray val detections: MutableList = mutableListOf(); - for (idx in 0..numel.toInt()) { - val bbox = Bbox(bboxes[idx], bboxes[idx + 1], bboxes[idx + 1], bboxes[idx + 1]) + for (idx in 0 until numel.toInt()) { val score = scores[idx] + if (score < detectionScoreThreshold) { + continue + } + val bbox = Bbox( + bboxes[idx * 4 + 0] * this.widthRatio, + bboxes[idx * 4 + 1] * this.heightRatio, + bboxes[idx * 4 + 2] * this.widthRatio, + bboxes[idx * 4 + 3] * this.heightRatio + ) val label = labels[idx] - detections.plus( + detections.add( Detection(bbox, score, CocoLabel.fromId(label.toInt())!!) ) } - return detections.toTypedArray() - } - - override fun postprocess(input: Tensor) { - } - override fun runModel(input: Mat): Array { - val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0)) - val modelOutput = forward(inputTensor) - return postprocessFromEValue(modelOutput) + val detectionsPostNms = nms(detections, iouThreshold); + return detectionsPostNms.toTypedArray() } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt index f58b385..00fd3fb 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt @@ -1,18 +1,101 @@ package com.swmansion.rnexecutorch.utils +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.WritableMap + +fun nms( + detections: MutableList, + iouThreshold: Float +): List { + if (detections.isEmpty()) { + return emptyList() + } + + // Sort detections first by label, then by score (descending) + val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score })) + + val result = mutableListOf() + + // Process NMS for each label group + var i = 0 + while (i < sortedDetections.size) { + val currentLabel = sortedDetections[i].label + + // Collect detections for the current label + val labelDetections = mutableListOf() + while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) { + labelDetections.add(sortedDetections[i]) + i++ + } + + // Filter out detections with high IoU + val filteredLabelDetections = mutableListOf() + while (labelDetections.isNotEmpty()) { + val current = labelDetections.removeAt(0) + filteredLabelDetections.add(current) + + // Remove detections that overlap with the current detection above the IoU threshold + val iterator = labelDetections.iterator() + while (iterator.hasNext()) { + val other = iterator.next() + if (calculateIoU(current.bbox, other.bbox) > iouThreshold) { + iterator.remove() // Remove detection if IoU is above threshold + } + } + } + + // Add the filtered detections to the result + result.addAll(filteredLabelDetections) + } + + return result +} + +fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float { + val x1 = maxOf(bbox1.x1, bbox2.x1) + val y1 = maxOf(bbox1.y1, bbox2.y1) + val x2 = minOf(bbox1.x2, bbox2.x2) + val y2 = minOf(bbox1.y2, bbox2.y2) + + val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1) + val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1) + val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1) + + val unionArea = bbox1Area + bbox2Area - intersectionArea + return if (unionArea == 0f) 0f else intersectionArea / unionArea +} + + data class Bbox( val x1: Float, - val x2: Float, val y1: Float, + val x2: Float, val y2: Float -) +) { + fun toWritableMap(): WritableMap { + val map = Arguments.createMap() + map.putDouble("x1", x1.toDouble()) + map.putDouble("x2", x2.toDouble()) + map.putDouble("y1", y1.toDouble()) + map.putDouble("y2", y2.toDouble()) + return map + } +} data class Detection( val bbox: Bbox, val score: Float, val label: CocoLabel, -) +) { + fun toWritableMap(): WritableMap { + val map = Arguments.createMap() + map.putMap("bbox", bbox.toWritableMap()) + map.putDouble("score", score.toDouble()) + map.putString("label", label.name) + return map + } +} enum class CocoLabel(val id: Int) { PERSON(1), @@ -107,10 +190,7 @@ enum class CocoLabel(val id: Int) { HAIR_BRUSH(91); companion object { - // A map to store the mapping from id to CocoLabel private val idToLabelMap = values().associateBy(CocoLabel::id) - - // Function to retrieve a CocoLabel by its id fun fromId(id: Int): CocoLabel? = idToLabelMap[id] } }