Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Additional Pre and Post Processors #727

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions frontends/concrete-python/concrete/fhe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Value,
)
from .compilation.decorators import circuit, compiler
from .dtypes import Integer
from .extensions import (
AutoRounder,
AutoTruncator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,8 @@ class Configuration:
relu_on_bits_threshold: int
relu_on_bits_chunk_size: int
if_then_else_chunk_size: int
additional_processors: List[GraphProcessor]
additional_pre_processors: List[GraphProcessor]
additional_post_processors: List[GraphProcessor]
rounding_exactness: Exactness
approximate_rounding_config: ApproximateRoundingConfig

Expand Down Expand Up @@ -1040,7 +1041,8 @@ def __init__(
relu_on_bits_threshold: int = 7,
relu_on_bits_chunk_size: int = 3,
if_then_else_chunk_size: int = 3,
additional_processors: Optional[List[GraphProcessor]] = None,
additional_pre_processors: Optional[List[GraphProcessor]] = None,
additional_post_processors: Optional[List[GraphProcessor]] = None,
rounding_exactness: Exactness = Exactness.EXACT,
approximate_rounding_config: Optional[ApproximateRoundingConfig] = None,
):
Expand Down Expand Up @@ -1125,7 +1127,12 @@ def __init__(
self.relu_on_bits_threshold = relu_on_bits_threshold
self.relu_on_bits_chunk_size = relu_on_bits_chunk_size
self.if_then_else_chunk_size = if_then_else_chunk_size
self.additional_processors = [] if additional_processors is None else additional_processors
self.additional_pre_processors = (
[] if additional_pre_processors is None else additional_pre_processors
)
self.additional_post_processors = (
[] if additional_post_processors is None else additional_post_processors
)
self.rounding_exactness = rounding_exactness
self.approximate_rounding_config = (
approximate_rounding_config or ApproximateRoundingConfig()
Expand Down Expand Up @@ -1190,7 +1197,8 @@ def fork(
relu_on_bits_threshold: Union[Keep, int] = KEEP,
relu_on_bits_chunk_size: Union[Keep, int] = KEEP,
if_then_else_chunk_size: Union[Keep, int] = KEEP,
additional_processors: Union[Keep, Optional[List[GraphProcessor]]] = KEEP,
additional_pre_processors: Union[Keep, Optional[List[GraphProcessor]]] = KEEP,
additional_post_processors: Union[Keep, Optional[List[GraphProcessor]]] = KEEP,
rounding_exactness: Union[Keep, Exactness] = KEEP,
approximate_rounding_config: Union[Keep, Optional[ApproximateRoundingConfig]] = KEEP,
) -> "Configuration":
Expand Down Expand Up @@ -1224,16 +1232,17 @@ def _validate(self):
if name in already_checked_by_parse_methods:
continue

if name == "additional_processors":
valid = isinstance(self.additional_processors, list)
if name in ["additional_pre_processors", "additional_post_processors"]:
attr = getattr(self, name)
valid = isinstance(attr, list)
if valid:
for processor in self.additional_processors:
for processor in attr:
valid = valid and isinstance(processor, GraphProcessor)
if not valid:
hint_type = friendly_type_format(hint)
value_type = friendly_type_format(type(self.additional_processors))
value_type = friendly_type_format(type(attr))
message = (
f"Unexpected type for keyword argument 'additional_processors' "
f"Unexpected type for keyword argument '{name}' "
f"(expected '{hint_type}', got '{value_type}')"
)
raise TypeError(message)
Expand Down
34 changes: 19 additions & 15 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,25 @@ def process(self, graph: Graph, configuration: Configuration):
configuration to use
"""

pipeline = [
CheckIntegerOnly(),
AssignBitWidths(
single_precision=configuration.single_precision,
composable=configuration.composable,
comparison_strategy_preference=configuration.comparison_strategy_preference,
bitwise_strategy_preference=configuration.bitwise_strategy_preference,
shifts_with_promotion=configuration.shifts_with_promotion,
multivariate_strategy_preference=configuration.multivariate_strategy_preference,
min_max_strategy_preference=configuration.min_max_strategy_preference,
),
ProcessRounding(
rounding_exactness=configuration.rounding_exactness,
),
] + configuration.additional_processors
pipeline = (
configuration.additional_pre_processors
+ [
CheckIntegerOnly(),
AssignBitWidths(
single_precision=configuration.single_precision,
composable=configuration.composable,
comparison_strategy_preference=configuration.comparison_strategy_preference,
bitwise_strategy_preference=configuration.bitwise_strategy_preference,
shifts_with_promotion=configuration.shifts_with_promotion,
multivariate_strategy_preference=configuration.multivariate_strategy_preference,
min_max_strategy_preference=configuration.min_max_strategy_preference,
),
ProcessRounding(
rounding_exactness=configuration.rounding_exactness,
),
]
+ configuration.additional_post_processors
)

for processor in pipeline:
processor.apply(graph)
Expand Down
12 changes: 10 additions & 2 deletions frontends/concrete-python/tests/compilation/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,18 @@ def test_configuration_fork():
"'bad' is not a valid 'MinMaxStrategy' (one-tlu-promoted, three-tlu-casted, chunked)",
),
pytest.param(
{"additional_processors": "bad"},
{"additional_pre_processors": "bad"},
TypeError,
(
"Unexpected type for keyword argument 'additional_processors' "
"Unexpected type for keyword argument 'additional_pre_processors' "
"(expected 'Optional[List[GraphProcessor]]', got 'str')"
),
),
pytest.param(
{"additional_post_processors": "bad"},
TypeError,
(
"Unexpected type for keyword argument 'additional_post_processors' "
"(expected 'Optional[List[GraphProcessor]]', got 'str')"
),
),
Expand Down
27 changes: 22 additions & 5 deletions frontends/concrete-python/tests/representation/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,21 @@ def test_direct_graph_integer_range(helpers):
assert circuit.graph.integer_range() is None


def test_graph_processor(helpers):
def test_graph_processors(helpers):
"""
Test providing additional graph processors.
"""

class RecordInputBitWidth(fhe.GraphProcessor):
"""Sample graph processor to record the input bit width."""

input_bit_width: int = 0

def apply(self, graph: fhe.Graph):
assert len(graph.input_nodes) == 1
assert isinstance(graph.input_nodes[0].output.dtype, fhe.Integer)
self.input_bit_width = graph.input_nodes[0].output.dtype.bit_width

class CountNodes(fhe.GraphProcessor):
"""Sample graph processor to count nodes."""

Expand All @@ -427,10 +437,17 @@ class CountNodes(fhe.GraphProcessor):
def apply(self, graph: fhe.Graph):
self.node_count += len(graph.query_nodes())

processor = CountNodes()
configuration = helpers.configuration().fork(additional_processors=[processor])
pre_processor1 = RecordInputBitWidth()
post_processor1 = RecordInputBitWidth()
post_processor2 = CountNodes()
configuration = helpers.configuration().fork(
additional_pre_processors=[pre_processor1],
additional_post_processors=[post_processor1, post_processor2],
)

compiler = fhe.Compiler(lambda x: x**2, {"x": "encrypted"})
compiler = fhe.Compiler(lambda x: (x + 5) ** 2, {"x": "encrypted"})
compiler.compile(range(8), configuration)

assert processor.node_count == 3
assert pre_processor1.input_bit_width == 3
assert post_processor1.input_bit_width == 8 if configuration.single_precision else 4
assert post_processor2.node_count == 5
Loading