From c080bb3f9ec6a4950fd9789b9ad7464d6f63a5fe Mon Sep 17 00:00:00 2001 From: Piotr Libera Date: Sat, 21 Dec 2024 22:22:09 +0100 Subject: [PATCH] Enable using WVN with non-square input --- .../feature_extractor/dino_interface.py | 2 +- .../feature_extractor/stego_interface.py | 2 +- .../image_projector/image_projector.py | 9 +++++---- .../scripts/wvn_feature_extractor_node.py | 14 +++++++------- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/wild_visual_navigation/feature_extractor/dino_interface.py b/wild_visual_navigation/feature_extractor/dino_interface.py index 816ad25e..49590ae4 100644 --- a/wild_visual_navigation/feature_extractor/dino_interface.py +++ b/wild_visual_navigation/feature_extractor/dino_interface.py @@ -85,7 +85,7 @@ def inference(self, img: torch.tensor): # resize and interpolate features B, D, H, W = img.shape - new_features_size = (H, H) + new_features_size = (H, W) # pad = int((W - H) / 2) features = F.interpolate(features, new_features_size, mode="bilinear", align_corners=True) # features = F.pad(features, pad=[pad, pad, 0, 0]) diff --git a/wild_visual_navigation/feature_extractor/stego_interface.py b/wild_visual_navigation/feature_extractor/stego_interface.py index d399603a..586da655 100644 --- a/wild_visual_navigation/feature_extractor/stego_interface.py +++ b/wild_visual_navigation/feature_extractor/stego_interface.py @@ -102,7 +102,7 @@ def inference(self, img: torch.tensor): # resize and interpolate features # with Timer("interpolate output"): B, D, H, W = img.shape - new_features_size = (H, H) + new_features_size = (H, W) # pad = int((W - H) / 2) self._code = F.interpolate(self._code, new_features_size, mode="bilinear", align_corners=True) self._cluster_pred = F.interpolate(self._cluster_pred[None].float(), new_features_size, mode="nearest").int() diff --git a/wild_visual_navigation/image_projector/image_projector.py b/wild_visual_navigation/image_projector/image_projector.py index f1d8a3c2..e7346f09 100644 --- a/wild_visual_navigation/image_projector/image_projector.py +++ b/wild_visual_navigation/image_projector/image_projector.py @@ -43,17 +43,18 @@ def __init__(self, K: torch.tensor, h: int, w: int, new_h: int = None, new_w: in self.width = w new_h = self.height.item() if new_h is None else new_h + new_w = self.width.item() if new_w is None else new_w # Compute scale sy = new_h / h - sx = (new_w / w) if (new_w is not None) else sy + sx = new_w / w # Compute scaled parameters sh = new_h - sw = new_w if new_w is not None else sh + sw = new_w # Prepare image cropper - if new_w is None or new_w == new_h: + if new_w == new_h: self.image_crop = T.Compose([T.Resize(new_h, T.InterpolationMode.NEAREST), T.CenterCrop(new_h)]) else: self.image_crop = T.Resize([new_h, new_w], T.InterpolationMode.NEAREST) @@ -61,7 +62,7 @@ def __init__(self, K: torch.tensor, h: int, w: int, new_h: int = None, new_w: in # Adjust camera matrix # Fill values sK = K.clone() - if new_w is None or new_w == new_h: + if new_w == new_h: sK[:, 0, 0] = K[:, 1, 1] * sy sK[:, 0, 2] = K[:, 1, 2] * sy sK[:, 1, 1] = K[:, 1, 1] * sy diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index e6d4caa5..ac7c86a9 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -53,7 +53,7 @@ def __init__(self, node_name): feature_type=self._ros_params.feature_type, patch_size=self._ros_params.dino_patch_size, backbone_type=self._ros_params.dino_backbone, - input_size=self._ros_params.network_input_image_height, + input_size=(self._ros_params.network_input_image_height, self._ros_params.network_input_image_width), slic_num_components=self._ros_params.slic_num_components, ) @@ -339,8 +339,8 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo msg = rc.numpy_to_ros_image(out_trav.cpu().numpy(), "passthrough") msg.header = image_msg.header - msg.width = out_trav.shape[0] - msg.height = out_trav.shape[1] + msg.width = out_trav.shape[1] + msg.height = out_trav.shape[0] self._camera_handler[cam]["trav_pub"].publish(msg) msg = self._camera_handler[cam]["camera_info_msg_out"] @@ -354,8 +354,8 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo "rgb8", ) msg.header = image_msg.header - msg.width = torch_image.shape[1] - msg.height = torch_image.shape[2] + msg.width = torch_image.shape[2] + msg.height = torch_image.shape[1] self._camera_handler[cam]["input_pub"].publish(msg) # Publish confidence @@ -365,8 +365,8 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo out_confidence = confidence.reshape(H, W) msg = rc.numpy_to_ros_image(out_confidence.cpu().numpy(), "passthrough") msg.header = image_msg.header - msg.width = out_confidence.shape[0] - msg.height = out_confidence.shape[1] + msg.width = out_confidence.shape[1] + msg.height = out_confidence.shape[0] self._camera_handler[cam]["conf_pub"].publish(msg) # Publish features and feature_segments