Skip to content

Commit

Permalink
Merge pull request #261 from synsense/develop
Browse files Browse the repository at this point in the history
v2.0.1
  • Loading branch information
ssinhaleite authored Jan 23, 2025
2 parents dbc3421 + ace99fd commit c0fce4e
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 30 deletions.
3 changes: 3 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
CHANGES
=======

* Spike count plot in 'DynapcnnVisualizer' is optional
* `DynapcnnVisualizer` allows custom JIT filters to make readout predictions

v2.0.0
------

Expand Down
2 changes: 1 addition & 1 deletion docs/speck/notebooks/nmnist_quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Covert CNN To SNN"
"### Convert CNN To SNN"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion docs/speck/visualizer.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ hardware_compatible_model.to(
In order to visualize the class outputs as images, we need to get the images. The images should be passed in the same order as the output layer of the network. Important! <br>
- If you want to visualize power measurements during streaming inference, set `add_power_monitor_plot`=`True`.
- If you want to visualize readout images as class predictions during streaming you need to pass `add_readout_plot`=`True`.
- If you don't want to visualize spike counts of output classes as line graphs over time during streaming you need to pass `add_spike_count_plot`=`False`.
- In order to show a prediction for each `N` milliseconds, set the parameter `spike_collection_interval`=`N`.
- In order to show the images, the paths of these images should be passed to `readout_images` parameter.
- In order to show a prediction only if there are more than a `threshold` number of events from that output, set the `readout_prediction_threshold`=`threshold`.
Expand Down Expand Up @@ -172,4 +173,4 @@ The example script that runs the visualizer can be found under `/examples/visual


#### MacOS users
Due to the difference in the behaviour of python's multiprocessing library on MacOS, you should run the `examples/visualizer/gesture_viz.py` script with `-i` flag. `python -i /examples/visualizer/gesture_viz.py` .
Due to the difference in the behaviour of python's multiprocessing library on MacOS, you should run the `examples/visualizer/gesture_viz.py` script with `-i` flag. `python -i /examples/visualizer/gesture_viz.py` .
7 changes: 4 additions & 3 deletions sinabs/backend/dynapcnn/dynapcnn_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,10 @@ def forward(self, x):
# Send input
self.samna_input_buffer.write(x)
received_evts = []
# Record at least until the last event has been replayed
min_duration = max(event.timestamp for event in x) * 1e-6
time.sleep(min_duration)

# Wait a minimum time to guarantee the events were played
time.sleep(1)

# Keep recording if more events are being registered
while True:
prev_length = len(received_evts)
Expand Down
74 changes: 51 additions & 23 deletions sinabs/backend/dynapcnn/dynapcnn_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import socket
import warnings
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union

import samna

Expand All @@ -15,6 +15,7 @@ class DynapcnnVisualizer:
# (tlx, tly, brx, bry)
DEFAULT_LAYOUT_DS = [(0, 0, 0.5, 1), (0.5, 0, 1, 1), None, None]
DEFAULT_LAYOUT_DSP = [(0, 0, 0.5, 0.66), (0.5, 0, 1, 0.66), None, (0, 0.66, 1, 1)]
DEFAULT_LAYOUT_DRP = [(0, 0, 0.5, 0.66), None, (0.5, 0, 1, 0.66), (0, 0.66, 1, 1)]
DEFAULT_LAYOUT_DSR = [(0, 0, 0.33, 1), (0.33, 0, 0.66, 1), (0.66, 0, 1, 1), None]

DEFAULT_LAYOUT_DSRP = [
Expand All @@ -27,6 +28,7 @@ class DynapcnnVisualizer:
LAYOUTS_DICT = {
"ds": DEFAULT_LAYOUT_DS,
"dsp": DEFAULT_LAYOUT_DSP,
"drp": DEFAULT_LAYOUT_DRP,
"dsr": DEFAULT_LAYOUT_DSR,
"dsrp": DEFAULT_LAYOUT_DSRP,
}
Expand All @@ -36,6 +38,7 @@ def __init__(
window_scale: Tuple[int, int] = (4, 8),
dvs_shape: Tuple[int, int] = (128, 128), # height, width
add_readout_plot: bool = False,
add_spike_count_plot: bool = True,
add_power_monitor_plot: bool = False,
spike_collection_interval: int = 500,
readout_prediction_threshold: int = 10,
Expand All @@ -46,6 +49,7 @@ def __init__(
feature_names: Optional[List[str]] = None,
readout_images: Optional[List[str]] = None,
feature_count: Optional[int] = None,
readout_node: Union[str, Callable] = "JitMajorityReadout",
extra_arguments: Optional[Dict[str, Dict[str, any]]] = None,
):
"""Quick wrapper around Samna objects to get a basic dynapcnn visualizer.
Expand All @@ -58,6 +62,10 @@ def __init__(
Defaults to (128, 128) -- Speck sensor resolution.
add_readout_plot: bool (defaults to False)
If set true adds a readout plot to the GUI
It displays an icon for the currently predicted class.
add_spike_count_plot: bool (defaults to True)
If set true adds a spike count plot to the GUI.
A line chart indicating the number of spikes over time.
add_power_monitor_plot: bool (defaults to False)
If set true adds a power monitor plot to the GUI.
spike_collection_interval: int (defaults to 500) (in milliseconds)
Expand Down Expand Up @@ -91,14 +99,17 @@ def __init__(
If the `feature_names` and `readout_images` was passed, this is not needed. Otherwise this parameter
should be passed, so that the GUI knows how many lines should be drawn on the `Spike Count Plot` and
`Readout Layer Plot`.
readout_node: str or Callable
Can either be a string "JitMajorityReadout" or a callable that returns a samna JIT filter
to decide on the readout prediction. Function parameters can be defined freely.
extra_arguments: Optional[Dict[str, Dict[str, any]]] (defaults to None)
Extra arguments that can be passed to individual plots. Available keys are:
- `spike_count`: Arguments that can be passed to `spike_count` plot.
- `readout`: Arguments that can be passed to `readout` plot.
- `power_measurement`: Arguments that can be passed `power_measurement` plot.
"""
# Checks if the configuration passed is valid
if add_readout_plot and not readout_images:
if add_readout_plot and readout_images is None:
raise ValueError(
"If a readout plot is to be displayed image paths should be passed as a list."
+ "The order of the images, should match the model output."
Expand All @@ -112,7 +123,9 @@ def __init__(
self.dvs_shape = dvs_shape

# Modify the GUI type based on the parameters
self.gui_type = "ds"
self.gui_type = "d"
if add_spike_count_plot:
self.gui_type += "s"
if add_readout_plot:
self.gui_type += "r"
if add_power_monitor_plot:
Expand All @@ -126,6 +139,7 @@ def __init__(
self.readout_default_return_value = readout_default_return_value
self.readout_default_threshold_low = readout_default_threshold_low
self.readout_default_threshold_high = readout_default_threshold_high
self.readout_node = readout_node

# Power monitor components
if power_monitor_number_of_items != 3 and power_monitor_number_of_items != 5:
Expand Down Expand Up @@ -338,13 +352,14 @@ def create_plots(self):
plots = []

plots.append(self.add_dvs_plot(shape=self.dvs_shape, layout=layout[0]))
if self.extra_arguments and "spike_count" in self.extra_argument.keys():
spike_count_plot_args = self.extra_arguments["spike_count"]
else:
spike_count_plot_args = {}
plots.append(
self.add_spike_count_plot(layout=layout[1], **spike_count_plot_args)
)
if "s" in self.gui_type:
if self.extra_arguments and "spike_count" in self.extra_argument.keys():
spike_count_plot_args = self.extra_arguments["spike_count"]
else:
spike_count_plot_args = {}
plots.append(
self.add_spike_count_plot(layout=layout[1], **spike_count_plot_args)
)
if "r" in self.gui_type:
try:
if self.extra_arguments and "readout" in self.extra_arguments.keys():
Expand Down Expand Up @@ -508,19 +523,32 @@ def connect(

## Readout node
if "r" in self.gui_type:
(_, majority_readout_node, _) = self.streamer_graph.sequential(
[
spike_collection_node,
samna.graph.JitMajorityReadout(samna.ui.Event),
streamer_node,
]
)
majority_readout_node.set_feature_count(self.feature_count)
majority_readout_node.set_default_feature(self.readout_default_return_value)
majority_readout_node.set_threshold_low(self.readout_default_threshold_low)
majority_readout_node.set_threshold_high(
self.readout_default_threshold_high
)
if self.readout_node == "JitMajorityReadout":
(_, majority_readout_node, _) = self.streamer_graph.sequential(
[
spike_collection_node,
samna.graph.JitMajorityReadout(samna.ui.Event),
streamer_node,
]
)
majority_readout_node.set_feature_count(self.feature_count)
majority_readout_node.set_default_feature(
self.readout_default_return_value
)
majority_readout_node.set_threshold_low(
self.readout_default_threshold_low
)
majority_readout_node.set_threshold_high(
self.readout_default_threshold_high
)
else:
(_, majority_readout_node, _) = self.streamer_graph.sequential(
[
spike_collection_node,
self.readout_node,
streamer_node,
]
)

## Readout layer visualization
if "o" in self.gui_type:
Expand Down
88 changes: 88 additions & 0 deletions tests/test_dynapcnn/custom_jit_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Optional

import samna


def majority_readout_filter(
feature_count: int,
default_feature: Optional[int] = None,
detection_threshold: int = 0,
threshold_low: int = 0,
threshold_high: Optional[int] = None,
):
"""
The default reaodut filter of samna's visualizer counts the total
number of events received per timestep to decide whether a detection
should be made or not.
The filter defined here allows for an additional `detection_threshold`
parameter which is compared to the number of spikes of the most
active class.
In other words, for a class to be detected, there needs to be
a minimum number of spikes for this class.
"""

jit_src = f"""
using InputT = speck2f::event::Spike;
using OutputT = ui::Event;
using ReadoutT = ui::Readout;
template<typename Spike>
class CustomMajorityReadout : public iris::FilterInterface<std::shared_ptr<const std::vector<Spike>>, std::shared_ptr<const std::vector<OutputT>>> {{
private:
int featureCount = {feature_count};
uint32_t defaultFeature = {default_feature if default_feature is not None else feature_count};
int detectionThreshold = {detection_threshold};
int thresholdLow = {threshold_low};
int thresholdHigh = {threshold_high if threshold_high is not None else "std::numeric_limits<int>::max()"};
public:
void apply() override
{{
while (const auto maybeSpikesPtr = this->receiveInput()) {{
if (0 == featureCount) {{
return;
}}
auto outputCollection = std::make_shared<std::vector<OutputT>>();
if ((*maybeSpikesPtr)->size() >= thresholdLow && (*maybeSpikesPtr)->size() <= thresholdHigh) {{
std::unordered_map<uint32_t, int> sum; // feature -> count
int maxCount = 0;
uint32_t maxCountFeature = 0;
int maxCountNum = 0;
for (const auto& spike : (**maybeSpikesPtr)) {{
sum[spike.feature]++;
}}
for (const auto& [feature, count] : sum) {{
if (feature >= featureCount) {{
continue;
}}
if (count > maxCount) {{
maxCount = count;
maxCountFeature = feature;
maxCountNum = 1;
}}
else if (count == maxCount) {{
maxCountNum++;
}}
}}
if (maxCount > detectionThreshold && 1 == maxCountNum) {{
outputCollection->emplace_back(ReadoutT{{maxCountFeature}});
}}
else {{
outputCollection->emplace_back(ReadoutT{{defaultFeature}});
}}
}}
else {{
outputCollection->emplace_back(ReadoutT{{defaultFeature}});
}}
this->forwardResult(std::move(outputCollection));
}}
}}
}};
"""
return samna.graph.JitFilter("CustomMajorityReadout", jit_src)
23 changes: 21 additions & 2 deletions tests/test_dynapcnn/test_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from itertools import product
from typing import Callable, Union

import pytest
import samna
from custom_jit_filters import majority_readout_filter as custom_filter
from hw_utils import find_open_devices, is_any_samna_device_connected

from sinabs.backend.dynapcnn.dynapcnn_visualizer import DynapcnnVisualizer
Expand All @@ -13,17 +17,32 @@ def X_available() -> bool:
return p.returncode == 0


vis_init_args = product(
(True, False),
(True, False),
("JitMajorityReadout", custom_filter),
)


@pytest.mark.skipif(
True,
reason="A window needs to pop. Needs UI. Makes sense to check this test manually",
)
def test_visualizer_initialization():
@pytest.mark.parametrize("spike_count_plot,readout_plot,readout_node", vis_init_args)
def test_visualizer_initialization(
spike_count_plot: bool, readout_plot: bool, readout_node: Union[str, Callable]
):
dvs_shape = (128, 128)
spike_collection_interval = 500
visualizer_id = 3

visualizer = DynapcnnVisualizer(
dvs_shape=dvs_shape, spike_collection_interval=spike_collection_interval
dvs_shape=dvs_shape,
spike_collection_interval=spike_collection_interval,
add_spike_count_plot=spike_count_plot,
add_readout_plot=readout_plot,
readout_node=readout_node,
readout_images=[],
)
visualizer.create_visualizer_process(
f"tcp://0.0.0.0:{visualizer.samna_visualizer_port}"
Expand Down

0 comments on commit c0fce4e

Please sign in to comment.