Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update value_count serialization/deserialization to be consistent with original schema #111

Merged
merged 21 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a8b6dd5
Add test for value_count property
oliverholworthy Aug 3, 2022
4516dbf
reformat test_schema_io.py
oliverholworthy Aug 3, 2022
d2b21ec
Merge branch 'main' into schema-io-value-count
karlhigley Aug 15, 2022
6b65311
Merge branch 'main' into schema-io-value-count
oliverholworthy Nov 18, 2022
814472f
Update value_count serizliation and default is_list/is_ragged.
oliverholworthy Nov 18, 2022
bd68a14
Add check for value count and is_ragged compatibility
oliverholworthy Nov 18, 2022
d5a0ca7
Update formatting
oliverholworthy Nov 18, 2022
6610d7f
Update formatting.
oliverholworthy Nov 18, 2022
64d0aa7
Restore is_list/is_ragged inference from value_count when loading schema
oliverholworthy Nov 18, 2022
f1f9c43
Correct test passinging list of properties
oliverholworthy Nov 18, 2022
1230469
Update test to reflect default is_ragged attribute
oliverholworthy Nov 18, 2022
8c549ec
Only check is_ragged min/max if provided in constructor
oliverholworthy Nov 18, 2022
3feb59a
Update formatting
oliverholworthy Nov 18, 2022
b89438d
Merge branch 'main' into schema-io-value-count
karlhigley Nov 18, 2022
01010c5
Merge branch 'main' into schema-io-value-count
karlhigley Nov 18, 2022
95ddcfa
Revert change to default value of `is_ragged`
oliverholworthy Nov 18, 2022
4cdaa78
Add check for value count of zero and raise ValueError
oliverholworthy Nov 18, 2022
0ca930d
Add test for zero value in properties
oliverholworthy Nov 18, 2022
f7fd2e1
Enable partial value count to be specified
oliverholworthy Nov 18, 2022
c73f7bb
Add test for specifying only max value count and fix deserialization
oliverholworthy Nov 18, 2022
f48634c
Check min or max when setting is_list
oliverholworthy Nov 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions merlin/schema/io/tensorflow_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,7 @@
import numpy

from merlin.schema.io import proto_utils, schema_bp
from merlin.schema.io.schema_bp import (
Feature,
FeatureType,
FixedShape,
FixedShapeDim,
FloatDomain,
IntDomain,
)
from merlin.schema.io.schema_bp import Feature, FeatureType, FloatDomain, IntDomain
from merlin.schema.io.schema_bp import Schema as ProtoSchema
from merlin.schema.io.schema_bp import ValueCount
from merlin.schema.schema import ColumnSchema
Expand Down Expand Up @@ -282,17 +275,11 @@ def _pb_feature(column_schema):

feature = _set_feature_domain(feature, column_schema)

if column_schema.is_list:
value_count = column_schema.properties.get("value_count", {})
min_length = value_count.get("min")
max_length = value_count.get("max")

if min_length and max_length and min_length == max_length:
feature.shape = FixedShape([FixedShapeDim(size=min_length)])
elif min_length and max_length and min_length < max_length:
feature.value_count = ValueCount(min=min_length, max=max_length)
else:
feature.value_count = ValueCount(min=0, max=0)
value_count = column_schema.properties.get("value_count", {})
if value_count:
min_length = value_count.get("min", 0)
max_length = value_count.get("max", 0)
feature.value_count = ValueCount(min=min_length, max=max_length)

feature.annotation.tag = _pb_tag(column_schema)
feature.annotation.extra_metadata.append(_pb_extra_metadata(column_schema))
Expand Down Expand Up @@ -342,8 +329,7 @@ def _merlin_domain(feature):
def _merlin_value_count(feature):
if proto_utils.has_field(feature, "value_count"):
value_count = feature.value_count
if value_count.min != value_count.max != 0:
return {"min": value_count.min, "max": value_count.max}
return {"min": value_count.min, "max": value_count.max}


def _merlin_properties(feature):
Expand All @@ -363,14 +349,17 @@ def _merlin_properties(feature):
properties = {}

domain = _merlin_domain(feature)

if domain:
properties["domain"] = domain

value_count = _merlin_value_count(feature)

if value_count:
properties["value_count"] = value_count
properties["is_list"] = True
properties["is_list"] = value_count.get("min") > 0
properties["is_ragged"] = value_count.get("min") != value_count.get("max")

return properties


Expand Down
24 changes: 22 additions & 2 deletions merlin/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ColumnSchema:
tags: Optional[TagSet] = field(default_factory=TagSet)
properties: Optional[Dict] = field(default_factory=dict)
dtype: Optional[object] = None
is_list: bool = False
is_list: Optional[bool] = None
is_ragged: Optional[bool] = None

def __post_init__(self):
Expand Down Expand Up @@ -75,15 +75,35 @@ def __post_init__(self):

object.__setattr__(self, "dtype", dtype)

value_count = self.properties.get("value_count")
value_count_provided = value_count is not None
value_count = value_count or {"min": 0, "max": 0}

if self.is_list is None:
if value_count["max"] > 0:
object.__setattr__(self, "is_list", True)
else:
object.__setattr__(self, "is_list", False)

if self.is_ragged is None:
object.__setattr__(self, "is_ragged", self.is_list)
if value_count["max"] > value_count["min"]:
object.__setattr__(self, "is_ragged", True)
else:
object.__setattr__(self, "is_ragged", False)

if self.is_ragged and not self.is_list:
raise ValueError(
"`is_ragged` is set to `True` but `is_list` is not. "
"Only list columns can set the `is_ragged` flag."
)

if self.is_ragged and value_count_provided and value_count["max"] == value_count["min"]:
raise ValueError(
"`is_ragged` is set to `True` but `value_count.min` == `value_count.max`. "
"If value_count min/max are equal. "
"This is a fixed size list and `is_ragged` should be set to False. "
)

@property
def quantity(self):
"""
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/schema/test_column_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@pytest.mark.parametrize("d_types", [numpy.float32, numpy.float64, numpy.uint32, numpy.uint64])
def test_dtype_column_schema(d_types):
column = ColumnSchema("name", tags=[], properties=[], dtype=d_types)
column = ColumnSchema("name", tags=[], properties={}, dtype=d_types)
assert column.dtype == d_types


Expand Down Expand Up @@ -180,8 +180,8 @@ def test_list_column_attributes():
col2_schema = ColumnSchema("col2", is_list=True)

assert col2_schema.is_list
assert col2_schema.is_ragged
assert col2_schema.quantity == ColumnQuantity.RAGGED_LIST
assert not col2_schema.is_ragged
assert col2_schema.quantity == ColumnQuantity.FIXED_LIST

col3_schema = ColumnSchema("col3", is_list=True, is_ragged=True)

Expand All @@ -197,3 +197,11 @@ def test_list_column_attributes():

with pytest.raises(ValueError):
ColumnSchema("col5", is_list=False, is_ragged=True)


def test_value_count():
with pytest.raises(ValueError) as exc_info:
ColumnSchema("col", is_ragged=True, properties={"value_count": {"min": 2, "max": 2}})
assert "`is_ragged` is set to `True` but `value_count.min` == `value_count.max`" in str(
exc_info.value
)
27 changes: 27 additions & 0 deletions tests/unit/schema/test_schema_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,33 @@ def test_merlin_to_proto_to_json_to_merlin():
assert output_schema == schema


@pytest.mark.parametrize(
["value_count", "expected_is_list", "expected_is_ragged"],
[
[{"min": 0, "max": 0}, False, False],
[{"min": 1, "max": 1}, True, False],
[{"min": 1, "max": 2}, True, True],
],
)
def test_value_count(value_count, expected_is_list, expected_is_ragged):
schema = Schema(
[
ColumnSchema(
"example",
properties={
"value_count": value_count,
},
)
]
)
assert schema["example"].is_list == expected_is_list
assert schema["example"].is_ragged == expected_is_ragged

json_schema = TensorflowMetadata.from_merlin_schema(schema).to_json()
output_schema = TensorflowMetadata.from_json(json_schema).to_merlin_schema()
assert output_schema == schema


def test_column_schema_protobuf_domain_check(tmpdir):
# create a schema
schema1 = ColumnSchema(
Expand Down