Skip to content

Commit

Permalink
feat(compiler/frontend): change partition operation
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed May 30, 2024
1 parent 61dd2f1 commit 0c87d72
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,25 @@ def FHE_ReinterpretPrecisionEintOp: FHE_Op<"reinterpret_precision", [Pure, Unary
let results = (outs FHE_AnyEncryptedInteger);
}

def FHE_ChangePartitionEintOp: FHE_Op<"change_partition", [Pure, UnaryEint]> {

let summary = "Change partition if necessary.";

let description = [{
Changing the partition of a ciphertext.
If necessary, it keyswitch the ciphertext to a different key having a different set of parameters than the original one.

Example:
```mlir
%new_eint = "FHE.change_partition"(%eint): (!FHE.eint<16>) -> (!FHE.eint<16>)
```
}];

let arguments = (ins FHE_AnyEncryptedInteger:$input);
let results = (outs FHE_AnyEncryptedInteger);
let hasVerifier = 1;
}

// FHE Boolean Operations

def FHE_GenGateOp : FHE_Op<"gen_gate", [Pure]> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,24 @@ def FHELinalg_ReinterpretPrecisionEintOp: FHELinalg_Op<"reinterpret_precision",
let hasVerifier = 1;
}

def FHELinalg_ChangePartitionEintOp: FHELinalg_Op<"change_partition", [Pure, TensorUnaryEint, UnaryEint]> {

let summary = "Change partition if necessary.";

let description = [{
Changing the partition of a ciphertext.
If necessary, it keyswitch the ciphertext to a different key having a different set of parameters than the original one.

Example:
```mlir
%new_eint = "FHE.change_partition"(%eint): (tensor<2x3x!FHE.eint<16>>) -> (tensor<2x3x!FHE.eint<16>>)
}];

let arguments = (ins Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$output);
let hasVerifier = 1;
}

def FHELinalg_FancyIndexOp : FHELinalg_Op<"fancy_index", [Pure]> {
let summary = "Index into a tensor using a tensor of indices.";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2228,6 +2228,24 @@ struct FancyAssignToSfcForall
};
};

// This operation should be used by the optimizer in multi-parameters, then
// removed. Its presence may indicate that mono-parameters might have been used.
// This patterns just hint for a potential fix.
struct ChangePartitionOpPattern
: public mlir::OpRewritePattern<FHELinalg::ChangePartitionEintOp> {
ChangePartitionOpPattern(mlir::MLIRContext *context)
: mlir::OpRewritePattern<FHELinalg::ChangePartitionEintOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}

mlir::LogicalResult
matchAndRewrite(FHELinalg::ChangePartitionEintOp op,
mlir::PatternRewriter &rewriter) const override {
op.emitError(llvm::Twine("change_partition shouldn't be present at this "
"level. Maybe you didn't use multi-parameters?"));
return mlir::failure();
};
};

namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
Expand Down Expand Up @@ -2414,6 +2432,7 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<TensorPartitionFrontierOpToLinalgGeneric>(&getContext());
patterns.insert<FancyIndexToTensorGenerate>(&getContext());
patterns.insert<FancyAssignToSfcForall>(&getContext());
patterns.insert<ChangePartitionOpPattern>(&getContext());

if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())
Expand Down
11 changes: 11 additions & 0 deletions compilers/concrete-compiler/compiler/lib/Dialect/FHE/IR/FHEOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,17 @@ mlir::LogicalResult RoundEintOp::verify() {
return mlir::success();
}

mlir::LogicalResult ChangePartitionEintOp::verify() {
auto input = this->getInput().getType().cast<FheIntegerInterface>();
auto output = this->getResult().getType().cast<FheIntegerInterface>();

if (verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(),
input, output)) {
return mlir::success();
}
return mlir::failure();
}

OpFoldResult RoundEintOp::fold(FoldAdaptor operands) {
auto input = this->getInput();
auto inputTy = input.getType().dyn_cast_or_null<FheIntegerInterface>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ struct OptimizerPartitionFrontierMaterializationPass
mlir::func::FuncOp func = this->getOperation();

func.walk([&](mlir::Operation *producer) {
mlir::IRRewriter rewriter(producer->getContext());

// Remove the change_partition op.
// TODO: The crypto parameters used in the op should be considered before
// removal
if (mlir::dyn_cast_or_null<FHELinalg::ChangePartitionEintOp>(producer) ||
mlir::dyn_cast_or_null<FHE::ChangePartitionEintOp>(producer)) {
rewriter.startRootUpdate(func);
rewriter.replaceOp(producer, producer->getOperand(0));
rewriter.finalizeRootUpdate(func);
return;
}

std::optional<uint64_t> producerOid =
getOid(producer, OperationKind::PRODUCER);

Expand All @@ -81,7 +94,6 @@ struct OptimizerPartitionFrontierMaterializationPass
solverSolution.circuit_keys.conversion_keyswitch_keys[eck[0]]
.output_key.identifier;

mlir::IRRewriter rewriter(producer->getContext());
rewriter.setInsertionPointAfter(producer);

for (mlir::Value res : producer->getResults()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1454,6 +1454,33 @@ mlir::LogicalResult ReinterpretPrecisionEintOp::verify() {
return mlir::success();
}

mlir::LogicalResult ChangePartitionEintOp::verify() {
auto inputType =
this->getInput().getType().dyn_cast_or_null<mlir::RankedTensorType>();
auto outputType =
this->getOutput().getType().dyn_cast_or_null<mlir::RankedTensorType>();

auto inputShape = inputType.getShape();
auto outputShape = outputType.getShape();

if (inputShape != outputShape) {
this->emitOpError()
<< "input and output tensors should have the same shape";
return mlir::failure();
}

auto inputElementType =
inputType.getElementType().cast<FHE::FheIntegerInterface>();
auto outputElementType =
outputType.getElementType().cast<FHE::FheIntegerInterface>();
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
*this->getOperation(), inputElementType, outputElementType)) {
return mlir::failure();
}

return mlir::success();
}

mlir::LogicalResult FancyIndexOp::verify() {
auto inputType =
this->getInput().getType().dyn_cast_or_null<mlir::RankedTensorType>();
Expand Down
7 changes: 7 additions & 0 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4048,4 +4048,11 @@ def zeros(self, resulting_type: ConversionType) -> Conversion:
original_bit_width=1,
)

def change_partition(self, x: Conversion) -> Conversion:
# TODO: get parameters and set them as attr
assert x.is_encrypted
dialect = fhe if x.is_scalar else fhelinalg
operation = dialect.ChangePartitionEintOp
return self.operation(operation, x.type, x.result)

# pylint: enable=missing-function-docstring
7 changes: 6 additions & 1 deletion frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,9 @@ def tfhers_to_native(self, ctx: Context, node: Node, preds: List[Conversion]) ->
dtype.msg_width,
)

# TODO: use parameters to change partition
tfhers_int = ctx.change_partition(tfhers_int)

# number of ciphertexts representing a single integer
num_cts = tfhers_int.shape[-1]
# first table maps to the lsb, and last one maps to the msb
Expand Down Expand Up @@ -913,6 +916,8 @@ def tfhers_from_native(self, ctx: Context, node: Node, preds: List[Conversion])
]

# we are extracting lsb first so we reverse it so we have msb first
return ctx.concatenate(result_type, extracted_bits[::-1], axis=-1)
result = ctx.concatenate(result_type, extracted_bits[::-1], axis=-1)
# TODO: use specified parameters
return ctx.change_partition(result)

# pylint: enable=missing-function-docstring,unused-argument
31 changes: 24 additions & 7 deletions frontends/concrete-python/concrete/fhe/tfhers/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,47 @@
from ..dtypes import Integer


class TFHERSParams:
"""Crypto parameters used for a tfhers integer."""

pass


class TFHERSIntegerType(Integer):
"""
TFHERSIntegerType (Subclass of Integer) to represent tfhers integer types.
"""

carry_width: int
msg_width: int

def __init__(self, is_signed: bool, bit_width: int, carry_width: int, msg_width: int):
params: TFHERSParams

def __init__(
self,
is_signed: bool,
bit_width: int,
carry_width: int,
msg_width: int,
params: TFHERSParams,
):
super().__init__(is_signed, bit_width)
self.carry_width = carry_width
self.msg_width = msg_width
self.params = params

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, self.__class__)
and super().__eq__(other)
and self.carry_width == other.carry_width
and self.msg_width == other.msg_width
and self.params == other.params
)

def __str__(self) -> str:
return (
f"tfhers<{('int' if self.is_signed else 'uint')}"
f"{self.bit_width}, {self.carry_width}, {self.msg_width}>"
f"{self.bit_width}, {self.carry_width}, {self.msg_width}, {self.params}>"
)

def encode(self, value: Union[int, np.integer, np.ndarray]) -> np.ndarray:
Expand Down Expand Up @@ -97,7 +113,8 @@ def decode(self, value: np.ndarray) -> Union[int, np.ndarray]:
int16 = partial(TFHERSIntegerType, True, 16)
uint16 = partial(TFHERSIntegerType, False, 16)

int8_2_2 = int8(2, 2)
uint8_2_2 = uint8(2, 2)
int16_2_2 = int16(2, 2)
uint16_2_2 = uint16(2, 2)
# TODO: make these partials as well, so that params have to be specified
int8_2_2 = int8(2, 2, TFHERSParams())
uint8_2_2 = uint8(2, 2, TFHERSParams())
int16_2_2 = int16(2, 2, TFHERSParams())
uint16_2_2 = uint16(2, 2, TFHERSParams())

0 comments on commit 0c87d72

Please sign in to comment.