diff --git a/RELEASE.md b/RELEASE.md index 3ea56327ea..67e2c5b057 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,8 @@ ## Major Features and Improvements * Added an implementation of `__eq__()` on `building blocks`. +* Added a new field, `content`, to the `Data` building block and updated + tests. ## Bug Fixes diff --git a/tensorflow_federated/python/core/backends/mapreduce/BUILD b/tensorflow_federated/python/core/backends/mapreduce/BUILD index d3f22fc9c8..8a5eb60425 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/BUILD +++ b/tensorflow_federated/python/core/backends/mapreduce/BUILD @@ -60,6 +60,7 @@ py_test( ":mapreduce_test_utils", "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/core/impl/compiler:building_block_factory", + "//tensorflow_federated/python/core/impl/compiler:building_block_test_utils", "//tensorflow_federated/python/core/impl/compiler:building_blocks", "//tensorflow_federated/python/core/impl/compiler:intrinsic_defs", "//tensorflow_federated/python/core/impl/compiler:tensorflow_computation_factory", diff --git a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py index baca90bbe7..4b7754909b 100644 --- a/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py +++ b/tensorflow_federated/python/core/backends/mapreduce/compiler_test.py @@ -21,6 +21,7 @@ from tensorflow_federated.python.core.backends.mapreduce import form_utils from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils from tensorflow_federated.python.core.impl.compiler import building_block_factory +from tensorflow_federated.python.core.impl.compiler import building_block_test_utils from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import intrinsic_defs from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory @@ -583,8 +584,11 @@ def test_compiles_lambda_under_federated_comp_to_tf(self): identity_lambda = building_blocks.Lambda( ref_to_x.name, ref_to_x.type_signature, ref_to_x ) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array(1) + ) federated_data = building_blocks.Data( - 'a', + any_proto, computation_types.FederatedType( computation_types.StructType([np.int32, np.float32]), placements.SERVER, diff --git a/tensorflow_federated/python/core/impl/compiler/BUILD b/tensorflow_federated/python/core/impl/compiler/BUILD index 8ea57a3c2b..da5140a598 100644 --- a/tensorflow_federated/python/core/impl/compiler/BUILD +++ b/tensorflow_federated/python/core/impl/compiler/BUILD @@ -158,6 +158,7 @@ py_library( "//tensorflow_federated/python/core/impl/types:type_analysis", "//tensorflow_federated/python/core/impl/types:type_serialization", "//tensorflow_federated/python/core/impl/types:typed_object", + "@com_google_protobuf//:protobuf_python", ], ) @@ -166,6 +167,7 @@ py_test( size = "small", srcs = ["building_blocks_test.py"], deps = [ + ":array", ":building_block_test_utils", ":building_blocks", ":computation_factory", @@ -177,6 +179,7 @@ py_test( "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:placements", "//tensorflow_federated/python/core/impl/types:type_serialization", + "@com_google_protobuf//:protobuf_python", ], ) diff --git a/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py b/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py index e0e23fade3..acc1e6a311 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py +++ b/tensorflow_federated/python/core/impl/compiler/building_block_test_utils.py @@ -323,20 +323,27 @@ def create_whimsy_called_federated_sum( return building_block_factory.create_federated_sum(value) -def create_whimsy_called_sequence_map(parameter_name, parameter_type=np.int32): +def create_whimsy_called_sequence_map( + # TODO: b/338284242 - Remove any proto from function constructor once the + # compiler tests no longer support string equality. + parameter_name, + parameter_type=np.int32, + any_proto=any_pb2.Any(), +): r"""Returns a whimsy called sequence map. Call / \ - sequence_map data + sequence_map Data(id) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. + any_proto: The any proto to use for the data block. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.SequenceType(parameter_type) - arg = building_blocks.Data('data', arg_type) + arg = building_blocks.Data(any_proto, arg_type) return building_block_factory.create_sequence_map(fn, arg) diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks.py b/tensorflow_federated/python/core/impl/compiler/building_blocks.py index 6bb5ad28d8..97d6f65e3e 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks.py +++ b/tensorflow_federated/python/core/impl/compiler/building_blocks.py @@ -22,6 +22,7 @@ import numpy as np +from google.protobuf import any_pb2 from tensorflow_federated.proto.v0 import computation_pb2 as pb from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure @@ -984,7 +985,7 @@ def __repr__(self) -> str: class Data(ComputationBuildingBlock): """A representation of data (an input pipeline). - This class does not deal with parsing data URIs and verifying correctness, + This class does not deal with parsing data protos and verifying correctness, it is only a container. Parsing and type analysis are a responsibility or a component external to this module. """ @@ -993,15 +994,15 @@ class Data(ComputationBuildingBlock): def from_proto(cls, computation_proto: pb.Computation) -> 'Data': _check_computation_oneof(computation_proto, 'data') return cls( - computation_proto.data.uri, + computation_proto.data.content, type_serialization.deserialize_type(computation_proto.type), ) - def __init__(self, uri: str, type_spec: object): + def __init__(self, content: any_pb2.Any, type_spec: object): """Creates a representation of data. Args: - uri: The URI that characterizes the data. + content: The proto that characterizes the data. type_spec: Either the types.Type that represents the type of this data, or something convertible to it by types.to_type(). @@ -1009,21 +1010,19 @@ def __init__(self, uri: str, type_spec: object): TypeError: if the arguments are of the wrong types. ValueError: if the user tries to specify an empty URI. """ - py_typecheck.check_type(uri, str) - if not uri: - raise ValueError('Empty string cannot be passed as URI to Data.') + py_typecheck.check_type(content, any_pb2.Any) if type_spec is None: raise TypeError( - 'Intrinsic {} cannot be created without a TFF type.'.format(uri) + 'Intrinsic {} cannot be created without a TFF type.'.format(content) ) type_spec = computation_types.to_type(type_spec) super().__init__(type_spec) - self._uri = uri + self._content = content def _proto(self) -> pb.Computation: return pb.Computation( type=type_serialization.serialize_type(self.type_signature), - data=pb.Data(uri=self._uri), + data=pb.Data(content=self._content), ) def children(self) -> Iterator[ComputationBuildingBlock]: @@ -1031,8 +1030,8 @@ def children(self) -> Iterator[ComputationBuildingBlock]: return iter(()) @property - def uri(self) -> str: - return self._uri + def content(self) -> any_pb2.Any: + return self._content def __eq__(self, other: object) -> bool: if self is other: @@ -1040,20 +1039,20 @@ def __eq__(self, other: object) -> bool: elif not isinstance(other, Data): return NotImplemented return ( - self._uri, + id(self._content), self._type_signature, ) == ( - other._uri, + id(other._content), other._type_signature, ) def __hash__(self): if self._cached_hash is None: - self._cached_hash = hash((self._uri, self._type_signature)) + self._cached_hash = hash((id(self._content), self._type_signature)) return self._cached_hash def __repr__(self) -> str: - return "Data('{}', {!r})".format(self._uri, self.type_signature) + return 'Data({}, {!r})'.format(id(self._content), self.type_signature) class CompiledComputation(ComputationBuildingBlock): @@ -1415,7 +1414,7 @@ def _lines_for_comp(comp, formatted): elif isinstance(comp, CompiledComputation): return ['comp#{}'.format(comp.name)] elif isinstance(comp, Data): - return [comp.uri] + return [str(id(comp.content))] elif isinstance(comp, Intrinsic): return [comp.uri] elif isinstance(comp, Lambda): @@ -1679,7 +1678,7 @@ def _get_node_label(comp): elif isinstance(comp, CompiledComputation): return 'Compiled({})'.format(comp.name) elif isinstance(comp, Data): - return comp.uri + return f'Data({id(comp.content)})' elif isinstance(comp, Intrinsic): return comp.uri elif isinstance(comp, Lambda): diff --git a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py b/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py index 34a67edfef..1249064602 100644 --- a/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py +++ b/tensorflow_federated/python/core/impl/compiler/building_blocks_test.py @@ -17,10 +17,12 @@ import numpy as np import tree +from google.protobuf import any_pb2 from tensorflow_federated.proto.v0 import array_pb2 from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.proto.v0 import data_type_pb2 from tensorflow_federated.python.common_libs import structure +from tensorflow_federated.python.core.impl.compiler import array from tensorflow_federated.python.core.impl.compiler import building_block_test_utils from tensorflow_federated.python.core.impl.compiler import building_blocks from tensorflow_federated.python.core.impl.compiler import computation_factory @@ -411,25 +413,38 @@ def test_intrinsic_class_succeeds_simple_federated_map(self): self._serialize_deserialize_roundtrip_test(concrete_federated_map) def test_basic_functionality_of_data_class(self): + test_proto = array.to_proto(np.array([1, 2, 3], np.int32)) + any_proto = any_pb2.Any() + any_proto.Pack(test_proto) x = building_blocks.Data( - '/tmp/mydata', computation_types.SequenceType(np.int32) + any_proto, computation_types.SequenceType(np.int32) ) self.assertEqual(str(x.type_signature), 'int32*') - self.assertEqual(x.uri, '/tmp/mydata') + self.assertEqual(x.content, any_proto) + arr = array_pb2.Array() + x.content.Unpack(arr) + self.assertEqual(arr, test_proto) + as_string = str(id(any_proto)) self.assertEqual( - repr(x), "Data('/tmp/mydata', SequenceType(TensorType(np.int32)))" + repr(x), f'Data({as_string}, SequenceType(TensorType(np.int32)))' + ) + self.assertEqual( + x.compact_representation(), + as_string, ) - self.assertEqual(x.compact_representation(), '/tmp/mydata') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature ) self.assertEqual(x_proto.WhichOneof('computation'), 'data') - self.assertEqual(x_proto.data.uri, x.uri) + self.assertEqual(x_proto.data.content, x.content) self._serialize_deserialize_roundtrip_test(x) def test_data_children_is_empty(self): - data = building_blocks.Data('a', np.int32) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array(1, np.int32) + ) + data = building_blocks.Data(any_proto, np.int32) self.assertEqual([], list(data.children())) def test_basic_functionality_of_compiled_computation_class(self): @@ -1433,57 +1448,75 @@ def test_hash_returns_different_value(self, intrinsic, other): class DataTest(parameterized.TestCase): def test_eq_returns_true(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3], np.int32) + ) type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data('data', type_signature) - other = building_blocks.Data('data', type_signature) + data = building_blocks.Data(any_proto, type_signature) + other = building_blocks.Data(any_proto, type_signature) self.assertIsNot(data, other) self.assertEqual(data, other) - @parameterized.named_parameters( - ( - 'different_uri', - building_blocks.Data('data', computation_types.TensorType(np.int32)), - building_blocks.Data( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Data('data', computation_types.TensorType(np.int32)), - building_blocks.Data( - 'data', computation_types.TensorType(np.float32) - ), - ), - ) - def test_eq_returns_false(self, data, other): + def test_eq_returns_false_different_content(self): + any_proto1 = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3], np.int32) + ) + type_signature = computation_types.TensorType(np.int32) + data = building_blocks.Data(any_proto1, type_signature) + + any_proto2 = building_block_test_utils.create_any_proto_from_array( + np.array([4], np.int32) + ) + other = building_blocks.Data(any_proto2, type_signature) + self.assertIsNot(data, other) + self.assertNotEqual(data, other) + + def test_eq_returns_false_different_type_signatures(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 1, 1], np.int32) + ) + type_signature1 = computation_types.TensorType(np.int32) + type_signature2 = computation_types.TensorType(np.float32) + data = building_blocks.Data(any_proto, type_signature1) + other = building_blocks.Data(any_proto, type_signature2) + self.assertIsNot(data, other) self.assertNotEqual(data, other) def test_hash_returns_same_value(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3], np.int32) + ) type_signature = computation_types.TensorType(np.int32) - data = building_blocks.Data('data', type_signature) - other = building_blocks.Data('data', type_signature) + data = building_blocks.Data(any_proto, type_signature) + other = building_blocks.Data(any_proto, type_signature) self.assertEqual(hash(data), hash(other)) - @parameterized.named_parameters( - ( - 'different_uri', - building_blocks.Data('data', computation_types.TensorType(np.int32)), - building_blocks.Data( - 'different', computation_types.TensorType(np.int32) - ), - ), - ( - 'different_type_signature', - building_blocks.Data('data', computation_types.TensorType(np.int32)), - building_blocks.Data( - 'data', computation_types.TensorType(np.float32) - ), - ), - ) - def test_hash_returns_different_value(self, data, other): + def test_hash_returns_different_value_for_different_content(self): + any_proto1 = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3], np.int32) + ) + type_signature = computation_types.TensorType(np.int32) + data = building_blocks.Data(any_proto1, type_signature) + + any_proto2 = building_block_test_utils.create_any_proto_from_array( + np.array([4], np.int32) + ) + other = building_blocks.Data(any_proto2, type_signature) + self.assertNotEqual(data, other) + self.assertNotEqual(hash(data), hash(other)) + + def test_hash_returns_different_value_for_different_type_signatures(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 1, 1], np.int32) + ) + type_signature1 = computation_types.TensorType(np.int32) + type_signature2 = computation_types.TensorType(np.float32) + data = building_blocks.Data(any_proto, type_signature1) + other = building_blocks.Data(any_proto, type_signature2) + self.assertNotEqual(data, other) self.assertNotEqual(hash(data), hash(other)) @@ -2313,11 +2346,15 @@ def test_returns_string_for_compiled_computation(self): self.assertEqual(comp.structural_representation(), 'Compiled(a)') def test_returns_string_for_data(self): - comp = building_blocks.Data('data', np.int32) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3], np.int32) + ) + comp = building_blocks.Data(any_proto, np.int32) - self.assertEqual(comp.compact_representation(), 'data') - self.assertEqual(comp.formatted_representation(), 'data') - self.assertEqual(comp.structural_representation(), 'data') + expected = str(id(any_proto)) + self.assertEqual(comp.compact_representation(), expected) + self.assertEqual(comp.formatted_representation(), expected) + self.assertEqual(comp.structural_representation(), f'Data({expected})') def test_returns_string_for_intrinsic(self): comp_type = computation_types.TensorType(np.int32) diff --git a/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py b/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py index d7390e64c7..8b77582d58 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py +++ b/tensorflow_federated/python/core/impl/compiler/transformation_utils_test.py @@ -94,13 +94,16 @@ def __str__(self): def _construct_trivial_instance_of_all_computation_building_blocks(): cbb_list = [] + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array(1, np.int32) + ) ref_to_x = building_blocks.Reference('x', np.int32) cbb_list.append(('reference', ref_to_x)) lam = building_blocks.Lambda('x', np.int32, ref_to_x) cbb_list.append(('lambda', lam)) block = building_blocks.Block([('x', ref_to_x)], lam) cbb_list.append(('block', block)) - data = building_blocks.Data('x', np.int32) + data = building_blocks.Data(any_proto, np.int32) cbb_list.append(('data', data)) function_type = computation_types.FunctionType(np.int32, np.int32) intrinsic = building_blocks.Intrinsic('whimsy_intrinsic', function_type) diff --git a/tensorflow_federated/python/core/impl/compiler/transformations_test.py b/tensorflow_federated/python/core/impl/compiler/transformations_test.py index ab77b68e1e..12a316ff71 100644 --- a/tensorflow_federated/python/core/impl/compiler/transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/transformations_test.py @@ -125,7 +125,10 @@ def test_creates_binding_for_each_call(self): int_type = computation_types.to_type(np.int32) int_to_int_type = computation_types.FunctionType(int_type, int_type) bb = building_blocks - int_to_int_fn = bb.Data('ext', int_to_int_type) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) + int_to_int_fn = bb.Data(any_proto, int_to_int_type) before = bb.Lambda( 'x', int_type, @@ -155,7 +158,10 @@ def test_evaluates_called_lambdas(self): int_to_int_type = computation_types.FunctionType(int_type, int_type) int_thunk_type = computation_types.FunctionType(None, int_type) bb = building_blocks - int_to_int_fn = bb.Data('ext', int_to_int_type) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) + int_to_int_fn = bb.Data(any_proto, int_to_int_type) # -> (let result = ext(x) in (-> result)) # Each call of the outer lambda should create a single binding, with @@ -222,7 +228,10 @@ def test_creates_block_for_non_lambda(self): [(None, int_type), (None, int_type)] ) get_two_int_type = computation_types.FunctionType(None, two_int_type) - call_ext = bb.Call(bb.Data('ext', get_two_int_type)) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) + call_ext = bb.Call(bb.Data(any_proto, get_two_int_type)) before = bb.Selection(call_ext, index=0) after = transformations.to_call_dominant(before) expected = bb.Block( @@ -376,9 +385,13 @@ def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self): first_broadcast = ( building_block_test_utils.create_whimsy_called_federated_broadcast() ) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) packed_broadcast = building_blocks.Struct([ building_blocks.Data( - 'a', computation_types.FederatedType(np.int32, placements.SERVER) + any_proto, + computation_types.FederatedType(np.int32, placements.SERVER), ), first_broadcast, ]) @@ -1810,8 +1823,11 @@ def test_splits_on_nested_in_tuple_broadcast(self): building_blocks.Reference('arg', arg_type), index=server_data_index ) ) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) packed_broadcast = building_blocks.Struct( - [building_blocks.Data('a', server_val_type), first_broadcast] + [building_blocks.Data(any_proto, server_val_type), first_broadcast] ) sel = building_blocks.Selection(packed_broadcast, index=0) second_broadcast = building_block_factory.create_federated_broadcast(sel) diff --git a/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py b/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py index 45cac85f51..fa2932ba1b 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_analysis_test.py @@ -349,7 +349,9 @@ def _create_trivial_mean(value_type=np.int32): fed_value_type = computation_types.FederatedType( value_type, placements.CLIENTS ) - any_proto = 'any_proto' + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) values = building_blocks.Data(any_proto, fed_value_type) return building_block_factory.create_federated_mean(values, None) @@ -360,7 +362,9 @@ def _create_trivial_secure_sum(value_type=np.int32): federated_type = computation_types.FederatedType( value_type, placements.CLIENTS ) - any_proto = 'any_proto' + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) value = building_blocks.Data(any_proto, federated_type) bitwidth = building_blocks.Data(any_proto, value_type) return building_block_factory.create_federated_secure_sum_bitwidth( @@ -392,6 +396,9 @@ class ContainsAggregationShared(parameterized.TestCase): @parameterized.named_parameters([ ('non_aggregation_intrinsics', non_aggregation_intrinsics), ('trivial_mean', trivial_mean), + # TODO: b/120439632 - Enable once federated_mean accepts structured + # weight. + # ('trivial_weighted_mean', trivial_weighted_mean), ('trivial_secure_sum', trivial_secure_sum), ]) def test_returns_none(self, comp): @@ -399,9 +406,12 @@ def test_returns_none(self, comp): self.assertEmpty(tree_analysis.find_secure_aggregation_in_tree(comp)) def test_throws_on_unresolvable_function_call(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) comp = building_blocks.Call( building_blocks.Data( - 'unknown_func', + any_proto, computation_types.FunctionType( None, computation_types.FederatedType(np.int32, placements.CLIENTS), @@ -419,12 +429,15 @@ def test_returns_none_on_unresolvable_function_call_with_non_federated_output( ): input_type = computation_types.FederatedType(np.int32, placements.CLIENTS) output_type = np.int32 + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array([1, 2, 3]) + ) comp = building_blocks.Call( building_blocks.Data( - 'unknown_func', + any_proto, computation_types.FunctionType(input_type, output_type), ), - building_blocks.Data('client_data', input_type), + building_blocks.Data(any_proto, input_type), ) self.assertEmpty(tree_analysis.find_unsecure_aggregation_in_tree(comp)) diff --git a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py index 9544a019a7..aaf99b13ae 100644 --- a/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py +++ b/tensorflow_federated/python/core/impl/compiler/tree_transformations_test.py @@ -128,21 +128,25 @@ def test_removes_federated_apply(self): self.assertTrue(modified) def test_removes_sequence_map(self): + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array(1, np.int32) + ) call = building_block_test_utils.create_whimsy_called_sequence_map( - parameter_name='a' + parameter_name='a', any_proto=any_proto ) comp = call transformed_comp, modified = ( tree_transformations.remove_mapped_or_applied_identity(comp) ) + data_str = str(id(any_proto)) self.assertEqual( comp.compact_representation(), - 'sequence_map(<(a -> a),data>)', + f'sequence_map(<(a -> a),{data_str}>)', ) self.assertEqual( transformed_comp.compact_representation(), - 'data', + data_str, ) self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified) @@ -153,18 +157,21 @@ def test_removes_federated_map_with_named_result(self): arg_type = computation_types.FederatedType( parameter_type, placements.CLIENTS ) - arg = building_blocks.Data('data', arg_type) + any_proto = building_block_test_utils.create_any_proto_from_array( + np.array(1, np.int32) + ) + arg = building_blocks.Data(any_proto, arg_type) call = building_block_factory.create_federated_map(fn, arg) comp = call transformed_comp, modified = ( tree_transformations.remove_mapped_or_applied_identity(comp) ) - + str_data = str(id(any_proto)) self.assertEqual( - comp.compact_representation(), 'federated_map(<(c -> c),data>)' + comp.compact_representation(), f'federated_map(<(c -> c),{str_data}>)' ) - self.assertEqual(transformed_comp.compact_representation(), 'data') + self.assertEqual(transformed_comp.compact_representation(), str_data) self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertTrue(modified)