Skip to content

Commit

Permalink
Merge pull request #1116 from sphuber/fix_1063_process_builder
Browse files Browse the repository at this point in the history
Implement a ProcessBuilder to normalize the launching of processes
  • Loading branch information
muhrin authored Feb 14, 2018
2 parents 317bcfe + 5f3a9ba commit 03a7bca
Show file tree
Hide file tree
Showing 16 changed files with 433 additions and 373 deletions.
1 change: 1 addition & 0 deletions aiida/backends/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
'work.futures': ['aiida.backends.tests.work.test_futures'],
'work.persistence': ['aiida.backends.tests.work.persistence'],
'work.process': ['aiida.backends.tests.work.process'],
'work.process_builder': ['aiida.backends.tests.work.test_process_builder'],
'work.process_spec': ['aiida.backends.tests.work.test_process_spec'],
'work.rmq': ['aiida.backends.tests.work.test_rmq'],
'work.run': ['aiida.backends.tests.work.run'],
Expand Down
49 changes: 49 additions & 0 deletions aiida/backends/tests/work/test_process_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at /~https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################

from aiida.backends.testbase import AiidaTestCase
from aiida.orm import CalculationFactory
from aiida.orm.data.parameter import ParameterData
from aiida.work.process_builder import ProcessBuilder
from aiida.work import utils


class TestProcessBuilder(AiidaTestCase):

def setUp(self):
super(TestProcessBuilder, self).setUp()
self.assertEquals(len(utils.ProcessStack.stack()), 0)
self.calculation_class = CalculationFactory('simpleplugins.templatereplacer')
self.process_class = self.calculation_class.process()
self.builder = self.process_class.get_builder()

def tearDown(self):
super(TestProcessBuilder, self).tearDown()
self.assertEquals(len(utils.ProcessStack.stack()), 0)

def test_process_builder_attributes(self):
"""
Check that the builder has all the input ports of the process class as attributes
"""
for name, port in self.process_class.spec().inputs.iteritems():
self.assertTrue(hasattr(self.builder, name))

def test_process_builder_set_attributes(self):
"""
Verify that setting attributes in builder works
"""
label = 'Test label'
description = 'Test description'

self.builder.label = label
self.builder.description = description

self.assertEquals(self.builder.label, label)
self.assertEquals(self.builder.description, description)
36 changes: 5 additions & 31 deletions aiida/backends/tests/work/test_process_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,21 @@
# For further information please visit http://www.aiida.net #
###########################################################################


from collections import Mapping
from aiida.backends.testbase import AiidaTestCase
from aiida.work.processes import Process, ProcessSpec
import aiida.work.utils as util
from aiida.work.processes import Process
from aiida.work import utils


class TestProcessSpec(AiidaTestCase):

def setUp(self):
super(TestProcessSpec, self).setUp()
self.assertEquals(len(util.ProcessStack.stack()), 0)
self.assertEquals(len(utils.ProcessStack.stack()), 0)
self.spec = Process.spec()

def tearDown(self):
super(TestProcessSpec, self).tearDown()
self.assertEquals(len(util.ProcessStack.stack()), 0)

def test_get_inputs_template(self):
s = ProcessSpec()
s.input('a')
s.input('b', default=5)

template = s.get_inputs_template()
self.assertIsInstance(template, Mapping)
self._test_template(template)
for attr in ['b']:
self.assertTrue(
attr in template,
"Attribute '{}' not found in template".format(attr))
self.assertEquals(len(utils.ProcessStack.stack()), 0)

def test_dynamic_input(self):
from aiida.orm import Node
Expand All @@ -59,15 +45,3 @@ def test_dynamic_output(self):
self.assertFalse(self.spec.validate_inputs({'key': 5})[0])
self.assertFalse(self.spec.validate_inputs({'key': n})[0])
self.assertTrue(self.spec.validate_inputs({'key': d})[0])

def _test_template(self, template):
template.a = 2
self.assertEqual(template.a, 2)
# Check the default is what we expect
self.assertEqual(template.b, 5)
with self.assertRaises(AttributeError):
template.c = 6

# Check that we can unpack
self.assertEqual(dict(**template)['a'], 2)

150 changes: 1 addition & 149 deletions aiida/common/extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,11 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from aiida.common.exceptions import ValidationError
import collections
from aiida.common.exceptions import ValidationError
from aiida.common.lang import override


## TODO: see if we want to have a function to rebuild a nested dictionary as
## a nested AttributeDict object when deserializing with json.
## (now it deserialized to a standard dictionary; comparison of
## AttributeDict == dict works fine, though.
## Note also that when pickling, instead, the structure is well preserved)

## Note that for instance putting this code in __getattr__ doesn't work:
## everytime I try to write on a.b.c I am actually writing on a copy
## return AttributeDict(item) if type(item) == dict else item


class Enumerate(frozenset):
def __getattr__(self, name):
if name in self:
Expand Down Expand Up @@ -287,140 +276,3 @@ def extrakeys(self):
Return the extra keys defined in the instance.
"""
return [_ for _ in self.keys() if _ not in self._default_fields]


class FixedDict(collections.MutableMapping, object):
def __init__(self, valid_keys):
class M(object):
pass

self._m = M()
self._m.values = {}
self._m.valid_keys = valid_keys

# Methods from MutableMapping ##########################
@override
def __dir__(self):
return self._m.valid_keys

@override
def __getitem__(self, key):
return self._m.values.__getitem__(key)

@override
def __setitem__(self, key, value):
if key not in self._m.valid_keys:
raise AttributeError("Invalid attribute: {}".format(key))
return self._m.values.__setitem__(key, value)

@override
def __delitem__(self, key):
assert key in self._m.values, \
"Cannot delete an item that has not been set."
return self._m.values.__delitem__(key)

@override
def __iter__(self):
return self._m.values.__iter__()

@override
def __len__(self):
return self._m.values.__len__()

########################################################

def __getattr__(self, item):
if item == '_m':
return super(FixedDict, self).__getattr__(item)
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError("AttributeError: '{}'".format(item))

def __setattr__(self, key, value):
if key == '_m':
return super(FixedDict, self).__setattr__(key, value)
return self.__setitem__(key, value)

def __delattr__(self, item):
return self.__delitem__(item)


class _WithDefaults(object):
def __init__(self, defaults):
self._m._defaults = {}
if defaults:
self._m._defaults.update(defaults)

def get_default(self, key):
return self._m._defaults[key]

@property
def defaults(self):
return self._m._defaults


class DefaultsDict(collections.MutableMapping):
def __init__(self, valid_keys, defaults=None):
self._set_internal('_valid_keys', valid_keys)
self._set_internal('_user_supplied', {})

if defaults is None:
defaults = {}
for key in defaults:
assert key in valid_keys
self._set_internal('_defaults', defaults)

def __dir__(self):
return self._get_internal('_valid_keys')

def __getitem__(self, item):
return self._get_internal('_user_supplied')[item]

def __setitem__(self, key, value):
if key not in self._get_internal('_valid_keys'):
raise KeyError("KeyError: '{}'".format(key))
self._get_internal('_user_supplied')[key] = value

def __iter__(self):
self._get_internal('_user_supplied').__iter__()

def __len__(self):
return len(self._get_internal('_user_supplied'))

def __delitem__(self, key):
del self._get_internal('_user_supplied')[key]

def __getattr__(self, item):
try:
return self._user_supplied[item]
except KeyError:
try:
self._defaults[item]
except KeyError:
raise AttributeError("AttributeError: '{}'".format(item))

def __setattr__(self, key, value):
try:
self.__setitem__(key, value)
except KeyError:
raise AttributeError("AttributeError: '{}'".format(key))

def __delattr__(self, key):
try:
self.__delitem__(key)
except KeyError:
raise AttributeError("AttributeError: '{}'".format(key))

def _get_internal(self, item):
return super(DefaultsDict, self).__getattribute__(item)

def _set_internal(self, key, value):
return super(DefaultsDict, self).__setattr__(key, value)

def _set_value(self, key, value):
self._get_internal('_user_supplied')[key] = value

@property
def defaults(self):
return self._get_internal('_defaults')
56 changes: 2 additions & 54 deletions aiida/common/test_extendeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,9 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
from aiida.common.extendeddicts import *
import unittest
import copy



class TestDefaultsDict(unittest.TestCase):
def setUp(self):
self.defaults_dict = DefaultsDict(
valid_keys=['foo', 'bar'],
defaults={'bar': 'bar_default'})

def test_setattr(self):
# Test setting and getting a value
self.defaults_dict.foo = 'hello'
self.assertEqual(self.defaults_dict.foo, 'hello')

# Test setting an invalid attribute
with self.assertRaises(AttributeError):
self.defaults_dict.non_existent = None

def test_defaults(self):
self.assertEquals(self.defaults_dict.defaults, {'bar': 'bar_default'})

def test_getattr(self):
# Test getting an unset value
with self.assertRaises(AttributeError):
self.defaults_dict.unset

def test_delattr(self):
self.defaults_dict.foo = 'hello'
del self.defaults_dict.foo
# Now try deleting it again
with self.assertRaises(AttributeError):
del self.defaults_dict.foo

# Try deleting one that never existed
with self.assertRaises(AttributeError):
del self.defaults_dict.foo

def test_delitem(self):
self.defaults_dict['foo'] = 'test'
del self.defaults_dict['foo']
# Try deleting again
with self.assertRaises(KeyError):
del self.defaults_dict['foo']

# Try deleting on that never existed
with self.assertRaises(KeyError):
del self.defaults_dict['non_existent']

def test_invalid_default(self):
with self.assertRaises(AssertionError):
DefaultsDict([], defaults={'foo': 'bar'})
import unittest
from aiida.common.extendeddicts import *


class TestFFADExample(FixedFieldsAttributeDict):
Expand Down Expand Up @@ -378,4 +327,3 @@ def test_validation(self):
# a.a must be a positive integer
with self.assertRaises(ValidationError):
o.validate()

Loading

0 comments on commit 03a7bca

Please sign in to comment.