You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.
I think this should be possible and is a bug rather than a feature request.
Thanks!
Reproducible example
# Load packages and prepare data ------------------------------------------# Code adapted from: https://parsnip.tidymodels.org/articles/Examples.html
library(tidymodels)
library(dbarts)
#> #> Attaching package: 'dbarts'#> The following object is masked from 'package:tidyr':#> #> extract#> The following object is masked from 'package:parsnip':#> #> bart
library(rstanarm)
#> Loading required package: Rcpp#> #> Attaching package: 'Rcpp'#> The following object is masked from 'package:rsample':#> #> populate#> This is rstanarm version 2.21.4#> - See https://mc-stan.org/rstanarm/articles/priors for changes to default priors!#> - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.#> - For execution on a local, multicore CPU with excess RAM we recommend calling#> options(mc.cores = parallel::detectCores())
tidymodels_prefer()
# Data
data(two_class_dat)
data_train<-two_class_dat[-(1:10), ]
data_test<-two_class_dat[ 1:10 , ]
# BART model --------------------------------------------------------------# Example to show that BART model does not include the standard error for the# predictions# BART model specification and fit
set.seed(1)
bt_cls_fit<-parsnip::bart() %>%
set_mode("classification") %>%
set_engine("dbarts") %>%
fit(Class~., data=data_train)
# Make predictions - output does not include the .std_error column
set.seed(2)
bind_cols(
predict(bt_cls_fit, data_test, type="prob"),
predict(bt_cls_fit, data_test, type="pred_int", std_error=TRUE)
) %>%
select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 2#> .pred_Class1 .pred_Class2#> <dbl> <dbl>#> 1 0.344 0.656#> 2 0.82 0.18 #> 3 0.562 0.438#> 4 0.608 0.392#> 5 0.438 0.562#> 6 0.234 0.766#> 7 0.632 0.368#> 8 0.448 0.552#> 9 0.971 0.029#> 10 0.0780 0.922# Check to see if the information is contained in the dbarts model fit object to# calculate the standard error for the predictions (i.e. check to make sure that# calling std_error = TRUE in predict can return the standard error). It looks # like the information is there but I'm not 100% sure that the below is correct.# Extract dbarts fitbt_cls_eng<-
extract_fit_engine(bt_cls_fit)
# Make predictions
set.seed(2)
bt_cls_pred<-
predict(bt_cls_eng, data_test)
# Summarise posterior predictions for each obeservation
tibble(
.pred_class1=1- apply(bt_cls_pred, 2, base::mean, na.rm=TRUE),
.pred_class2= apply(bt_cls_pred, 2, base::mean, na.rm=TRUE),
.std_error= apply(bt_cls_pred, 2, stats::sd, na.rm=TRUE)
)
#> # A tibble: 10 × 3#> .pred_class1 .pred_class2 .std_error#> <dbl> <dbl> <dbl>#> 1 0.335 0.665 0.0947#> 2 0.830 0.170 0.0743#> 3 0.586 0.414 0.0969#> 4 0.606 0.394 0.123 #> 5 0.434 0.566 0.104 #> 6 0.231 0.769 0.0876#> 7 0.649 0.351 0.111 #> 8 0.448 0.552 0.110 #> 9 0.977 0.0228 0.0265#> 10 0.0785 0.922 0.0486# Stan model --------------------------------------------------------------# Example to show that a Stan model does include the standard error for the # predictions (what I'm hoping the bart model can provide).# Stan model specification and fit
set.seed(1)
logreg_cls_fit<-
logistic_reg() %>%
set_engine("stan") %>%
fit(Class~., data=data_train)
# Make predictions - output includes the .std_error column
bind_cols(
predict(logreg_cls_fit, data_test, type="prob"),
predict(logreg_cls_fit, data_test, type="pred_int", std_error=TRUE)
) %>%
select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 3#> .pred_Class1 .pred_Class2 .std_error#> <dbl> <dbl> <dbl>#> 1 0.518 0.482 0.500 #> 2 0.909 0.0909 0.287 #> 3 0.650 0.350 0.474 #> 4 0.609 0.391 0.491 #> 5 0.443 0.557 0.497 #> 6 0.206 0.794 0.402 #> 7 0.708 0.292 0.454 #> 8 0.568 0.432 0.497 #> 9 0.994 0.00580 0.0834#> 10 0.108 0.892 0.313
Thanks for the issue @jdberson ! I think you're right, this should be possible.
dbart_predict_calc() has a std_err argument but neither does it get used inside the function nor does it get passed the std_error from predict(). I'd take this as an indication that the intention was there but it hasn't happened yet!
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.
The problem
I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.
I think this should be possible and is a bug rather than a feature request.
Thanks!
Reproducible example
Created on 2023-05-29 with reprex v2.0.2
Session info
The text was updated successfully, but these errors were encountered: