Skip to content

Commit

Permalink
Update tff.framework.Data for new proto field.
Browse files Browse the repository at this point in the history
We have updated the proto message `Data` for a new content field that is an Any proto, deprecating the `string uri` field. The intention is to make this field more flexible, for implementing environments. This change updates the building_blocks class for the new field.

PiperOrigin-RevId: 629785319
  • Loading branch information
eglanz authored and tensorflow-copybara committed May 2, 2024
1 parent 643e4d6 commit d53d52a
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 88 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
## Major Features and Improvements

* Added an implementation of `__eq__()` on `building blocks`.
* Added a new field, `content`, to the `Data` building block and updated
tests.

## Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/core/backends/mapreduce/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ py_test(
":mapreduce_test_utils",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/core/impl/compiler:building_block_factory",
"//tensorflow_federated/python/core/impl/compiler:building_block_test_utils",
"//tensorflow_federated/python/core/impl/compiler:building_blocks",
"//tensorflow_federated/python/core/impl/compiler:intrinsic_defs",
"//tensorflow_federated/python/core/impl/compiler:tensorflow_computation_factory",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tensorflow_federated.python.core.backends.mapreduce import form_utils
from tensorflow_federated.python.core.backends.mapreduce import mapreduce_test_utils
from tensorflow_federated.python.core.impl.compiler import building_block_factory
from tensorflow_federated.python.core.impl.compiler import building_block_test_utils
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.compiler import tensorflow_computation_factory
Expand Down Expand Up @@ -583,8 +584,11 @@ def test_compiles_lambda_under_federated_comp_to_tf(self):
identity_lambda = building_blocks.Lambda(
ref_to_x.name, ref_to_x.type_signature, ref_to_x
)
any_proto = building_block_test_utils.create_any_proto_from_array(
np.array(1)
)
federated_data = building_blocks.Data(
'a',
any_proto,
computation_types.FederatedType(
computation_types.StructType([np.int32, np.float32]),
placements.SERVER,
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ py_library(
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/impl/types:type_serialization",
"//tensorflow_federated/python/core/impl/types:typed_object",
"@com_google_protobuf//:protobuf_python",
],
)

Expand All @@ -166,6 +167,7 @@ py_test(
size = "small",
srcs = ["building_blocks_test.py"],
deps = [
":array",
":building_block_test_utils",
":building_blocks",
":computation_factory",
Expand All @@ -177,6 +179,7 @@ py_test(
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/impl/types:type_serialization",
"@com_google_protobuf//:protobuf_python",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,20 +323,27 @@ def create_whimsy_called_federated_sum(
return building_block_factory.create_federated_sum(value)


def create_whimsy_called_sequence_map(parameter_name, parameter_type=np.int32):
def create_whimsy_called_sequence_map(
# TODO: b/338284242 - Remove any proto from function constructor once the
# compiler tests no longer support string equality.
parameter_name,
parameter_type=np.int32,
any_proto=any_pb2.Any(),
):
r"""Returns a whimsy called sequence map.
Call
/ \
sequence_map data
sequence_map Data(id)
Args:
parameter_name: The name of the parameter.
parameter_type: The type of the parameter.
any_proto: The any proto to use for the data block.
"""
fn = create_identity_function(parameter_name, parameter_type)
arg_type = computation_types.SequenceType(parameter_type)
arg = building_blocks.Data('data', arg_type)
arg = building_blocks.Data(any_proto, arg_type)
return building_block_factory.create_sequence_map(fn, arg)


Expand Down
35 changes: 17 additions & 18 deletions tensorflow_federated/python/core/impl/compiler/building_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np

from google.protobuf import any_pb2
from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
Expand Down Expand Up @@ -984,7 +985,7 @@ def __repr__(self) -> str:
class Data(ComputationBuildingBlock):
"""A representation of data (an input pipeline).
This class does not deal with parsing data URIs and verifying correctness,
This class does not deal with parsing data protos and verifying correctness,
it is only a container. Parsing and type analysis are a responsibility
or a component external to this module.
"""
Expand All @@ -993,67 +994,65 @@ class Data(ComputationBuildingBlock):
def from_proto(cls, computation_proto: pb.Computation) -> 'Data':
_check_computation_oneof(computation_proto, 'data')
return cls(
computation_proto.data.uri,
computation_proto.data.content,
type_serialization.deserialize_type(computation_proto.type),
)

def __init__(self, uri: str, type_spec: object):
def __init__(self, content: any_pb2.Any, type_spec: object):
"""Creates a representation of data.
Args:
uri: The URI that characterizes the data.
content: The proto that characterizes the data.
type_spec: Either the types.Type that represents the type of this data, or
something convertible to it by types.to_type().
Raises:
TypeError: if the arguments are of the wrong types.
ValueError: if the user tries to specify an empty URI.
"""
py_typecheck.check_type(uri, str)
if not uri:
raise ValueError('Empty string cannot be passed as URI to Data.')
py_typecheck.check_type(content, any_pb2.Any)
if type_spec is None:
raise TypeError(
'Intrinsic {} cannot be created without a TFF type.'.format(uri)
'Intrinsic {} cannot be created without a TFF type.'.format(content)
)
type_spec = computation_types.to_type(type_spec)
super().__init__(type_spec)
self._uri = uri
self._content = content

def _proto(self) -> pb.Computation:
return pb.Computation(
type=type_serialization.serialize_type(self.type_signature),
data=pb.Data(uri=self._uri),
data=pb.Data(content=self._content),
)

def children(self) -> Iterator[ComputationBuildingBlock]:
del self
return iter(())

@property
def uri(self) -> str:
return self._uri
def content(self) -> any_pb2.Any:
return self._content

def __eq__(self, other: object) -> bool:
if self is other:
return True
elif not isinstance(other, Data):
return NotImplemented
return (
self._uri,
id(self._content),
self._type_signature,
) == (
other._uri,
id(other._content),
other._type_signature,
)

def __hash__(self):
if self._cached_hash is None:
self._cached_hash = hash((self._uri, self._type_signature))
self._cached_hash = hash((id(self._content), self._type_signature))
return self._cached_hash

def __repr__(self) -> str:
return "Data('{}', {!r})".format(self._uri, self.type_signature)
return 'Data({}, {!r})'.format(id(self._content), self.type_signature)


class CompiledComputation(ComputationBuildingBlock):
Expand Down Expand Up @@ -1415,7 +1414,7 @@ def _lines_for_comp(comp, formatted):
elif isinstance(comp, CompiledComputation):
return ['comp#{}'.format(comp.name)]
elif isinstance(comp, Data):
return [comp.uri]
return [str(id(comp.content))]
elif isinstance(comp, Intrinsic):
return [comp.uri]
elif isinstance(comp, Lambda):
Expand Down Expand Up @@ -1679,7 +1678,7 @@ def _get_node_label(comp):
elif isinstance(comp, CompiledComputation):
return 'Compiled({})'.format(comp.name)
elif isinstance(comp, Data):
return comp.uri
return f'Data({id(comp.content)})'
elif isinstance(comp, Intrinsic):
return comp.uri
elif isinstance(comp, Lambda):
Expand Down
Loading

0 comments on commit d53d52a

Please sign in to comment.