-
Notifications
You must be signed in to change notification settings - Fork 354
/
Copy pathmain.rs
153 lines (144 loc) · 5.95 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// The pre-trained weights can be downloaded here:
// /~https://github.com/LaurentMazare/ocaml-torch/releases/download/v0.1-unstable/yolo-v3.ot
//
// These weights have been generated by using the Python version at:
// /~https://github.com/ayooshkathuria/YOLO_v3_tutorial_from_scratch
// The weights are exported in npz format by adding the following code in detect.py after the
// model.load_weights(...) call.
//
// ```python
// def remove_prefix(text, prefix):
// return text[text.startswith(prefix) and len(prefix):]
// nps = {}
// for k, v in model.state_dict().items():
// k = remove_prefix(k, 'module_list.')
// nps[k] = v.detach().numpy()
// np.savez('yolo-v3.ot', **nps)
// ```
//
// Then tch-rs tensor-tools example can be used to convert the .npz weights
// to the requested .ot file.
// cargo run --example tensor-tools cp yolo-v3.ot.npz yolo-v3.ot
mod coco_classes;
mod darknet;
use anyhow::{ensure, Result};
use tch::nn::ModuleT;
use tch::vision::image;
use tch::Tensor;
const CONFIG_NAME: &str = "examples/yolo/yolo-v3.cfg";
const CONFIDENCE_THRESHOLD: f64 = 0.5;
const NMS_THRESHOLD: f64 = 0.4;
#[derive(Debug, Clone, Copy)]
struct Bbox {
xmin: f64,
ymin: f64,
xmax: f64,
ymax: f64,
confidence: f64,
}
// Intersection over union of two bounding boxes.
fn iou(b1: &Bbox, b2: &Bbox) -> f64 {
let b1_area = (b1.xmax - b1.xmin + 1.) * (b1.ymax - b1.ymin + 1.);
let b2_area = (b2.xmax - b2.xmin + 1.) * (b2.ymax - b2.ymin + 1.);
let i_xmin = b1.xmin.max(b2.xmin);
let i_xmax = b1.xmax.min(b2.xmax);
let i_ymin = b1.ymin.max(b2.ymin);
let i_ymax = b1.ymax.min(b2.ymax);
let i_area = (i_xmax - i_xmin + 1.).max(0.) * (i_ymax - i_ymin + 1.).max(0.);
i_area / (b1_area + b2_area - i_area)
}
// Assumes x1 <= x2 and y1 <= y2
pub fn draw_rect(t: &mut Tensor, x1: i64, x2: i64, y1: i64, y2: i64) {
let color = Tensor::from_slice(&[0., 0., 1.]).view([3, 1, 1]);
t.narrow(2, x1, x2 - x1).narrow(1, y1, y2 - y1).copy_(&color)
}
pub fn report(pred: &Tensor, img: &Tensor, w: i64, h: i64) -> Result<Tensor> {
let (npreds, pred_size) = pred.size2()?;
let nclasses = (pred_size - 5) as usize;
// The bounding boxes grouped by (maximum) class index.
let mut bboxes: Vec<Vec<Bbox>> = (0..nclasses).map(|_| vec![]).collect();
// Extract the bounding boxes for which confidence is above the threshold.
for index in 0..npreds {
let pred = Vec::<f64>::try_from(pred.get(index))?;
let confidence = pred[4];
if confidence > CONFIDENCE_THRESHOLD {
let mut class_index = 0;
for i in 0..nclasses {
if pred[5 + i] > pred[5 + class_index] {
class_index = i
}
}
if pred[class_index + 5] > 0. {
let bbox = Bbox {
xmin: pred[0] - pred[2] / 2.,
ymin: pred[1] - pred[3] / 2.,
xmax: pred[0] + pred[2] / 2.,
ymax: pred[1] + pred[3] / 2.,
confidence,
};
bboxes[class_index].push(bbox)
}
}
}
// Perform non-maximum suppression.
for bboxes_for_class in bboxes.iter_mut() {
bboxes_for_class.sort_by(|b1, b2| b2.confidence.partial_cmp(&b1.confidence).unwrap());
let mut current_index = 0;
for index in 0..bboxes_for_class.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = iou(&bboxes_for_class[prev_index], &bboxes_for_class[index]);
if iou > NMS_THRESHOLD {
drop = true;
break;
}
}
if !drop {
bboxes_for_class.swap(current_index, index);
current_index += 1;
}
}
bboxes_for_class.truncate(current_index);
}
// Annotate the original image and print boxes information.
let (_, initial_h, initial_w) = img.size3()?;
let mut img = img.to_kind(tch::Kind::Float) / 255.;
let w_ratio = initial_w as f64 / w as f64;
let h_ratio = initial_h as f64 / h as f64;
for (class_index, bboxes_for_class) in bboxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
println!("{}: {:?}", coco_classes::NAMES[class_index], b);
let xmin = ((b.xmin * w_ratio) as i64).clamp(0, initial_w - 1);
let ymin = ((b.ymin * h_ratio) as i64).clamp(0, initial_h - 1);
let xmax = ((b.xmax * w_ratio) as i64).clamp(0, initial_w - 1);
let ymax = ((b.ymax * h_ratio) as i64).clamp(0, initial_h - 1);
draw_rect(&mut img, xmin, xmax, ymin, ymax.min(ymin + 2));
draw_rect(&mut img, xmin, xmax, ymin.max(ymax - 2), ymax);
draw_rect(&mut img, xmin, xmax.min(xmin + 2), ymin, ymax);
draw_rect(&mut img, xmin.max(xmax - 2), xmax, ymin, ymax);
}
}
Ok((img * 255.).to_kind(tch::Kind::Uint8))
}
pub fn main() -> Result<()> {
let args: Vec<_> = std::env::args().collect();
ensure!(args.len() >= 3, "usage: main yolo-v3.ot img.jpg ...");
// Create the model and load the weights from the file.
let mut vs = tch::nn::VarStore::new(tch::Device::Cpu);
let darknet = darknet::parse_config(CONFIG_NAME)?;
let model = darknet.build_model(&vs.root())?;
vs.load(&args[1])?;
for (index, image) in args.iter().skip(2).enumerate() {
// Load the image file and resize it.
let original_image = image::load(image)?;
let net_width = darknet.width()?;
let net_height = darknet.height()?;
let image = image::resize(&original_image, net_width, net_height)?;
let image = image.unsqueeze(0).to_kind(tch::Kind::Float) / 255.;
let predictions = model.forward_t(&image, false).squeeze();
let image = report(&predictions, &original_image, net_width, net_height)?;
image::save(&image, format!("output-{index:05}.jpg"))?;
println!("Converted {index}");
}
Ok(())
}