Skip to content

Commit

Permalink
feat: add android object detection
Browse files Browse the repository at this point in the history
  • Loading branch information
chmjkb committed Dec 17, 2024
1 parent 002c7bf commit 2c79aea
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() {
StyleTransfer(reactContext)
} else if (name == Classification.NAME) {
Classification(reactContext)
}
} else if (name == ObjectDetection.NAME) {
ObjectDetection(reactContext)
}
else {
null
}
Expand Down Expand Up @@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
ObjectDetection.NAME,
ObjectDetection.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.Bbox
import com.swmansion.rnexecutorch.utils.CocoLabel
import com.swmansion.rnexecutorch.utils.Detection
import com.swmansion.rnexecutorch.utils.nms
import org.pytorch.executorch.EValue

const val detectionScoreThreshold = .7f
const val iouThreshold = .55f

class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
private var heightRatio: Float = 1.0f
private var widthRatio: Float = 1.0f

private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
Expand All @@ -23,36 +27,45 @@ class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : Base
return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): Mat {
override fun preprocess(input: Mat): EValue {
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
Imgproc.resize(input, input, getModelImageSize())
return input
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

fun postprocessFromEValue(eValues: Array<EValue>) : Array<Detection> {
val scoresTensor = eValues[1].toTensor()
val numel = scoresTensor.numel() // bboxes is 4 * numel, labels is the same length
val bboxes = eValues[0].toTensor().dataAsFloatArray
override fun runModel(input: Mat): Array<Detection> {
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}

override fun postprocess(output: Array<EValue>): Array<Detection> {
val scoresTensor = output[1].toTensor()
val numel = scoresTensor.numel()
val bboxes = output[0].toTensor().dataAsFloatArray
val scores = scoresTensor.dataAsFloatArray
val labels = eValues[2].toTensor().dataAsFloatArray
val labels = output[2].toTensor().dataAsFloatArray

val detections: MutableList<Detection> = mutableListOf();
for (idx in 0..numel.toInt()) {
val bbox = Bbox(bboxes[idx], bboxes[idx + 1], bboxes[idx + 1], bboxes[idx + 1])
for (idx in 0 until numel.toInt()) {
val score = scores[idx]
if (score < detectionScoreThreshold) {
continue
}
val bbox = Bbox(
bboxes[idx * 4 + 0] * this.widthRatio,
bboxes[idx * 4 + 1] * this.heightRatio,
bboxes[idx * 4 + 2] * this.widthRatio,
bboxes[idx * 4 + 3] * this.heightRatio
)
val label = labels[idx]
detections.plus(
detections.add(
Detection(bbox, score, CocoLabel.fromId(label.toInt())!!)
)
}
return detections.toTypedArray()
}

override fun postprocess(input: Tensor) {
}

override fun runModel(input: Mat): Array<Detection> {
val inputTensor = ImageProcessor.matToEValue(preprocess(input), module.getInputShape(0))
val modelOutput = forward(inputTensor)
return postprocessFromEValue(modelOutput)
val detectionsPostNms = nms(detections, iouThreshold);
return detectionsPostNms.toTypedArray()
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,101 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

fun nms(
detections: MutableList<Detection>,
iouThreshold: Float
): List<Detection> {
if (detections.isEmpty()) {
return emptyList()
}

// Sort detections first by label, then by score (descending)
val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score }))

val result = mutableListOf<Detection>()

// Process NMS for each label group
var i = 0
while (i < sortedDetections.size) {
val currentLabel = sortedDetections[i].label

// Collect detections for the current label
val labelDetections = mutableListOf<Detection>()
while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) {
labelDetections.add(sortedDetections[i])
i++
}

// Filter out detections with high IoU
val filteredLabelDetections = mutableListOf<Detection>()
while (labelDetections.isNotEmpty()) {
val current = labelDetections.removeAt(0)
filteredLabelDetections.add(current)

// Remove detections that overlap with the current detection above the IoU threshold
val iterator = labelDetections.iterator()
while (iterator.hasNext()) {
val other = iterator.next()
if (calculateIoU(current.bbox, other.bbox) > iouThreshold) {
iterator.remove() // Remove detection if IoU is above threshold
}
}
}

// Add the filtered detections to the result
result.addAll(filteredLabelDetections)
}

return result
}

fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float {
val x1 = maxOf(bbox1.x1, bbox2.x1)
val y1 = maxOf(bbox1.y1, bbox2.y1)
val x2 = minOf(bbox1.x2, bbox2.x2)
val y2 = minOf(bbox1.y2, bbox2.y2)

val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1)
val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1)

val unionArea = bbox1Area + bbox2Area - intersectionArea
return if (unionArea == 0f) 0f else intersectionArea / unionArea
}


data class Bbox(
val x1: Float,
val x2: Float,
val y1: Float,
val x2: Float,
val y2: Float
)
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putDouble("x1", x1.toDouble())
map.putDouble("x2", x2.toDouble())
map.putDouble("y1", y1.toDouble())
map.putDouble("y2", y2.toDouble())
return map
}
}


data class Detection(
val bbox: Bbox,
val score: Float,
val label: CocoLabel,
)
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putMap("bbox", bbox.toWritableMap())
map.putDouble("score", score.toDouble())
map.putString("label", label.name)
return map
}
}

enum class CocoLabel(val id: Int) {
PERSON(1),
Expand Down Expand Up @@ -107,10 +190,7 @@ enum class CocoLabel(val id: Int) {
HAIR_BRUSH(91);

companion object {
// A map to store the mapping from id to CocoLabel
private val idToLabelMap = values().associateBy(CocoLabel::id)

// Function to retrieve a CocoLabel by its id
fun fromId(id: Int): CocoLabel? = idToLabelMap[id]
}
}

0 comments on commit 2c79aea

Please sign in to comment.