From f1e421319658e94de665f9374dffd9d12c76f267 Mon Sep 17 00:00:00 2001 From: tmigot Date: Fri, 9 Sep 2022 14:35:49 -0400 Subject: [PATCH 1/3] add parametric type for `fx` --- src/State/NLPAtXmod.jl | 16 ++++++++-------- src/Stopping/NLPStoppingmod.jl | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/State/NLPAtXmod.jl b/src/State/NLPAtXmod.jl index 4415a01d..ed85edd3 100644 --- a/src/State/NLPAtXmod.jl +++ b/src/State/NLPAtXmod.jl @@ -45,11 +45,11 @@ Note: See also: `GenericState`, `update!`, `update_and_start!`, `update_and_stop!`, `reinit!` """ -mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T} +mutable struct NLPAtX{Score, S, T <: AbstractVector} <: AbstractState{S, T} #Unconstrained State x::T # current point - fx::eltype(T) # objective function + fx::S # objective function gx::T # gradient size: x Hx # hessian size: |x| x |x| @@ -66,12 +66,12 @@ mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T} #Resources State current_time::Float64 - current_score::S + current_score::Score function NLPAtX( x::T, lambda::T, - current_score::S; + current_score::Score; fx::eltype(T) = _init_field(eltype(T)), gx::T = _init_field(T), Hx = _init_field(Matrix{eltype(T)}), @@ -81,10 +81,10 @@ mutable struct NLPAtX{S, T <: AbstractVector} <: AbstractState{S, T} d::T = _init_field(T), res::T = _init_field(T), current_time::Float64 = NaN, - ) where {S, T <: AbstractVector} + ) where {Score, S, T <: AbstractVector} _size_check(x, lambda, fx, gx, Hx, mu, cx, Jx) - return new{S, T}(x, fx, gx, Hx, mu, cx, Jx, lambda, d, res, current_time, current_score) + return new{Score, eltype(T), T}(x, fx, gx, Hx, mu, cx, Jx, lambda, d, res, current_time, current_score) end end @@ -166,7 +166,7 @@ reinit!: function that set all the entries at void except the mandatory x Note: if `x` or `lambda` are given as keyword arguments they will be prioritized over the existing `x`, `lambda` and the default `Counters`. """ -function reinit!(stateatx::NLPAtX{S, T}, x::T, l::T; kwargs...) where {S, T} +function reinit!(stateatx::NLPAtX{Score, S, T}, x::T, l::T; kwargs...) where {Score, S, T} for k ∈ fieldnames(NLPAtX) if k ∉ [:x, :lambda] setfield!(stateatx, k, _init_field(typeof(getfield(stateatx, k)))) @@ -183,7 +183,7 @@ function reinit!(stateatx::NLPAtX{S, T}, x::T, l::T; kwargs...) where {S, T} return update!(stateatx; kwargs...) end -function reinit!(stateatx::NLPAtX{S, T}, x::T; kwargs...) where {S, T} +function reinit!(stateatx::NLPAtX{Score, S, T}, x::T; kwargs...) where {Score, S, T} for k ∈ fieldnames(NLPAtX) if k ∉ [:x, :lambda] setfield!(stateatx, k, _init_field(typeof(getfield(stateatx, k)))) diff --git a/src/Stopping/NLPStoppingmod.jl b/src/Stopping/NLPStoppingmod.jl index 5869da7a..a76763ce 100644 --- a/src/Stopping/NLPStoppingmod.jl +++ b/src/Stopping/NLPStoppingmod.jl @@ -207,7 +207,7 @@ fill_in!: (NLPStopping version) a function that fill in the required values in t `fill_in!( :: NLPStopping, :: Union{AbstractVector, Nothing}; fx :: Union{AbstractVector, Nothing} = nothing, gx :: Union{AbstractVector, Nothing} = nothing, Hx :: Union{MatrixType, Nothing} = nothing, cx :: Union{AbstractVector, Nothing} = nothing, Jx :: Union{MatrixType, Nothing} = nothing, lambda :: Union{AbstractVector, Nothing} = nothing, mu :: Union{AbstractVector, Nothing} = nothing, matrix_info :: Bool = true, kwargs...)` """ function fill_in!( - stp::NLPStopping{Pb, M, SRC, NLPAtX{S, T}, MStp, LoS}, + stp::NLPStopping{Pb, M, SRC, NLPAtX{Score, S, T}, MStp, LoS}, x::T; fx::Union{eltype(T), Nothing} = nothing, gx::Union{T, Nothing} = nothing, @@ -219,7 +219,7 @@ function fill_in!( matrix_info::Bool = true, convert::Bool = true, kwargs..., -) where {Pb, M, SRC, MStp, LoS, S, T} +) where {Pb, M, SRC, MStp, LoS, Score, S, T} gfx = isnothing(fx) ? obj(stp.pb, x) : fx ggx = isnothing(gx) ? grad(stp.pb, x) : gx @@ -411,9 +411,9 @@ Note: - if minimize problem (i.e. nlp.meta.minimize is true) check if `state.fx <= - meta.unbounded_threshold`, otherwise check `state.fx ≥ meta.unbounded_threshold`. """ function _unbounded_problem_check!( - stp::NLPStopping{Pb, M, SRC, NLPAtX{S, T}, MStp, LoS}, + stp::NLPStopping{Pb, M, SRC, NLPAtX{Score, S, T}, MStp, LoS}, x::AbstractVector, -) where {Pb, M, SRC, MStp, LoS, S, T} +) where {Pb, M, SRC, MStp, LoS, Score, S, T} if isnan(stp.current_state.fx) stp.current_state.fx = obj(stp.pb, x) end From f89f686ff53a326f0aa0280204f93f12ecb0fd1f Mon Sep 17 00:00:00 2001 From: tmigot Date: Fri, 9 Sep 2022 14:46:51 -0400 Subject: [PATCH 2/3] add getter for the state --- src/State/GenericStatemod.jl | 12 ++++++++++++ src/State/NLPAtXmod.jl | 12 ++++++++++++ src/State/OneDAtXmod.jl | 12 ++++++++++++ 3 files changed, 36 insertions(+) diff --git a/src/State/GenericStatemod.jl b/src/State/GenericStatemod.jl index 7f1263fa..924b2354 100644 --- a/src/State/GenericStatemod.jl +++ b/src/State/GenericStatemod.jl @@ -82,6 +82,18 @@ end scoretype(typestate::AbstractState{S, T}) where {S, T} = S xtype(typestate::AbstractState{S, T}) where {S, T} = T +for field in fieldnames(GenericState) + meth = Symbol("get_", field) + @eval begin + @doc """ + $($meth)(state) + Return the value $($(QuoteNode(field))) from the state. + """ + $meth(state::GenericState) = getproperty(state, $(QuoteNode(field))) + end + @eval export $meth +end + """ `update!(:: AbstractState; convert = false, kwargs...)` diff --git a/src/State/NLPAtXmod.jl b/src/State/NLPAtXmod.jl index ed85edd3..047bff3b 100644 --- a/src/State/NLPAtXmod.jl +++ b/src/State/NLPAtXmod.jl @@ -156,6 +156,18 @@ function NLPAtX( ) end +for field in fieldnames(NLPAtX) + meth = Symbol("get_", field) + @eval begin + @doc """ + $($meth)(state) + Return the value $($(QuoteNode(field))) from the state. + """ + $meth(state::NLPAtX) = getproperty(state, $(QuoteNode(field))) + end + @eval export $meth +end + """ reinit!: function that set all the entries at void except the mandatory x diff --git a/src/State/OneDAtXmod.jl b/src/State/OneDAtXmod.jl index d871af82..25bec766 100644 --- a/src/State/OneDAtXmod.jl +++ b/src/State/OneDAtXmod.jl @@ -79,3 +79,15 @@ function OneDAtX( current_time = current_time, ) end + +for field in fieldnames(OneDAtX) + meth = Symbol("get_", field) + @eval begin + @doc """ + $($meth)(state) + Return the value $($(QuoteNode(field))) from the state. + """ + $meth(state::OneDAtX) = getproperty(state, $(QuoteNode(field))) + end + @eval export $meth +end From 486a2be064984af0b82c7e10de84871aefe3fab4 Mon Sep 17 00:00:00 2001 From: tmigot Date: Fri, 9 Sep 2022 14:50:54 -0400 Subject: [PATCH 3/3] use `get_fx` function to check unboundedness --- src/Stopping/NLPStoppingmod.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Stopping/NLPStoppingmod.jl b/src/Stopping/NLPStoppingmod.jl index a76763ce..565e51ab 100644 --- a/src/Stopping/NLPStoppingmod.jl +++ b/src/Stopping/NLPStoppingmod.jl @@ -414,14 +414,14 @@ function _unbounded_problem_check!( stp::NLPStopping{Pb, M, SRC, NLPAtX{Score, S, T}, MStp, LoS}, x::AbstractVector, ) where {Pb, M, SRC, MStp, LoS, Score, S, T} - if isnan(stp.current_state.fx) + if isnan(get_fx(stp.current_state)) stp.current_state.fx = obj(stp.pb, x) end if stp.pb.meta.minimize - f_too_large = stp.current_state.fx <= -stp.meta.unbounded_threshold + f_too_large = get_fx(stp.current_state) <= -stp.meta.unbounded_threshold else - f_too_large = stp.current_state.fx >= stp.meta.unbounded_threshold + f_too_large = get_fx(stp.current_state) >= stp.meta.unbounded_threshold end if f_too_large @@ -435,14 +435,14 @@ function _unbounded_problem_check!( stp::NLPStopping{Pb, M, SRC, OneDAtX{S, T}, MStp, LoS}, x::Union{AbstractVector, Number}, ) where {Pb, M, SRC, MStp, LoS, S, T} - if isnan(stp.current_state.fx) + if isnan(get_fx(stp.current_state)) stp.current_state.fx = obj(stp.pb, x) end if stp.pb.meta.minimize - f_too_large = stp.current_state.fx <= -stp.meta.unbounded_threshold + f_too_large = get_fx(stp.current_state) <= -stp.meta.unbounded_threshold else - f_too_large = stp.current_state.fx >= stp.meta.unbounded_threshold + f_too_large = get_fx(stp.current_state) >= stp.meta.unbounded_threshold end return stp.meta.unbounded_pb