Skip to content

Commit

Permalink
Add option to StegoInterface to use feature clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
mmattamala committed Jan 26, 2024
1 parent 9d6809f commit 0b67ed7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def segment_random(self, img, **kwargs):
def segment_stego(self, img, **kwargs):
# Prepare input image
img_internal = img.clone()
self.extractor.inference_crf(img_internal)
self.extractor.inference(img_internal)
seg = torch.from_numpy(self.extractor.cluster_segments).to(self._device)

# Change the segment indices by numbers from 0 to N
Expand Down
13 changes: 8 additions & 5 deletions wild_visual_navigation/feature_extractor/stego_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
device: str,
input_size: int = 448,
model_path: str = f"{STEGO_ROOT_DIR}/models/stego_cocostuff27_vit_base_5_cluster_linear_fine_tuning.ckpt",
n_image_clusters: int = 20,
run_crf: bool = True,
run_clustering: bool = False,
cfg: OmegaConf = OmegaConf.create({}),
Expand All @@ -29,12 +30,13 @@ def __init__(
"input_size": input_size,
"run_crf": run_crf,
"run_clustering": run_clustering,
"n_image_clusters": n_image_clusters,
}
)
else:
self._cfg = cfg

self._model = Stego.load_from_checkpoint(self._cfg.model_path)
self._model = Stego.load_from_checkpoint(self._cfg.model_path, n_image_clusters=self._cfg.n_image_clusters)
self._model.eval().to(device)
self._device = device

Expand Down Expand Up @@ -142,8 +144,9 @@ def run_stego_interfacer():
si = StegoInterface(
device=device,
input_size=448,
run_crf=True,
run_clustering=False,
run_crf=False,
run_clustering=True,
n_image_clusters=20,
)

p = join(WVN_ROOT_DIR, "assets/images/forest_clean.png")
Expand All @@ -152,9 +155,9 @@ def run_stego_interfacer():
img = torch.from_numpy(np_img).to(device)
img = img.permute(2, 0, 1)
img = (img.type(torch.float32) / 255)[None]
img = F.interpolate(img, scale_factor=0.25)
img = F.interpolate(img, scale_factor=0.5)

with Timer(f"Stego (input {si.input_size}"):
with Timer(f"Stego input {si.input_size}"):
linear_pred, cluster_pred = si.inference(img)

# Plot result as in colab
Expand Down

0 comments on commit 0b67ed7

Please sign in to comment.