diff --git a/.gitignore b/.gitignore index cb32c17b..88488fd1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ assets/stego/cocostuff27_vit_base_5.ckpt .pytest_cache/** .vscode/** notebooks/** +*.code-workspace + # DrawIO **.dtmp **.bkp @@ -143,4 +145,5 @@ dmypy.json # Pyre type checker .pyre/ -assets/virutal_env/* +assets/virtual_env/* + diff --git a/wild_visual_navigation/cfg/experiment_params.py b/wild_visual_navigation/cfg/experiment_params.py index 1ed29b97..63412abf 100644 --- a/wild_visual_navigation/cfg/experiment_params.py +++ b/wild_visual_navigation/cfg/experiment_params.py @@ -97,7 +97,7 @@ class AblationDataModuleParams: @dataclass class ModelParams: - name: str = "LinearRnvp" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP + name: str = "SimpleMLP" # LinearRnvp, SimpleMLP, SimpleGCN, DoubleMLP load_ckpt: Optional[str] = None @dataclass diff --git a/wild_visual_navigation/feature_extractor/feature_extractor.py b/wild_visual_navigation/feature_extractor/feature_extractor.py index 6bf1be9b..055dd758 100644 --- a/wild_visual_navigation/feature_extractor/feature_extractor.py +++ b/wild_visual_navigation/feature_extractor/feature_extractor.py @@ -31,36 +31,50 @@ def __init__( self._segmentation_type = segmentation_type self._feature_type = feature_type self._input_size = input_size + # Prepare segment extractor self.segment_extractor = SegmentExtractor().to(self._device) # Prepare extractor depending on the type if self._feature_type == "stego": self._feature_dim = 90 - self.extractor = StegoInterface(device=device, input_size=input_size) - elif self._feature_type == "dino": - self._feature_dim = 90 + self._extractor = StegoInterface( + device=device, + input_size=input_size, + n_image_clusters=kwargs.get("n_image_clusters", 20), + run_clustering=kwargs.get("run_clustering", True), + run_crf=kwargs.get("run_crf", False), + ) - self.extractor = DinoInterface( + elif "dino" in self._feature_type: + self._feature_dim = 90 + self._extractor = DinoInterface( device=device, input_size=input_size, patch_size=kwargs.get("patch_size", 8), + backbone=kwargs.get("backbone", "dino"), dim=kwargs.get("dino_dim", 384), ) + elif self._feature_type == "sift": self._feature_dim = 128 - self.extractor = DenseSIFTDescriptor().to(device) + self._extractor = DenseSIFTDescriptor().to(device) + elif self._feature_type == "torchvision": self._extractor = TorchVisionInterface( device=device, model_type=kwargs["model_type"], input_size=input_size ) + elif self._feature_type == "histogram": self._feature_dim = 90 + elif self._feature_type == "none": pass + else: raise f"Extractor[{self._feature_type}] not supported!" + # Segmentation if self.segmentation_type == "slic": from fast_slic import Slic @@ -87,7 +101,7 @@ def extract(self, img, **kwargs): if kwargs.get("return_dense_features", False): return None, feat, seg, None, dense_feat - return None, feat, seg, None + return None, feat, seg, None, None # Compute segments, their centers, and edges connecting them (graph structure) # with Timer("feature_extractor - compute_segments"): @@ -104,7 +118,7 @@ def extract(self, img, **kwargs): if kwargs.get("return_dense_features", False): return edges, feat, seg, center, dense_feat - return edges, feat, seg, center + return edges, feat, seg, center, None @property def feature_type(self): @@ -125,7 +139,7 @@ def change_device(self, device): device (str): new device """ self._device = device - self.extractor.change_device(device) + self._extractor.change_device(device) def compute_segments(self, img: torch.tensor, **kwargs): if self._segmentation_type == "none" or self._segmentation_type is None: @@ -149,9 +163,9 @@ def compute_segments(self, img: torch.tensor, **kwargs): # Compute edges and centers if self._segmentation_type != "none" and self._segmentation_type is not None: # Extract adjacency_list based on segments - edges = self.segment_extractor.adjacency_list(seg[None, None]) + edges = self.segment_extractor.adjacency_list(seg) # Extract centers - centers = self.segment_extractor.centers(seg[None, None]) + centers = self.segment_extractor.centers(seg) return edges.T, seg, centers @@ -187,20 +201,20 @@ def segment_grid(self, img, **kwargs): for i in range(patches.shape[1]): patches[:, i, :, :, :] = i - combine_patch_size = (int(H / cell_size), int(W / cell_size)) + # combine_patch_size = (int(H / cell_size), int(W / cell_size)) seg = combine_tensor_patches( patches=patches, original_size=(H, W), - window_size=combine_patch_size, - stride=combine_patch_size, + window_size=patch_size, + stride=patch_size, ) - return seg[0, 0].to(self._device) + return seg.to(self._device) def segment_slic(self, img, **kwargs): # Get slic clusters img_np = kornia.utils.tensor_to_image(img) - seg = self.slic.iterate(np.uint8(np.ascontiguousarray(img_np) * 255)) + seg = self.slic.iterate(np.uint8(np.ascontiguousarray(img_np) * 255))[None, None] return torch.from_numpy(seg).to(self._device).type(torch.long) def segment_random(self, img, **kwargs): @@ -210,19 +224,19 @@ def segment_random(self, img, **kwargs): seg = torch.full((H * W,), -1, dtype=torch.long, device=self._device) indices = torch.randperm(H * W, device=self._device)[:nr] seg[indices] = torch.arange(0, nr, device=self._device) - seg = seg.reshape(H, W) + seg = seg.reshape(H, W)[None, None] return seg def segment_stego(self, img, **kwargs): # Prepare input image img_internal = img.clone() - self.extractor.inference(img_internal) - seg = torch.from_numpy(self.extractor.cluster_segments).to(self._device) + self._extractor.inference(img_internal) + seg = self._extractor.cluster_segments.to(self._device) + # seg = torch.from_numpy(self._extractor.cluster_segments).to(self._device) # Change the segment indices by numbers from 0 to N for i, k in enumerate(seg.unique()): seg[seg == k.item()] = i - return seg def compute_features(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs): @@ -254,18 +268,18 @@ def compute_histogram(self, img: torch.tensor, seg: torch.tensor, **kwargs): def compute_sift(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs): B, C, H, W = img.shape if C == 3: - feat_r = self.extractor(img[:, 0, :, :][None]) - feat_g = self.extractor(img[:, 1, :, :][None]) - feat_b = self.extractor(img[:, 2, :, :][None]) + feat_r = self._extractor(img[:, 0, :, :][None]) + feat_g = self._extractor(img[:, 1, :, :][None]) + feat_b = self._extractor(img[:, 2, :, :][None]) features = torch.cat([feat_r, feat_g, feat_b], dim=1) else: - features = self.extractor(img) + features = self._extractor(img) return features @torch.no_grad() def compute_dino(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs): img_internal = img.clone() - features = self.extractor.inference(img_internal) + features = self._extractor.inference(img_internal) return features @torch.no_grad() @@ -276,7 +290,7 @@ def compute_torchvision(self, img: torch.tensor, seg: torch.tensor, center: torc @torch.no_grad() def compute_stego(self, img: torch.tensor, seg: torch.tensor, center: torch.tensor, **kwargs): - return self.extractor.features + return self._extractor.features def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cumsum_trick=False): if self._feature_type not in ["histogram"] and self._segmentation_type not in ["none"]: @@ -361,9 +375,56 @@ def sparsify_features(self, dense_features: torch.tensor, seg: torch.tensor, cum sparse_features = [] for i in range(seg.max() + 1): m = seg == i - x, y = torch.where(m) + x, y = torch.where(m[0, 0]) feat = dense_features[0, :, x, y].mean(dim=1) sparse_features.append(feat) return torch.stack(sparse_features, dim=1).T else: return dense_features + + +def run_feature_extractor(): + """Tests feature extractor""" + import os + import cv2 + from os.path import join + from pytictac import Timer + from torchvision import transforms as T + from wild_visual_navigation import WVN_ROOT_DIR + + # Create test directory + os.makedirs(join(WVN_ROOT_DIR, "results", "test_feature_extractor"), exist_ok=True) + + # Inference model + device = "cuda" if torch.cuda.is_available() else "cpu" + + p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") + np_img = cv2.imread(p) + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(np_img).to(device) + img = img.permute(2, 0, 1) + img = (img.type(torch.float32) / 255)[None] + transform = T.Compose( + [ + T.Resize(448, T.InterpolationMode.NEAREST), + T.CenterCrop(448), + ] + ) + img = transform(img) + + # create feature extractor + fe = FeatureExtractor(device=device, segmentation_type="slic", feature_type="dino") + with Timer(f"SLIC-DINO"): + edges, feat, seg, center, dense_feat = fe.extract(img) + + fe = FeatureExtractor(device=device, segmentation_type="grid", feature_type="dino") + with Timer(f"GRID-DINO"): + edges, feat, seg, center, dense_feat = fe.extract(img) + + fe = FeatureExtractor(device=device, segmentation_type="stego", feature_type="stego") + with Timer(f"STEGO-STEGO"): + edges, feat, seg, center, dense_feat = fe.extract(img) + + +if __name__ == "__main__": + run_feature_extractor() diff --git a/wild_visual_navigation/feature_extractor/stego_interface.py b/wild_visual_navigation/feature_extractor/stego_interface.py index 47802f98..83a5ecdc 100644 --- a/wild_visual_navigation/feature_extractor/stego_interface.py +++ b/wild_visual_navigation/feature_extractor/stego_interface.py @@ -99,7 +99,7 @@ def inference(self, img: torch.tensor): self._cluster_pred = F.interpolate(self._cluster_pred[None].float(), new_features_size, mode="nearest").int() self._linear_pred = F.interpolate(self._linear_pred[None].float(), new_features_size, mode="nearest").int() - return self._linear_pred[0], self._cluster_pred[0] + return self._linear_pred, self._cluster_pred @property def model(self): @@ -165,9 +165,9 @@ def run_stego_interfacer(): ax[0].imshow(img[0].permute(1, 2, 0).cpu().numpy()) ax[0].set_title("Image") - ax[1].imshow(si.cmap[cluster_pred[0].cpu() % si.cmap.shape[0]]) + ax[1].imshow(si.cmap[cluster_pred[0, 0].cpu() % si.cmap.shape[0]]) ax[1].set_title("Cluster Predictions") - ax[2].imshow(si.cmap[linear_pred[0].cpu()]) + ax[2].imshow(si.cmap[linear_pred[0, 0].cpu()]) ax[2].set_title("Linear Probe Predictions") remove_axes(ax) diff --git a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml index 96f3d0f9..9b7bb318 100644 --- a/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml +++ b/wild_visual_navigation_ros/config/wild_visual_navigation/default.yaml @@ -21,8 +21,8 @@ image_graph_dist_thr: 0.2 # meters supervision_graph_dist_thr: 0.1 # meters network_input_image_height: 224 # 448 network_input_image_width: 224 # 448 -segmentation_type: "slic" -feature_type: "dino" # TODO verify this here +segmentation_type: "stego" +feature_type: "stego" # TODO verify this here dino_patch_size: 8 # 8 or 16; 8 is finer slic_num_components: 100 dino_dim: 384 # 90 or 384; 384 is better @@ -49,16 +49,14 @@ status_thread_rate: 0.5 # hertz # Runtime options device: "cuda" -mode: "debug" # check out comments in the class WVNMode +mode: "online" # check out comments in the class WVNMode colormap: "RdYlBu" print_image_callback_time: false print_supervision_callback_time: false log_time: false log_confidence: false -verbose: false -debug_supervision_node_index_from_last: 10 -use_debug_for_desired: false +verbose: true extraction_store_folder: "nan" exp: "nan" diff --git a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py index 4bcb47d0..b3372d45 100644 --- a/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_feature_extractor_node.py @@ -26,10 +26,17 @@ class WvnFeatureExtractor: - def __init__(self): + def __init__(self, node_name): # Read params self.read_params() + # Initialize variables + self._node_name = node_name + self._load_model_counter = 0 + + self.model = get_model(self.params.model).to(self.ros_params.device) + self.model.eval() + self.feature_extractor = FeatureExtractor( self.ros_params.device, segmentation_type=self.ros_params.segmentation_type, @@ -38,10 +45,6 @@ def __init__(self): slic_num_components=self.ros_params.slic_num_components, dino_dim=self.ros_params.dino_dim, ) - self.i = 0 - - self.model = get_model(self.params.model).to(self.ros_params.device) - self.model.eval() if not self.anomaly_detection: self.confidence_generator = ConfidenceGenerator( @@ -81,12 +84,12 @@ def shutdown_callback(self, *args, **kwargs): sys.exit(0) def status_thread_loop(self): - # rate = rospy.Rate(self.ros_params.status_thread_rate) + rate = rospy.Rate(self.ros_params.status_thread_rate) # Learning loop while self.run_status_thread: self.status_thread_stop_event.wait(timeout=0.01) if self.status_thread_stop_event.is_set(): - rospy.logwarn("Stopped learning thread") + rospy.logwarn(f"[{self._node_name}] Stopped learning thread") break t = rospy.get_time() @@ -107,12 +110,12 @@ def status_thread_loop(self): x.add_row([k, colored(round(d, 2), c)]) else: x.add_row([k, v]) - print(x) - # try: - # rate.sleep() - # except Exception as e: - # rate = rospy.Rate(self.ros_params.status_thread_rate) - # print("Ignored jump pack in time!") + print(f"[{self._node_name}]\n{x}") + try: + rate.sleep() + except Exception: + rate = rospy.Rate(self.ros_params.status_thread_rate) + print(f"[{self._node_name}] Ignored jump pack in time!") self.status_thread_stop_event.clear() def read_params(self): @@ -153,9 +156,9 @@ def setup_ros(self, setup_fully=True): self.ros_params.camera_topics[cam]["name"] = cam # Camera info - camera_info_msg = rospy.wait_for_message( - self.ros_params.camera_topics[cam]["info_topic"], CameraInfo, timeout=15 - ) + rospy.loginfo(f"[{self._node_name}] Waiting for camera info topic...") + camera_info_msg = rospy.wait_for_message(self.ros_params.camera_topics[cam]["info_topic"], CameraInfo) + rospy.loginfo(f"[{self._node_name}] Done") K, H, W = rc.ros_cam_info_to_tensors(camera_info_msg, device=self.ros_params.device) self.camera_handler[cam]["camera_info"] = camera_info_msg @@ -216,7 +219,7 @@ def setup_ros(self, setup_fully=True): self.camera_handler[cam]["trav_pub"] = trav_pub self.camera_handler[cam]["info_pub"] = info_pub if self.anomaly_detection and self.ros_params.camera_topics[cam]["publish_confidence"]: - print(colored("Warning force set public confidence to false", "red")) + rospy.logwarn(f"[{self._node_name}] Warning force set public confidence to false") self.ros_params.camera_topics[cam]["publish_confidence"] = False if self.ros_params.camera_topics[cam]["publish_input_image"]: @@ -259,6 +262,7 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo # Update model from file if possible self.load_model() + # Convert image message to torch image torch_image = rc.ros_image_to_torch(image_msg, device=self.ros_params.device) torch_image = self.camera_handler[cam]["image_projector"].resize_image(torch_image) @@ -368,27 +372,32 @@ def image_callback(self, image_msg: Image, cam: str): # info_msg: CameraInfo self.camera_handler[cam]["imagefeat_pub"].publish(msg) def load_model(self): + """Method to load the new model weights to perform inference on the incoming images + + Args: + None + """ try: - self.i += 1 - if self.i % 100 == 0: - res = torch.load(f"{WVN_ROOT_DIR}/tmp_state_dict2.pt") + self._load_model_counter += 1 + if self._load_model_counter % 10 == 0: + new_model_state_dict = torch.load(f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") k = list(self.model.state_dict().keys())[-1] - if (self.model.state_dict()[k] != res[k]).any(): + if (self.model.state_dict()[k] != new_model_state_dict[k]).any(): if self.ros_params.verbose: self.log_data[f"time_last_model"] = rospy.get_time() self.log_data[f"nr_model_updates"] += 1 - self.model.load_state_dict(res, strict=False) + self.model.load_state_dict(new_model_state_dict, strict=False) try: - if res["traversability_threshold"] is not None: + if new_model_state_dict["traversability_threshold"] is not None: # TODO Verify if this works or the writing is need - self.ros_params.traversability_threshold = res["traversability_threshold"] - if res["confidence_generator"] is not None: - self.confidence_generator_state = res["confidence_generator"] + self.ros_params.traversability_threshold = new_model_state_dict["traversability_threshold"] + if new_model_state_dict["confidence_generator"] is not None: + self.confidence_generator_state = new_model_state_dict["confidence_generator"] - self.confidence_generator_state = res["confidence_generator"] + self.confidence_generator_state = new_model_state_dict["confidence_generator"] self.confidence_generator.var = self.confidence_generator_state["var"] self.confidence_generator.mean = self.confidence_generator_state["mean"] self.confidence_generator.std = self.confidence_generator_state["std"] @@ -397,7 +406,7 @@ def load_model(self): except Exception as e: if self.ros_params.verbose: - print(f"Model Loading Failed: {e}") + rospy.logerr(f"[{self._node_name}] Model Loading Failed: {e}") if __name__ == "__main__": @@ -409,10 +418,10 @@ def load_model(self): rospack = rospkg.RosPack() wvn_path = rospack.get_path("wild_visual_navigation_ros") - os.system(f"rosparam load {wvn_path}/config/wild_visual_navigation/default.yaml wvn_feature_extractor_node") + os.system(f"rosparam load {wvn_path}/config/wild_visual_navigation/default.yaml {node_name}") os.system( - f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/alphasense_compressed_front.yaml wvn_feature_extractor_node" + f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/alphasense_compressed_front.yaml {node_name}" ) - wvn = WvnFeatureExtractor() + wvn = WvnFeatureExtractor(node_name) rospy.spin() diff --git a/wild_visual_navigation_ros/scripts/wvn_learning_node.py b/wild_visual_navigation_ros/scripts/wvn_learning_node.py index dd73802e..6cbbee14 100644 --- a/wild_visual_navigation_ros/scripts/wvn_learning_node.py +++ b/wild_visual_navigation_ros/scripts/wvn_learning_node.py @@ -43,11 +43,14 @@ class WvnLearning: - def __init__(self): + def __init__(self, node_name): # Timers to control the rate of the publishers self.last_image_ts = rospy.get_time() self.last_supervision_ts = rospy.get_time() + # Prepare variables + self._node_name = node_name + # Read params self.read_params() @@ -124,7 +127,7 @@ def __init__(self): # Launch processes print("-" * 80) - print("Launching [learning] thread") + rospy.loginfo(f"[{self._node_name}] Launching [learning] thread") if self.ros_params.mode != WVNMode.EXTRACT_LABELS: self.learning_thread_stop_event = Event() self.learning_thread = Thread(target=self.learning_thread_loop, name="learning") @@ -133,34 +136,34 @@ def __init__(self): # self.logging_thread_stop_event = Event() # self.logging_thread = Thread(target=self.logging_thread_loop, name="logging") # self.logging_thread.start() - print("[WVN] System ready") + rospy.loginfo(f"[{self._node_name}] [WVN] System ready") def shutdown_callback(self, *args, **kwargs): # Write stuff to files rospy.logwarn("Shutdown callback called") if self.ros_params.mode != WVNMode.EXTRACT_LABELS: self.learning_thread_stop_event.set() - self.logging_thread_stop_event.set() + # self.logging_thread_stop_event.set() - print("Storing learned checkpoint...", end="") + print(f"[{self._node_name}] Storing learned checkpoint...", end="") self.traversability_estimator.save_checkpoint(self.params.general.model_path, "last_checkpoint.pt") print("done") if self.ros_params.log_time: - print("Storing timer data...", end="") + print(f"[{self._node_name}] Storing timer data...", end="") self.timer.store(folder=self.params.general.model_path) print("done") - print("Joining learning thread...", end="") + print(f"[{self._node_name}] Joining learning thread...", end="") if self.ros_params.mode != WVNMode.EXTRACT_LABELS: self.learning_thread_stop_event.set() self.learning_thread.join() - self.logging_thread_stop_event.set() - self.logging_thread.join() + # self.logging_thread_stop_event.set() + # self.logging_thread.join() print("done") - rospy.signal_shutdown(f"Wild Visual Navigation killed {args}") + rospy.signal_shutdown(f"[{self._node_name}] Wild Visual Navigation killed {args}") sys.exit(0) @accumulate_time @@ -225,8 +228,10 @@ def learning_thread_loop(self): cg = self.traversability_estimator._traversability_loss._confidence_generator res["confidence_generator"] = cg.get_dict() - os.system(f"rm {WVN_ROOT_DIR}/tmp_state_dict2.pt") - torch.save(res, f"{WVN_ROOT_DIR}/tmp_state_dict2.pt") + os.remove( + f"{WVN_ROOT_DIR}/.tmp_state_dict.pt", + ) + torch.save(res, f"{WVN_ROOT_DIR}/.tmp_state_dict.pt") i += 1 self.system_events["learning_thread_loop"] = { @@ -284,7 +289,9 @@ def read_params(self): # Parse operation modes if self.ros_params.mode == WVNMode.ONLINE: - print("\nWARNING: online_mode enabled. The graph will not store any debug/training data such as images\n") + rospy.logwarn( + f"[{self._node_name}] WARNING: online_mode enabled. The graph will not store any debug/training data such as images\n" + ) elif self.ros_params.mode == WVNMode.EXTRACT_LABELS: with read_write(self.ros_params): @@ -323,9 +330,13 @@ def setup_ros(self, setup_fully=True): [robot_state_sub, desired_twist_sub], queue_size=10, slop=0.5 ) - print(f"Start waiting for RobotState topic {self.ros_params.robot_state_topic} being published!") + rospy.loginfo( + f"[{self._node_name}] Start waiting for RobotState topic {self.ros_params.robot_state_topic} being published!" + ) rospy.wait_for_message(self.ros_params.robot_state_topic, RobotState) - print(f"Start waiting for TwistStamped topic {self.ros_params.desired_twist_topic} being published!") + rospy.loginfo( + f"[{self._node_name}] Start waiting for TwistStamped topic {self.ros_params.desired_twist_topic} being published!" + ) rospy.wait_for_message(self.ros_params.desired_twist_topic, TwistStamped) self.robot_state_sub.registerCallback(self.robot_state_callback) @@ -413,14 +424,14 @@ def pause_learning_callback(self, req): def reset_callback(self, req): """Resets the system""" - print("WARNING: System reset!") + rospy.logwarn(f"[{self._node_name}] System reset!") - print("Storing learned checkpoint...", end="") + print(f"[{self._node_name}] Storing learned checkpoint...", end="") self.traversability_estimator.save_checkpoint(self.params.general.model_path, "last_checkpoint.pt") print("done") if self.ros_params.log_time: - print("Storing timer data...", end="") + print(f"[{self._node_name}] Storing timer data...", end="") self.timer.store(folder=self.params.general.model_path) print("done") @@ -430,7 +441,7 @@ def reset_callback(self, req): # Reset traversability estimator self.traversability_estimator.reset() - print("Reset done") + print(f"[{self._node_name}] Reset done") return TriggerResponse(True, "Reset done!") @accumulate_time @@ -496,10 +507,10 @@ def query_tf(self, parent_frame: str, child_frame: str, stamp: Optional[rospy.Ti ) rot /= np.linalg.norm(rot) return (trans, tuple(rot)) - except Exception as e: + except Exception: if self.ros_params.verbose: - print("Error in query tf: ", e) - rospy.logwarn(f"Couldn't get between {parent_frame} and {child_frame}") + # print("Error in query tf: ", e) + rospy.logwarn(f"[{self._node_name}] Couldn't get between {parent_frame} and {child_frame}") return (None, None) @accumulate_time @@ -525,7 +536,7 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): self.last_propio_ts = ts # Query transforms from TF - suc, pose_base_in_world = rc.ros_tf_to_torch( + success, pose_base_in_world = rc.ros_tf_to_torch( self.query_tf( self.ros_params.fixed_frame, self.ros_params.base_frame, @@ -533,14 +544,14 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): ), device=self.ros_params.device, ) - if not suc: + if not success: self.system_events["robot_state_callback_cancled"] = { "time": rospy.get_time(), "value": "cancled due to pose_base_in_world", } return - suc, pose_footprint_in_base = rc.ros_tf_to_torch( + success, pose_footprint_in_base = rc.ros_tf_to_torch( self.query_tf( self.ros_params.base_frame, self.ros_params.footprint_frame, @@ -548,10 +559,10 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): ), device=self.ros_params.device, ) - if not suc: - self.system_events["robot_state_callback_cancled"] = { + if not success: + self.system_events["robot_state_callback_canceled"] = { "time": rospy.get_time(), - "value": "cancled due to pose_footprint_in_base", + "value": "canceled due to pose_footprint_in_base", } return @@ -597,7 +608,7 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): self.visualize_supervision() if self.ros_params.print_supervision_callback_time: - print(self.timer) + print(f"[{self._node_name}]\n{self.timer}") self.system_events["robot_state_callback_state"] = { "time": rospy.get_time(), @@ -606,7 +617,7 @@ def robot_state_callback(self, state_msg, desired_twist_msg: TwistStamped): except Exception as e: traceback.print_exc() - print("error state callback", e) + rospy.logerr(f"[{self._node_name}] error state callback", e) self.system_events["robot_state_callback_state"] = { "time": rospy.get_time(), "value": f"failed to execute {e}", @@ -634,21 +645,21 @@ def imagefeat_callback(self, *args): "value": "message received", } if self.ros_params.verbose: - print(f"\nImage callback: {camera_options['name']}... ", end="") + print(f"[{self._node_name}] Image callback: {camera_options['name']}... ", end="") try: # Run the callback so as to match the desired rate ts = imagefeat_msg.header.stamp.to_sec() if abs(ts - self.last_image_ts) < 1.0 / self.ros_params.image_callback_rate: if self.ros_params.verbose: - print("skip") + print(f"skip") return else: if self.ros_params.verbose: - print("process") + print(f"process") self.last_image_ts = ts # Query transforms from TF - suc, pose_base_in_world = rc.ros_tf_to_torch( + success, pose_base_in_world = rc.ros_tf_to_torch( self.query_tf( self.ros_params.fixed_frame, self.ros_params.base_frame, @@ -656,13 +667,13 @@ def imagefeat_callback(self, *args): ), device=self.ros_params.device, ) - if not suc: + if not success: self.system_events["image_callback_cancled"] = { "time": rospy.get_time(), "value": "cancled due to pose_base_in_world", } return - suc, pose_cam_in_base = rc.ros_tf_to_torch( + success, pose_cam_in_base = rc.ros_tf_to_torch( self.query_tf( self.ros_params.base_frame, imagefeat_msg.header.frame_id, @@ -671,10 +682,10 @@ def imagefeat_callback(self, *args): device=self.ros_params.device, ) - if not suc: + if not success: self.system_events["image_callback_cancled"] = { "time": rospy.get_time(), - "value": "cancled due to pose_cam_in_base", + "value": "canceled due to pose_cam_in_base", } return # Prepare image projector @@ -740,7 +751,7 @@ def imagefeat_callback(self, *args): # Print callback time if required if self.ros_params.print_image_callback_time: - print(self.timer) + rospy.loginfo(f"[{self._node_name}]\n{self.timer}") self.system_events["image_callback_state"] = { "time": rospy.get_time(), @@ -749,7 +760,7 @@ def imagefeat_callback(self, *args): except Exception as e: traceback.print_exc() - print("error image callback", e) + rospy.logerr(f"[{self._node_name}] error image callback", e) self.system_events["image_callback_state"] = { "time": rospy.get_time(), "value": f"failed to execute {e}", @@ -848,14 +859,8 @@ def visualize_supervision(self): if node.is_untraversable: untraversable_plane = node.get_untraversable_plane(grid_size=2) N, D = untraversable_plane.shape - for n in [ - 0, - 1, - 3, - 2, - 0, - 3, - ]: # this is a hack to show the triangles correctly + # the following is a 'hack' to show the triangles correctly + for n in [0, 1, 3, 2, 0, 3]: p = Point() p.x = untraversable_plane[n, 0] p.y = untraversable_plane[n, 1] @@ -866,7 +871,7 @@ def visualize_supervision(self): # Publish if len(footprints_marker.points) % 3 != 0: if self.ros_params.verbose: - print(f"number of points for footprint is {len(footprints_marker.points)}") + rospy.loginfo(f"[{self._node_name}] number of points for footprint is {len(footprints_marker.points)}") return self.pub_graph_footprints.publish(footprints_marker) self.pub_debug_supervision_graph.publish(supervision_graph_msg) @@ -925,10 +930,10 @@ def visualize_image_overlay(self): rospack = rospkg.RosPack() wvn_path = rospack.get_path("wild_visual_navigation_ros") - os.system(f"rosparam load {wvn_path}/config/wild_visual_navigation/default.yaml wvn_learning_node") + os.system(f"rosparam load {wvn_path}/config/wild_visual_navigation/default.yaml {node_name}") os.system( - f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/alphasense_compressed_front.yaml wvn_learning_node" + f"rosparam load {wvn_path}/config/wild_visual_navigation/inputs/alphasense_compressed_front.yaml {node_name}" ) - wvn = WvnLearning() + wvn = WvnLearning(node_name) rospy.spin()