Skip to content

Commit

Permalink
Always pass default field's values to __new__()/__init__().
Browse files Browse the repository at this point in the history
Previously, during Struct instantiation, if values for fields with default values were not explicitly provided, these fields would be initialized but not passed to the Struct's __new__() and __init__() methods. This meant that in order to cover all cases of instantiation, the user had to include default values in the parameter lists for __new__()/__init__().

Under the new behavior, the user should specify all fields as parameters in __new__()/__init__(), but need not include their default values.
  • Loading branch information
brandjon committed Jul 20, 2015
1 parent 797b12c commit 60a7f99
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.2.2 (unreleased)

- fields with default values are properly passed to __new__()/__init__()
- added support for coercion of tuples for Struct-typed fields
- added support for `__getitem__` and `__setitem__`
- testing a Struct for equality with itself succeeds quickly
Expand Down
7 changes: 4 additions & 3 deletions examples/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ class MutablePoint(Struct):
#
# If the fields have default values, these are substituted in before
# calling the constructor. Thus providing default parameter values
# in the constructor argument list is meaningless.
# in the constructor argument list is meaningless, as they will always
# be overridden by the defaults from the field's declaration.

class DoublingVector2D(Struct):

x = Field
y = Field
x = Field(default=0)
y = Field(default=0)

def __new__(cls, x, y):
print('Vector2D.__new__() has been called')
Expand Down
24 changes: 17 additions & 7 deletions simplestruct/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class MetaStruct(type):
Upon instantiation of a Struct subtype, set the instance's
_initialized attribute to True after __init__() returns.
Preprocess its __new__/__init__() arguments as well.
"""

# Use OrderedDict to preserve Field declaration order.
Expand Down Expand Up @@ -144,9 +145,23 @@ def __new__(mcls, clsname, bases, namespace, **kargs):

return cls

def get_boundargs(cls, *args, **kargs):
"""Return an inspect.BoundArguments object for the application
of this Struct's signature to its arguments. Add missing values
for default fields as keyword arguments.
"""
boundargs = cls._signature.bind(*args, **kargs)
# Include default arguments.
for param in cls._signature.parameters.values():
if (param.name not in boundargs.arguments and
param.default is not param.empty):
boundargs.arguments[param.name] = param.default
return boundargs

# Mark the class as _initialized after construction.
def __call__(cls, *args, **kargs):
inst = super().__call__(*args, **kargs)
boundargs = cls.get_boundargs(*args, **kargs)
inst = super().__call__(*boundargs.args, **boundargs.kwargs)
inst._initialized = True
return inst

Expand Down Expand Up @@ -193,12 +208,7 @@ def __new__(cls, *args, **kargs):

f = None
try:
boundargs = cls._signature.bind(*args, **kargs)
# Include default arguments.
for param in cls._signature.parameters.values():
if (param.name not in boundargs.arguments and
param.default is not param.empty):
boundargs.arguments[param.name] = param.default
boundargs = cls.get_boundargs(*args, **kargs)
for f in cls._struct:
setattr(inst, f.name, boundargs.arguments[f.name])
f = None
Expand Down
7 changes: 5 additions & 2 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,13 @@ class Foo(Struct):
class Foo(Struct):
a = Field()
b = Field(default='b')
# Make sure default field values are passed to __init__() too.
def __init__(self, a, b):
self.c = b
f = Foo(1, 2)
self.assertEqual((f.a, f.b), (1, 2))
self.assertEqual((f.a, f.b, f.c), (1, 2, 2))
f = Foo(1)
self.assertEqual((f.a, f.b), (1, 'b'))
self.assertEqual((f.a, f.b, f.c), (1, 'b', 'b'))

# Parentheses-less shorthand.
class Foo(Struct):
Expand Down

0 comments on commit 60a7f99

Please sign in to comment.