From a119d239f61a5f2e747efaeb32a4ddb4ee2957a1 Mon Sep 17 00:00:00 2001 From: ziqi-jin Date: Wed, 27 Jul 2022 07:29:01 +0000 Subject: [PATCH 1/2] fix export bugs --- export.py | 2 +- models/yolo.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/export.py b/export.py index 306dcb9d..231035ce 100644 --- a/export.py +++ b/export.py @@ -23,7 +23,7 @@ parser = argparse.ArgumentParser() parser.add_argument('--weights', type=str, default=r'C:\Users\chen\Desktop\Model_Zoo\model_zoo/v5lite-e.pt', help='weights path') # from yolov5/models/ parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size') # height, width - parser.add_argument('--concat', type=str, default=True, help='concat or not') + parser.add_argument('--concat', action='store_false', help='concat or not') parser.add_argument('--batch-size', type=int, default=1, help='batch size') parser.add_argument('--dynamic', action='store_true', help='dynamic ONNX axes') parser.add_argument('--grid', action='store_true', help='export Detect() layer grid') diff --git a/models/yolo.py b/models/yolo.py index d6a92df8..fd35faee 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -53,8 +53,13 @@ def forward(self, x): logits = x[i][..., 5:] y = x[i].sigmoid() - y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy - y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + if not torch.onnx.is_in_onnx_export(): + y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + else: + xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy + wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].data # wh + y = torch.cat((xy, wh, y[..., 4:]), -1) z.append(y.view(bs, -1, self.no)) logits_.append(logits.view(bs, -1, self.no - 5)) From a7952c954dca7db9fe394444dbd97d037e96135c Mon Sep 17 00:00:00 2001 From: ziqi-jin Date: Wed, 27 Jul 2022 08:04:02 +0000 Subject: [PATCH 2/2] fix export problem --- models/yolo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/yolo.py b/models/yolo.py index fd35faee..8b923e81 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -48,7 +48,9 @@ def forward(self, x): x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if not self.training: # inference - if self.grid[i].shape[2:4] != x[i].shape[2:4]: + if torch.onnx.is_in_onnx_export(): + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + elif self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i] = self._make_grid(nx, ny).to(x[i].device) logits = x[i][..., 5:]