Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type-stable access to fx #114

Merged
merged 3 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/State/GenericStatemod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ end
scoretype(typestate::AbstractState{S, T}) where {S, T} = S
xtype(typestate::AbstractState{S, T}) where {S, T} = T

for field in fieldnames(GenericState)
meth = Symbol("get_", field)
@eval begin
@doc """
$($meth)(state)
Return the value $($(QuoteNode(field))) from the state.
"""
$meth(state::GenericState) = getproperty(state, $(QuoteNode(field)))
end
@eval export $meth
end

"""
`update!(:: AbstractState; convert = false, kwargs...)`

Expand Down
28 changes: 20 additions & 8 deletions src/State/NLPAtXmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ Note:

See also: `GenericState`, `update!`, `update_and_start!`, `update_and_stop!`, `reinit!`
"""
mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T}
mutable struct NLPAtX{Score, S, T <: AbstractVector} <: AbstractState{S, T}

#Unconstrained State
x::T # current point
fx::eltype(T) # objective function
fx::S # objective function
gx::T # gradient size: x
Hx # hessian size: |x| x |x|

Expand All @@ -66,12 +66,12 @@ mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T}

#Resources State
current_time::Float64
current_score::S
current_score::Score

function NLPAtX(
x::T,
lambda::T,
current_score::S;
current_score::Score;
fx::eltype(T) = _init_field(eltype(T)),
gx::T = _init_field(T),
Hx = _init_field(Matrix{eltype(T)}),
Expand All @@ -81,10 +81,10 @@ mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T}
d::T = _init_field(T),
res::T = _init_field(T),
current_time::Float64 = NaN,
) where {S, T <: AbstractVector}
) where {Score, S, T <: AbstractVector}
_size_check(x, lambda, fx, gx, Hx, mu, cx, Jx)

return new{S, T}(x, fx, gx, Hx, mu, cx, Jx, lambda, d, res, current_time, current_score)
return new{Score, eltype(T), T}(x, fx, gx, Hx, mu, cx, Jx, lambda, d, res, current_time, current_score)
end
end

Expand Down Expand Up @@ -156,6 +156,18 @@ function NLPAtX(
)
end

for field in fieldnames(NLPAtX)
meth = Symbol("get_", field)
@eval begin
@doc """
$($meth)(state)
Return the value $($(QuoteNode(field))) from the state.
"""
$meth(state::NLPAtX) = getproperty(state, $(QuoteNode(field)))
end
@eval export $meth
end

"""
reinit!: function that set all the entries at void except the mandatory x

Expand All @@ -166,7 +178,7 @@ reinit!: function that set all the entries at void except the mandatory x
Note: if `x` or `lambda` are given as keyword arguments they will be
prioritized over the existing `x`, `lambda` and the default `Counters`.
"""
function reinit!(stateatx::NLPAtX{S, T}, x::T, l::T; kwargs...) where {S, T}
function reinit!(stateatx::NLPAtX{Score, S, T}, x::T, l::T; kwargs...) where {Score, S, T}
for k ∈ fieldnames(NLPAtX)
if k ∉ [:x, :lambda]
setfield!(stateatx, k, _init_field(typeof(getfield(stateatx, k))))
Expand All @@ -183,7 +195,7 @@ function reinit!(stateatx::NLPAtX{S, T}, x::T, l::T; kwargs...) where {S, T}
return update!(stateatx; kwargs...)
end

function reinit!(stateatx::NLPAtX{S, T}, x::T; kwargs...) where {S, T}
function reinit!(stateatx::NLPAtX{Score, S, T}, x::T; kwargs...) where {Score, S, T}
for k ∈ fieldnames(NLPAtX)
if k ∉ [:x, :lambda]
setfield!(stateatx, k, _init_field(typeof(getfield(stateatx, k))))
Expand Down
12 changes: 12 additions & 0 deletions src/State/OneDAtXmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,15 @@ function OneDAtX(
current_time = current_time,
)
end

for field in fieldnames(OneDAtX)
meth = Symbol("get_", field)
@eval begin
@doc """
$($meth)(state)
Return the value $($(QuoteNode(field))) from the state.
"""
$meth(state::OneDAtX) = getproperty(state, $(QuoteNode(field)))
end
@eval export $meth
end
20 changes: 10 additions & 10 deletions src/Stopping/NLPStoppingmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ fill_in!: (NLPStopping version) a function that fill in the required values in t
`fill_in!( :: NLPStopping, :: Union{AbstractVector, Nothing}; fx :: Union{AbstractVector, Nothing} = nothing, gx :: Union{AbstractVector, Nothing} = nothing, Hx :: Union{MatrixType, Nothing} = nothing, cx :: Union{AbstractVector, Nothing} = nothing, Jx :: Union{MatrixType, Nothing} = nothing, lambda :: Union{AbstractVector, Nothing} = nothing, mu :: Union{AbstractVector, Nothing} = nothing, matrix_info :: Bool = true, kwargs...)`
"""
function fill_in!(
stp::NLPStopping{Pb, M, SRC, NLPAtX{S, T}, MStp, LoS},
stp::NLPStopping{Pb, M, SRC, NLPAtX{Score, S, T}, MStp, LoS},
x::T;
fx::Union{eltype(T), Nothing} = nothing,
gx::Union{T, Nothing} = nothing,
Expand All @@ -219,7 +219,7 @@ function fill_in!(
matrix_info::Bool = true,
convert::Bool = true,
kwargs...,
) where {Pb, M, SRC, MStp, LoS, S, T}
) where {Pb, M, SRC, MStp, LoS, Score, S, T}
gfx = isnothing(fx) ? obj(stp.pb, x) : fx
ggx = isnothing(gx) ? grad(stp.pb, x) : gx

Expand Down Expand Up @@ -411,17 +411,17 @@ Note:
- if minimize problem (i.e. nlp.meta.minimize is true) check if `state.fx <= - meta.unbounded_threshold`, otherwise check `state.fx ≥ meta.unbounded_threshold`.
"""
function _unbounded_problem_check!(
stp::NLPStopping{Pb, M, SRC, NLPAtX{S, T}, MStp, LoS},
stp::NLPStopping{Pb, M, SRC, NLPAtX{Score, S, T}, MStp, LoS},
x::AbstractVector,
) where {Pb, M, SRC, MStp, LoS, S, T}
if isnan(stp.current_state.fx)
) where {Pb, M, SRC, MStp, LoS, Score, S, T}
if isnan(get_fx(stp.current_state))
stp.current_state.fx = obj(stp.pb, x)
end

if stp.pb.meta.minimize
f_too_large = stp.current_state.fx <= -stp.meta.unbounded_threshold
f_too_large = get_fx(stp.current_state) <= -stp.meta.unbounded_threshold
else
f_too_large = stp.current_state.fx >= stp.meta.unbounded_threshold
f_too_large = get_fx(stp.current_state) >= stp.meta.unbounded_threshold
end

if f_too_large
Expand All @@ -435,14 +435,14 @@ function _unbounded_problem_check!(
stp::NLPStopping{Pb, M, SRC, OneDAtX{S, T}, MStp, LoS},
x::Union{AbstractVector, Number},
) where {Pb, M, SRC, MStp, LoS, S, T}
if isnan(stp.current_state.fx)
if isnan(get_fx(stp.current_state))
stp.current_state.fx = obj(stp.pb, x)
end

if stp.pb.meta.minimize
f_too_large = stp.current_state.fx <= -stp.meta.unbounded_threshold
f_too_large = get_fx(stp.current_state) <= -stp.meta.unbounded_threshold
else
f_too_large = stp.current_state.fx >= stp.meta.unbounded_threshold
f_too_large = get_fx(stp.current_state) >= stp.meta.unbounded_threshold
end

return stp.meta.unbounded_pb
Expand Down