Skip to content

Commit

Permalink
feat: don't resample structure
Browse files Browse the repository at this point in the history
  • Loading branch information
younesStrittmatter committed Feb 12, 2024
1 parent af454db commit 83e8894
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 32 deletions.
56 changes: 32 additions & 24 deletions src/equation_tree/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,115 +98,121 @@ def __sample_tree_raw_fast(
prior,
tree_depth,
max_num_variables,
structure=None
):
equation_tree = EquationTree.from_prior_fast(prior, tree_depth, max_num_variables)
equation_tree = EquationTree.from_prior_fast(prior, tree_depth, max_num_variables, structure=structure)
_tmp = equation_tree.prefix.copy()
_tmp_structure = equation_tree.standard_structure.copy()

# Check if tree is valid
if not equation_tree.check_validity():
return None
return None, _tmp_structure

try:
equation_tree.simplify(
function_test=lambda x: x in get_defined_functions(prior),
operator_test=lambda x: x in get_defined_operators(prior),
)
if equation_tree.prefix != _tmp:
return None
return None, _tmp_structure
except ValueError:
return None
return None, _tmp_structure

# Check is nan
if equation_tree.is_nan:
return None
return None, _tmp_structure

# Check if duplicate constants
if (
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
):
return None
return None, _tmp_structure

# Check if more variables than max:
if equation_tree.n_variables_unique > max_num_variables:
return None

# Check if tree depth is exact
if len(equation_tree.structure) != tree_depth:
return None
return None, _tmp_structure

if not equation_tree.check_validity():
return None
return None, _tmp_structure

if not equation_tree.check_possible_from_prior(prior):
return None
return None, _tmp_structure

equation_tree.get_evaluation()
if not equation_tree.has_valid_value:
return None
return None, _tmp_structure

return equation_tree
return equation_tree, _tmp_structure


def __sample_tree_raw(
prior,
max_num_variables,
structure=None
):
equation_tree = EquationTree.from_prior(prior, max_num_variables)
equation_tree = EquationTree.from_prior(prior, max_num_variables, structure=structure)
_tmp = equation_tree.prefix.copy()
_tmp_structure = equation_tree.standard_structure.copy()

# Check if tree is valid
if not equation_tree.check_validity():
return None
return None, _tmp_structure

try:
equation_tree.simplify(
function_test=lambda x: x in get_defined_functions(prior),
operator_test=lambda x: x in get_defined_operators(prior),
)
if equation_tree.prefix != _tmp:
return None
return None, _tmp_structure
except ValueError:
return None
return None, _tmp_structure

# Check is nan
if equation_tree.is_nan:
return None
return None, _tmp_structure

# Check if duplicate constants
if (
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
):
return None
return None, _tmp_structure

# Check if more variables than max:
if equation_tree.n_variables_unique > max_num_variables:
return None
return None, _tmp_structure

if not equation_tree.check_validity():
return None
return None, _tmp_structure

if not equation_tree.check_possible_from_prior(prior):
return None
return None, _tmp_structure

equation_tree.get_evaluation()
if not equation_tree.has_valid_value:
return None
return None, _tmp_structure

return equation_tree
return equation_tree, _tmp_structure


def _sample_tree_iter_fast(
prior,
tree_depth,
max_num_variables,
):
_structure = None
for _ in range(MAX_ITER):
equation_tree = __sample_tree_raw_fast(
equation_tree, _structure = __sample_tree_raw_fast(
prior,
tree_depth,
max_num_variables,
structure=_structure
)
if equation_tree is not None:
return equation_tree
Expand All @@ -216,10 +222,12 @@ def _sample_tree_iter(
prior,
max_num_variables,
):
_structure = None
for _ in range(MAX_ITER):
equation_tree = __sample_tree_raw(
equation_tree, _structure = __sample_tree_raw(
prior,
max_num_variables,
structure=_structure
)

if equation_tree is not None:
Expand Down
12 changes: 8 additions & 4 deletions src/equation_tree/src/tree_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def sample_tree(
return tree


def sample_tree_full(prior, max_var_unique):
def sample_tree_full(prior, max_var_unique, structure=None):
"""
Examples:
>>> np.random.seed(42)
Expand Down Expand Up @@ -780,7 +780,9 @@ def sample_tree_full(prior, max_var_unique):
True
"""
tree_structure = sample_tree_structure(prior["structures"])
tree_structure = structure
if tree_structure is None:
tree_structure = sample_tree_structure(prior["structures"])
function_conditionals = None
operator_conditionals = None
if "function_conditionals" in prior.keys():
Expand All @@ -801,7 +803,7 @@ def sample_tree_full(prior, max_var_unique):
return tree


def sample_tree_full_fast(prior, tree_depth, max_var_unique):
def sample_tree_full_fast(prior, tree_depth, max_var_unique, structure=None):
"""
Examples:
>>> np.random.seed(42)
Expand Down Expand Up @@ -883,7 +885,9 @@ def sample_tree_full_fast(prior, tree_depth, max_var_unique):
True
"""
tree_structure = sample_tree_structure_fast(tree_depth)
tree_structure = structure
if tree_structure is None:
tree_structure = sample_tree_structure_fast(tree_depth)
function_conditionals = None
operator_conditionals = None
if "function_conditionals" in prior.keys():
Expand Down
8 changes: 4 additions & 4 deletions src/equation_tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def from_full_prior(cls, prior):
return cls(root)

@classmethod
def from_prior_fast(cls, prior: Dict, tree_depth, max_variables_unique: int):
def from_prior_fast(cls, prior: Dict, tree_depth, max_variables_unique: int, structure=None):
"""
Initiate a tree from a prior with fast sampling
Attention: structure prior is not supported
Expand All @@ -267,11 +267,11 @@ def from_prior_fast(cls, prior: Dict, tree_depth, max_variables_unique: int):
max_variables_unique: The maximum number of unique variables (a tree can have less then
this number)
"""
root = sample_tree_full_fast(prior, tree_depth, max_variables_unique)
root = sample_tree_full_fast(prior, tree_depth, max_variables_unique, structure=structure)
return cls(root)

@classmethod
def from_prior(cls, prior: Dict, max_variables_unique: int):
def from_prior(cls, prior: Dict, max_variables_unique: int, structure=None):
"""
Initiate a tree from a prior
Expand Down Expand Up @@ -353,7 +353,7 @@ def from_prior(cls, prior: Dict, max_variables_unique: int):
# Note: this would be discarded in a future step as unnecesarry complex
"""
root = sample_tree_full(prior, max_variables_unique)
root = sample_tree_full(prior, max_variables_unique, structure=structure)
return cls(root)

@classmethod
Expand Down

0 comments on commit 83e8894

Please sign in to comment.