Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QoL improvements for SDeMo #367

Merged
merged 15 commits into from
Feb 28, 2025
Prev Previous commit
feat(demo): do not display the regularization part of the loss
  • Loading branch information
tpoisot committed Feb 28, 2025
commit 81450b57cfa3e2cf66ebdab8a9a02c705eccd087
21 changes: 12 additions & 9 deletions SDeMo/src/classifiers/logistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function __sigmoid!(store::Vector{<:AbstractFloat}, z::Vector{<:AbstractFloat})
store[i] = __sigmoid(z[i])
end
return store
end
end

@testitem "We get the correct response for a sigmoid" begin
@test SDeMo.__sigmoid(0.0) == 0.5
Expand Down Expand Up @@ -85,7 +85,12 @@ end

Base.zero(::Type{Logistic}) = 0.5

function SDeMo.train!(lreg::Logistic, y::Vector{Bool}, X::Matrix{T}; kwargs...) where {T <: Number}
function SDeMo.train!(
lreg::Logistic,
y::Vector{Bool},
X::Matrix{T};
kwargs...,
) where {T <: Number}
# Get the validation data if relevant
Xt = get(kwargs, :Xt, nothing)
yt = get(kwargs, :yt, nothing)
Expand All @@ -108,12 +113,10 @@ function SDeMo.train!(lreg::Logistic, y::Vector{Bool}, X::Matrix{T}; kwargs...)
if validation_data
zv = Xvt * lreg.θ
__sigmoid!(zv, zv)
validation_loss = -mean(yt .* log.(zv) .+ (1 .- yt) .* log.(1 .- zv)) + (lreg.λ / (2 * length(lreg.θ))) * sum(lreg.θ[2:end] .^ 2)
validation_loss = -mean(yt .* log.(zv) .+ (1 .- yt) .* log.(1 .- zv))
end
z = clamp.(z, eps(), 1 - eps())
loss =
-mean(y .* log.(z) .+ (1 .- y) .* log.(1 .- z)) +
(lreg.λ / (2 * length(lreg.θ))) * sum(lreg.θ[2:end] .^ 2)
loss = -mean(y .* log.(z) .+ (1 .- y) .* log.(1 .- z))
# Percent done
prct = lpad(round(Int64, (epoch / lreg.epochs) * 100), 3, " ")
infostr = "[$(prct)%] LOSS: training ≈ $(rpad(round(loss; digits=4), 6, " "))"
Expand Down Expand Up @@ -205,17 +208,17 @@ end

@testitem "We can run a Logistic model" begin
X, y = SDeMo.__demodata()
sdm = SDM(ZScore(), Logistic(), 0.5, X, y, [1,2,12])
sdm = SDM(ZScore(), Logistic(), 0.5, X, y, [1, 2, 12])
folds = holdout(sdm)
classifier(sdm).verbose = true
classifier(sdm).η = 1e-3
classifier(sdm).verbosity = 10
train!(sdm; training=folds[1])
train!(sdm; training = folds[1])
end

@testitem "We can run a verbose Logistic model with no training data" begin
X, y = SDeMo.__demodata()
sdm = SDM(ZScore(), Logistic(), 0.5, X, y, [1,2,12])
sdm = SDM(ZScore(), Logistic(), 0.5, X, y, [1, 2, 12])
classifier(sdm).verbose = true
classifier(sdm).η = 1e-3
classifier(sdm).verbosity = 10
Expand Down
Loading