From 0b67ed7c58ad450eeee0f2c1de539d336e683f01 Mon Sep 17 00:00:00 2001 From: Matias Mattamala Date: Fri, 26 Jan 2024 15:52:43 +0000 Subject: [PATCH] Add option to StegoInterface to use feature clustering --- .../feature_extractor/feature_extractor.py | 2 +- .../feature_extractor/stego_interface.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/wild_visual_navigation/feature_extractor/feature_extractor.py b/wild_visual_navigation/feature_extractor/feature_extractor.py index a5aa5fc3..6bf1be9b 100644 --- a/wild_visual_navigation/feature_extractor/feature_extractor.py +++ b/wild_visual_navigation/feature_extractor/feature_extractor.py @@ -216,7 +216,7 @@ def segment_random(self, img, **kwargs): def segment_stego(self, img, **kwargs): # Prepare input image img_internal = img.clone() - self.extractor.inference_crf(img_internal) + self.extractor.inference(img_internal) seg = torch.from_numpy(self.extractor.cluster_segments).to(self._device) # Change the segment indices by numbers from 0 to N diff --git a/wild_visual_navigation/feature_extractor/stego_interface.py b/wild_visual_navigation/feature_extractor/stego_interface.py index 2fa63082..47802f98 100644 --- a/wild_visual_navigation/feature_extractor/stego_interface.py +++ b/wild_visual_navigation/feature_extractor/stego_interface.py @@ -17,6 +17,7 @@ def __init__( device: str, input_size: int = 448, model_path: str = f"{STEGO_ROOT_DIR}/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt", + n_image_clusters: int = 20, run_crf: bool = True, run_clustering: bool = False, cfg: OmegaConf = OmegaConf.create({}), @@ -29,12 +30,13 @@ def __init__( "input_size": input_size, "run_crf": run_crf, "run_clustering": run_clustering, + "n_image_clusters": n_image_clusters, } ) else: self._cfg = cfg - self._model = Stego.load_from_checkpoint(self._cfg.model_path) + self._model = Stego.load_from_checkpoint(self._cfg.model_path, n_image_clusters=self._cfg.n_image_clusters) self._model.eval().to(device) self._device = device @@ -142,8 +144,9 @@ def run_stego_interfacer(): si = StegoInterface( device=device, input_size=448, - run_crf=True, - run_clustering=False, + run_crf=False, + run_clustering=True, + n_image_clusters=20, ) p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png") @@ -152,9 +155,9 @@ def run_stego_interfacer(): img = torch.from_numpy(np_img).to(device) img = img.permute(2, 0, 1) img = (img.type(torch.float32) / 255)[None] - img = F.interpolate(img, scale_factor=0.25) + img = F.interpolate(img, scale_factor=0.5) - with Timer(f"Stego (input {si.input_size}"): + with Timer(f"Stego input {si.input_size}"): linear_pred, cluster_pred = si.inference(img) # Plot result as in colab