Skip to content

Commit

Permalink
Merge pull request #1175 from SciML/change_equation_inference
Browse files Browse the repository at this point in the history
Update inference of variables and default differential from `@equations` macro
  • Loading branch information
TorkelE authored Jan 17, 2025
2 parents 41ce941 + 64f1e39 commit 1205ea1
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 97 deletions.
23 changes: 23 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@
(at the time the release is made). If you need a dependency version increased,
please open an issue and we can update it and make a new Catalyst release once
testing against the newer dependency version is complete.
- New formula for inferring variables from equations (declared using the `@equations` options) in the DSL. The order of inference of species/variables/parameters is now:
(1) Every symbol explicitly declared using `@species`, `@variables`, and `@parameters` are assigned to the correct category.
(2) Every symbol used as a reaction reactant is inferred as a species.
(3) Every symbol not declared in (1) or (2) that occurs in an expression provided after `@equations` is inferred as a variable.
(4) Every symbol not declared in (1), (2), or (3) that occurs either as a reaction rate or stoichiometric coefficient is inferred to be a parameter.
E.g. in
```julia
@reaction_network begin
@equations V1 + S ~ V2^2
(p + S + V1), S --> 0
end
```
`S` is inferred as a species, `V1` and `V2` as variables, and `p` as a parameter. The previous special cases for the `@observables`, `@compounds`, and `@differentials` options still hold. Finally, the `@require_declaration` options (described in more detail below) can now be used to require everything to be explicitly declared.
- New formula for determining whether the default differentials have been used within an `@equations` option. Now, if any expression `D(...)` is encountered (where `...` can be anything), this is inferred as usage of the default differential D. E.g. in the following equations `D` is inferred as a differential with respect to the default independent variable:
```julia
@reaction_network begin
@equations D(V) + V ~ 1
end
@reaction_network begin
@equations D(D(V)) ~ 1
end
```
Please note that this cannot be used at the same time as `D` is used to represent a species, variable, or parameter (including is these are implicitly designated as such by e.g. appearing as a reaction reactant).
- Array symbolics support is more consistent with ModelingToolkit v9. Parameter
arrays are no longer scalarized by Catalyst, while species and variables
arrays still are (as in ModelingToolkit). As such, parameter arrays should now
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function optimise_p(pinit, tend)
newprob = remake(prob; tspan = (0.0, tend), p = p)
sol = Array(solve(newprob, Rosenbrock23(); saveat = newtimes))
loss = sum(abs2, sol .- sample_vals[:, 1:size(sol,2)])
return loss, sol
return loss
end
# optimize for the parameters that minimize the loss
Expand Down
72 changes: 41 additions & 31 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ struct UndeclaredSymbolicError <: Exception
msg::String
end

function Base.showerror(io::IO, err::UndeclaredSymbolicError)
function Base.showerror(io::IO, err::UndeclaredSymbolicError)
print(io, "UndeclaredSymbolicError: ")
print(io, err.msg)
end
Expand Down Expand Up @@ -328,11 +328,6 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
parameters_declared = extract_syms(options, :parameters)
variables_declared = extract_syms(options, :variables)

# Reads equations.
vars_extracted, add_default_diff, equations = read_equations_options(
options, variables_declared; requiredec)
variables = vcat(variables_declared, vars_extracted)

# Handle independent variables
if haskey(options, :ivs)
ivs = Tuple(extract_syms(options, :ivs))
Expand All @@ -352,23 +347,32 @@ function make_reaction_system(ex::Expr; name = :(gensym(:ReactionSystem)))
combinatoric_ratelaws = true
end

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs; requiredec)

# Collect species and parameters, including ones inferred from the reactions.
declared_syms = Set(Iterators.flatten((parameters_declared, species_declared,
variables)))
variables_declared)))
species_extracted, parameters_extracted = extract_species_and_parameters!(
reactions, declared_syms; requiredec)

# Reads equations (and infers potential variables).
# Excludes any parameters already extracted (if they also was a variable).
declared_syms = union(declared_syms, species_extracted)
vars_extracted, add_default_diff, equations = read_equations_options(
options, declared_syms, parameters_extracted; requiredec)
variables = vcat(variables_declared, vars_extracted)
parameters_extracted = setdiff(parameters_extracted, vars_extracted)

# Creates the finalised parameter and species lists.
species = vcat(species_declared, species_extracted)
parameters = vcat(parameters_declared, parameters_extracted)

# Create differential expression.
diffexpr = create_differential_expr(
options, add_default_diff, [species; parameters; variables], tiv)

# Reads observables.
observed_vars, observed_eqs, obs_syms = read_observed_options(
options, [species_declared; variables], all_ivs; requiredec)

# Checks for input errors.
(sum(length.([reaction_lines, option_lines])) != length(ex.args)) &&
error("@reaction_network input contain $(length(ex.args) - sum(length.([reaction_lines,option_lines]))) malformed lines.")
Expand Down Expand Up @@ -701,7 +705,7 @@ end
# `vars_extracted`: A vector with extracted variables (lhs in pure differential equations only).
# `dtexpr`: If a differential equation is defined, the default derivative (D ~ Differential(t)) must be defined.
# `equations`: a vector with the equations provided.
function read_equations_options(options, variables_declared; requiredec = false)
function read_equations_options(options, syms_declared, parameters_extracted; requiredec = false)
# Prepares the equations. First, extracts equations from provided option (converting to block form if required).
# Next, uses MTK's `parse_equations!` function to split input into a vector with the equations.
eqs_input = haskey(options, :equations) ? options[:equations].args[3] : :(begin end)
Expand All @@ -713,34 +717,40 @@ function read_equations_options(options, variables_declared; requiredec = false)
# Loops through all equations, checks for lhs of the form `D(X) ~ ...`.
# When this is the case, the variable X and differential D are extracted (for automatic declaration).
# Also performs simple error checks.
vars_extracted = Vector{Symbol}()
vars_extracted = OrderedSet{Union{Symbol, Expr}}()
add_default_diff = false
for eq in equations
if (eq.head != :call) || (eq.args[1] != :~)
error("Malformed equation: \"$eq\". Equation's left hand and right hand sides should be separated by a \"~\".")
end

# Checks if the equation have the format D(X) ~ ... (where X is a symbol). This means that the
# default differential has been used. X is added as a declared variable to the system, and
# we make a note that a differential D = Differential(iv) should be made as well.
lhs = eq.args[2]
# if lhs: is an expression. Is a function call. The function's name is D. Calls a single symbol.
if (lhs isa Expr) && (lhs.head == :call) && (lhs.args[1] == :D) &&
(lhs.args[2] isa Symbol)
diff_var = lhs.args[2]
if in(diff_var, forbidden_symbols_error)
error("A forbidden symbol ($(diff_var)) was used as an variable in this differential equation: $eq")
elseif (!in(diff_var, variables_declared)) && requiredec
throw(UndeclaredSymbolicError(
"Unrecognized symbol $(diff_var) was used as a variable in an equation: \"$eq\". Since the @require_declaration flag is set, all variables in equations must be explicitly declared via @variables, @species, or @parameters."))
else
add_default_diff = true
in(diff_var, variables_declared) || push!(vars_extracted, diff_var)
end
# If the default differential (`D`) is used, record that it should be decalred later on.
if (:D union(syms_declared, parameters_extracted)) && find_D_call(eq)
requiredec && throw(UndeclaredSymbolicError(
"Unrecognized symbol D was used as a differential in an equation: \"$eq\". Since the @require_declaration flag is set, all differentials in equations must be explicitly declared using the @differentials option."))
add_default_diff = true
push!(syms_declared, :D)
end

# Any undecalred symbolic variables encountered should be extracted as variables.
add_syms_from_expr!(vars_extracted, eq, syms_declared)
(!isempty(vars_extracted) && requiredec) && throw(UndeclaredSymbolicError(
"Unrecognized symbolic variables $(join(vars_extracted, ", ")) detected in equation expression: \"$(string(eq))\". Since the flag @require_declaration is declared, all symbolic variables must be explicitly declared with the @species, @variables, and @parameters options."))
end

return vars_extracted, add_default_diff, equations
return collect(vars_extracted), add_default_diff, equations
end

# Searches an expresion `expr` and returns true if it have any subexpression `D(...)` (where `...` can be anything).
# Used to determine whether the default differential D has been used in any equation provided to `@equations`.
function find_D_call(expr)
return if Base.isexpr(expr, :call) && expr.args[1] == :D
true
elseif expr isa Expr
any(find_D_call, expr.args)
else
false
end
end

# Creates an expression declaring differentials. Here, `tiv` is the time independent variables,
Expand Down
32 changes: 16 additions & 16 deletions src/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ Base.@kwdef mutable struct NetworkProperties{I <: Integer, V <: BasicSymbolic{Re
stronglinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)
terminallinkageclasses::Vector{Vector{Int}} = Vector{Vector{Int}}(undef, 0)

checkedrobust::Bool = false
checkedrobust::Bool = false
robustspecies::Vector{Int} = Vector{Int}(undef, 0)
deficiency::Int = -1
deficiency::Int = -1
end
#! format: on

Expand Down Expand Up @@ -215,11 +215,11 @@ end

### ReactionSystem Structure ###

"""
"""
WARNING!!!
The following variable is used to check that code that should be updated when the `ReactionSystem`
fields are updated has in fact been updated. Do not just blindly update this without first checking
The following variable is used to check that code that should be updated when the `ReactionSystem`
fields are updated has in fact been updated. Do not just blindly update this without first checking
all such code and updating it appropriately (e.g. serialization). Please use a search for
`reactionsystem_fields` throughout the package to ensure all places which should be updated, are updated.
"""
Expand Down Expand Up @@ -318,7 +318,7 @@ struct ReactionSystem{V <: NetworkProperties} <:
"""
discrete_events::Vector{MT.SymbolicDiscreteCallback}
"""
Metadata for the system, to be used by downstream packages.
Metadata for the system, to be used by downstream packages.
"""
metadata::Any
"""
Expand Down Expand Up @@ -480,10 +480,10 @@ function ReactionSystem(iv; kwargs...)
ReactionSystem(Reaction[], iv, [], []; kwargs...)
end

# Called internally (whether DSL-based or programmatic model creation is used).
# Called internally (whether DSL-based or programmatic model creation is used).
# Creates a sorted reactions + equations vector, also ensuring reaction is first in this vector.
# Extracts potential species, variables, and parameters from the input (if not provided as part of
# the model creation) and creates the corresponding vectors.
# Extracts potential species, variables, and parameters from the input (if not provided as part of
# the model creation) and creates the corresponding vectors.
# While species are ordered before variables in the unknowns vector, this ordering is not imposed here,
# but carried out at a later stage.
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
Expand All @@ -495,7 +495,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
any(in(obs_vars), us_in) &&
error("Found an observable in the list of unknowns. This is not allowed.")

# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# independent variables can be excluded when encountered quantities are added to `us` and `ps`).
t = value(iv)
ivs = Set([t])
Expand Down Expand Up @@ -560,7 +560,7 @@ function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in;
end
psv = collect(new_ps)

# Passes the processed input into the next `ReactionSystem` call.
# Passes the processed input into the next `ReactionSystem` call.
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events,
discrete_events, observed, kwargs...)
end
Expand Down Expand Up @@ -1062,8 +1062,8 @@ end

### General `ReactionSystem`-specific Functions ###

# Checks if the `ReactionSystem` structure have been updated without also updating the
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
# Checks if the `ReactionSystem` structure have been updated without also updating the
# `reactionsystem_fields` constant. If this is the case, returns `false`. This is used in
# certain functionalities which would break if the `ReactionSystem` structure is updated without
# also updating these functionalities.
function reactionsystem_uptodate_check()
Expand Down Expand Up @@ -1241,7 +1241,7 @@ end
### `ReactionSystem` Remaking ###

"""
remake_ReactionSystem_internal(rs::ReactionSystem;
remake_ReactionSystem_internal(rs::ReactionSystem;
default_reaction_metadata::Vector{Pair{Symbol, T}} = Vector{Pair{Symbol, Any}}()) where {T}
Takes a `ReactionSystem` and remakes it, returning a modified `ReactionSystem`. Modifications depend
Expand Down Expand Up @@ -1274,7 +1274,7 @@ function set_default_metadata(rs::ReactionSystem; default_reaction_metadata = []
# Currently, `noise_scaling` is the only relevant metadata supported this way.
drm_dict = Dict(default_reaction_metadata)
if haskey(drm_dict, :noise_scaling)
# Finds parameters, species, and variables in the noise scaling term.
# Finds parameters, species, and variables in the noise scaling term.
ns_expr = drm_dict[:noise_scaling]
ns_syms = [Symbolics.unwrap(sym) for sym in get_variables(ns_expr)]
ns_ps = Iterators.filter(ModelingToolkit.isparameter, ns_syms)
Expand Down Expand Up @@ -1414,7 +1414,7 @@ function ModelingToolkit.compose(sys::ReactionSystem, systems::AbstractArray; na
MT.collect_scoped_vars!(newunknowns, newparams, ssys, iv)
end

if !isempty(newunknowns)
if !isempty(newunknowns)
@set! sys.unknowns = union(get_unknowns(sys), newunknowns)
sort!(get_unknowns(sys), by = !isspecies)
@set! sys.species = filter(isspecies, get_unknowns(sys))
Expand Down
Loading

0 comments on commit 1205ea1

Please sign in to comment.