Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bart model predictions do not provide the standard error for the predictions #976

Closed
jdberson opened this issue May 29, 2023 · 3 comments · Fixed by #978
Closed

bart model predictions do not provide the standard error for the predictions #976

jdberson opened this issue May 29, 2023 · 3 comments · Fixed by #978
Labels
feature a feature request or enhancement

Comments

@jdberson
Copy link

The problem

I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.

I think this should be possible and is a bug rather than a feature request.

Thanks!

Reproducible example

# Load packages and prepare data ------------------------------------------

# Code adapted from: https://parsnip.tidymodels.org/articles/Examples.html

library(tidymodels)
library(dbarts)
#> 
#> Attaching package: 'dbarts'
#> The following object is masked from 'package:tidyr':
#> 
#>     extract
#> The following object is masked from 'package:parsnip':
#> 
#>     bart
library(rstanarm)
#> Loading required package: Rcpp
#> 
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#> 
#>     populate
#> This is rstanarm version 2.21.4
#> - See https://mc-stan.org/rstanarm/articles/priors for changes to default priors!
#> - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.
#> - For execution on a local, multicore CPU with excess RAM we recommend calling
#>   options(mc.cores = parallel::detectCores())

tidymodels_prefer()

# Data
data(two_class_dat)
data_train <- two_class_dat[-(1:10), ]
data_test  <- two_class_dat[  1:10 , ]


# BART model --------------------------------------------------------------

# Example to show that BART model does not include the standard error for the
# predictions

# BART model specification and fit
set.seed(1)
bt_cls_fit <- 
  parsnip::bart() %>% 
  set_mode("classification") %>% 
  set_engine("dbarts") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output does not include the .std_error column
set.seed(2)
bind_cols(
  predict(bt_cls_fit, data_test, type = "prob"),
  predict(bt_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 2
#>    .pred_Class1 .pred_Class2
#>           <dbl>        <dbl>
#>  1       0.344         0.656
#>  2       0.82          0.18 
#>  3       0.562         0.438
#>  4       0.608         0.392
#>  5       0.438         0.562
#>  6       0.234         0.766
#>  7       0.632         0.368
#>  8       0.448         0.552
#>  9       0.971         0.029
#> 10       0.0780        0.922

# Check to see if the information is contained in the dbarts model fit object to
# calculate the standard error for the predictions (i.e. check to make sure that
# calling std_error = TRUE in predict can return the standard error). It looks 
# like the information is there but I'm not 100% sure that the below is correct.

# Extract dbarts fit
bt_cls_eng <- 
  extract_fit_engine(bt_cls_fit)

# Make predictions
set.seed(2)
bt_cls_pred <- 
  predict(bt_cls_eng, data_test)

# Summarise posterior predictions for each obeservation
tibble(
  .pred_class1 = 1 - apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .pred_class2 = apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .std_error = apply(bt_cls_pred, 2, stats::sd, na.rm = TRUE)
)
#> # A tibble: 10 × 3
#>    .pred_class1 .pred_class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1       0.335        0.665      0.0947
#>  2       0.830        0.170      0.0743
#>  3       0.586        0.414      0.0969
#>  4       0.606        0.394      0.123 
#>  5       0.434        0.566      0.104 
#>  6       0.231        0.769      0.0876
#>  7       0.649        0.351      0.111 
#>  8       0.448        0.552      0.110 
#>  9       0.977        0.0228     0.0265
#> 10       0.0785       0.922      0.0486


# Stan model --------------------------------------------------------------

# Example to show that a Stan model does include the standard error for the 
# predictions (what I'm hoping the bart model can provide).

# Stan model specification and fit
set.seed(1)
logreg_cls_fit <- 
  logistic_reg() %>% 
  set_engine("stan") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output includes the .std_error column
bind_cols(
  predict(logreg_cls_fit, data_test, type = "prob"),
  predict(logreg_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 3
#>    .pred_Class1 .pred_Class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1        0.518      0.482       0.500 
#>  2        0.909      0.0909      0.287 
#>  3        0.650      0.350       0.474 
#>  4        0.609      0.391       0.491 
#>  5        0.443      0.557       0.497 
#>  6        0.206      0.794       0.402 
#>  7        0.708      0.292       0.454 
#>  8        0.568      0.432       0.497 
#>  9        0.994      0.00580     0.0834
#> 10        0.108      0.892       0.313

Created on 2023-05-29 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.3 (2023-03-15 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Australia.utf8
#>  ctype    English_Australia.utf8
#>  tz       Australia/Perth
#>  date     2023-05-29
#>  pandoc   2.19.2 @ C:/program files/rstudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package      * version    date (UTC) lib source
#>    backports      1.4.1      2021-12-13 [2] CRAN (R 4.2.0)
#>    base64enc      0.1-3      2015-07-28 [2] CRAN (R 4.2.0)
#>    bayesplot      1.10.0     2022-11-16 [1] CRAN (R 4.2.3)
#>    boot           1.3-28.1   2022-11-22 [2] CRAN (R 4.2.3)
#>    broom        * 1.0.4      2023-03-11 [2] CRAN (R 4.2.3)
#>    cachem         1.0.8      2023-05-01 [1] CRAN (R 4.2.3)
#>    callr          3.7.3      2022-11-02 [2] CRAN (R 4.2.3)
#>    class          7.3-21     2023-01-23 [2] CRAN (R 4.2.3)
#>    cli            3.6.1      2023-03-23 [1] CRAN (R 4.2.3)
#>    codetools      0.2-19     2023-02-01 [2] CRAN (R 4.2.3)
#>    colorspace     2.1-0      2023-01-23 [2] CRAN (R 4.2.3)
#>    colourpicker   1.2.0      2022-10-28 [1] CRAN (R 4.2.3)
#>    conflicted     1.2.0      2023-02-01 [2] CRAN (R 4.2.3)
#>    crayon         1.5.2      2022-09-29 [2] CRAN (R 4.2.3)
#>    crosstalk      1.2.0      2021-11-04 [2] CRAN (R 4.2.3)
#>    curl           5.0.0      2023-01-12 [2] CRAN (R 4.2.3)
#>    data.table     1.14.8     2023-02-17 [2] CRAN (R 4.2.3)
#>    dbarts       * 0.9-23     2023-01-23 [1] CRAN (R 4.2.3)
#>    dials        * 1.2.0      2023-04-03 [2] CRAN (R 4.2.3)
#>    DiceDesign     1.9        2021-02-13 [2] CRAN (R 4.2.3)
#>    digest         0.6.31     2022-12-11 [2] CRAN (R 4.2.3)
#>    dplyr        * 1.1.2      2023-04-20 [1] CRAN (R 4.2.3)
#>    DT             0.27       2023-01-17 [1] CRAN (R 4.2.3)
#>    dygraphs       1.1.1.6    2018-07-11 [1] CRAN (R 4.2.3)
#>    ellipsis       0.3.2      2021-04-29 [2] CRAN (R 4.2.3)
#>    evaluate       0.21       2023-05-05 [1] CRAN (R 4.2.3)
#>    fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.2)
#>    fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.2.3)
#>    foreach        1.5.2      2022-02-02 [2] CRAN (R 4.2.3)
#>    fs             1.6.2      2023-04-25 [1] CRAN (R 4.2.3)
#>    furrr          0.3.1      2022-08-15 [2] CRAN (R 4.2.3)
#>    future         1.32.0     2023-03-07 [2] CRAN (R 4.2.3)
#>    future.apply   1.10.0     2022-11-05 [2] CRAN (R 4.2.3)
#>    generics       0.1.3      2022-07-05 [2] CRAN (R 4.2.3)
#>    ggplot2      * 3.4.2      2023-04-03 [1] CRAN (R 4.2.3)
#>    globals        0.16.2     2022-11-21 [2] CRAN (R 4.2.2)
#>    glue           1.6.2      2022-02-24 [2] CRAN (R 4.2.3)
#>    gower          1.0.1      2022-12-22 [2] CRAN (R 4.2.2)
#>    GPfit          1.0-8      2019-02-08 [2] CRAN (R 4.2.3)
#>    gridExtra      2.3        2017-09-09 [2] CRAN (R 4.2.3)
#>    gtable         0.3.3      2023-03-21 [2] CRAN (R 4.2.3)
#>    gtools         3.9.4      2022-11-27 [1] CRAN (R 4.2.3)
#>    hardhat        1.3.0      2023-03-30 [2] CRAN (R 4.2.3)
#>    htmltools      0.5.5      2023-03-23 [2] CRAN (R 4.2.3)
#>    htmlwidgets    1.6.2      2023-03-17 [2] CRAN (R 4.2.3)
#>    httpuv         1.6.11     2023-05-11 [1] CRAN (R 4.2.3)
#>    igraph         1.4.3      2023-05-22 [1] CRAN (R 4.2.3)
#>    infer        * 1.0.4      2022-12-02 [2] CRAN (R 4.2.3)
#>    inline         0.3.19     2021-05-31 [1] CRAN (R 4.2.3)
#>    ipred          0.9-14     2023-03-09 [2] CRAN (R 4.2.3)
#>    iterators      1.0.14     2022-02-05 [2] CRAN (R 4.2.3)
#>    jsonlite       1.8.4      2022-12-06 [1] CRAN (R 4.2.2)
#>    knitr          1.42       2023-01-25 [2] CRAN (R 4.2.3)
#>    later          1.3.1      2023-05-02 [1] CRAN (R 4.2.3)
#>    lattice        0.20-45    2021-09-22 [2] CRAN (R 4.2.3)
#>    lava           1.7.2.1    2023-02-27 [2] CRAN (R 4.2.3)
#>    lhs            1.1.6      2022-12-17 [2] CRAN (R 4.2.3)
#>    lifecycle      1.0.3      2022-10-07 [2] CRAN (R 4.2.3)
#>    listenv        0.9.0      2022-12-16 [2] CRAN (R 4.2.3)
#>    lme4           1.1-33     2023-04-25 [1] CRAN (R 4.2.3)
#>    loo            2.6.0      2023-03-31 [1] CRAN (R 4.2.3)
#>    lubridate      1.9.2      2023-02-10 [1] CRAN (R 4.2.2)
#>    magrittr       2.0.3      2022-03-30 [2] CRAN (R 4.2.3)
#>    markdown       1.7        2023-05-16 [1] CRAN (R 4.2.3)
#>    MASS           7.3-58.2   2023-01-23 [2] CRAN (R 4.2.3)
#>    Matrix         1.5-3      2022-11-11 [1] CRAN (R 4.2.2)
#>    matrixStats    0.63.0     2022-11-18 [1] CRAN (R 4.2.3)
#>    memoise        2.0.1      2021-11-26 [2] CRAN (R 4.2.3)
#>    mime           0.12       2021-09-28 [2] CRAN (R 4.2.0)
#>    miniUI         0.1.1.1    2018-05-18 [2] CRAN (R 4.2.3)
#>    minqa          1.2.5      2022-10-19 [2] CRAN (R 4.2.3)
#>    modeldata    * 1.1.0      2023-01-25 [2] CRAN (R 4.2.3)
#>    munsell        0.5.0      2018-06-12 [2] CRAN (R 4.2.3)
#>    nlme           3.1-162    2023-01-31 [2] CRAN (R 4.2.3)
#>    nloptr         2.0.3      2022-05-26 [2] CRAN (R 4.2.3)
#>    nnet           7.3-18     2022-09-28 [2] CRAN (R 4.2.3)
#>    parallelly     1.35.0     2023-03-23 [2] CRAN (R 4.2.3)
#>    parsnip      * 1.1.0      2023-04-12 [2] CRAN (R 4.2.3)
#>    pillar         1.9.0      2023-03-22 [2] CRAN (R 4.2.3)
#>    pkgbuild       1.4.0      2022-11-27 [2] CRAN (R 4.2.3)
#>    pkgconfig      2.0.3      2019-09-22 [2] CRAN (R 4.2.3)
#>    plyr           1.8.8      2022-11-11 [2] CRAN (R 4.2.3)
#>    prettyunits    1.1.1      2020-01-24 [2] CRAN (R 4.2.3)
#>    processx       3.8.1      2023-04-18 [1] CRAN (R 4.2.3)
#>    prodlim        2023.03.31 2023-04-02 [2] CRAN (R 4.2.3)
#>    promises       1.2.0.1    2021-02-11 [2] CRAN (R 4.2.3)
#>    ps             1.7.5      2023-04-18 [1] CRAN (R 4.2.3)
#>    purrr        * 1.0.1      2023-01-10 [1] CRAN (R 4.2.2)
#>    R6             2.5.1      2021-08-19 [2] CRAN (R 4.2.3)
#>    Rcpp         * 1.0.10     2023-01-22 [1] CRAN (R 4.2.2)
#>  D RcppParallel   5.1.7      2023-02-27 [1] CRAN (R 4.2.3)
#>    recipes      * 1.0.6      2023-04-25 [2] CRAN (R 4.2.3)
#>    reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>    reshape2       1.4.4      2020-04-09 [1] CRAN (R 4.2.3)
#>    rlang          1.1.1      2023-04-28 [1] CRAN (R 4.2.3)
#>    rmarkdown      2.21       2023-03-26 [2] CRAN (R 4.2.3)
#>    rpart          4.1.19     2022-10-21 [2] CRAN (R 4.2.3)
#>    rsample      * 1.1.1      2022-12-07 [2] CRAN (R 4.2.3)
#>    rstan          2.26.15    2023-02-11 [1] local
#>    rstanarm     * 2.21.4     2023-04-08 [1] CRAN (R 4.2.3)
#>    rstantools     2.3.1      2023-03-30 [1] CRAN (R 4.2.3)
#>    rstudioapi     0.14       2022-08-22 [2] CRAN (R 4.2.3)
#>    scales       * 1.2.1      2022-08-20 [2] CRAN (R 4.2.3)
#>    sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.2.3)
#>    shiny          1.7.4      2022-12-15 [2] CRAN (R 4.2.3)
#>    shinyjs        2.1.0      2021-12-23 [1] CRAN (R 4.2.3)
#>    shinystan      2.6.0      2022-03-03 [1] CRAN (R 4.2.3)
#>    shinythemes    1.2.0      2021-01-25 [1] CRAN (R 4.2.3)
#>    StanHeaders    2.26.15    2023-02-11 [1] local
#>    stringi        1.7.12     2023-01-11 [1] CRAN (R 4.2.2)
#>    stringr        1.5.0      2022-12-02 [1] CRAN (R 4.2.2)
#>    survival       3.5-3      2023-02-12 [2] CRAN (R 4.2.3)
#>    threejs        0.3.3      2020-01-21 [1] CRAN (R 4.2.3)
#>    tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.3)
#>    tidymodels   * 1.1.0      2023-05-01 [1] CRAN (R 4.2.3)
#>    tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.2)
#>    tidyselect     1.2.0      2022-10-10 [2] CRAN (R 4.2.3)
#>    timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.2)
#>    timeDate       4022.108   2023-01-07 [2] CRAN (R 4.2.3)
#>    tune         * 1.1.1      2023-04-11 [2] CRAN (R 4.2.3)
#>    utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.2)
#>    V8             4.3.0      2023-04-08 [1] CRAN (R 4.2.3)
#>    vctrs          0.6.2      2023-04-19 [1] CRAN (R 4.2.3)
#>    withr          2.5.0      2022-03-03 [2] CRAN (R 4.2.3)
#>    workflows    * 1.1.3      2023-02-22 [2] CRAN (R 4.2.3)
#>    workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.2.3)
#>    xfun           0.39       2023-04-20 [1] CRAN (R 4.2.3)
#>    xtable         1.8-4      2019-04-21 [2] CRAN (R 4.2.3)
#>    xts            0.13.1     2023-04-16 [1] CRAN (R 4.2.3)
#>    yaml           2.3.7      2023-01-23 [2] CRAN (R 4.2.3)
#>    yardstick    * 1.2.0      2023-04-21 [2] CRAN (R 4.2.3)
#>    zoo            1.8-12     2023-04-13 [1] CRAN (R 4.2.3)
#> 
#>  [1] C:/Users/00055815/AppData/Local/R/win-library/4.2
#>  [2] C:/Program Files/R/R-4.2.3/library
#> 
#>  D ── DLL MD5 mismatch, broken installation.
#> 
#> ──────────────────────────────────────────────────────────────────────────────
@hfrick
Copy link
Member

hfrick commented Jun 5, 2023

Thanks for the issue @jdberson ! I think you're right, this should be possible.

dbart_predict_calc() has a std_err argument but neither does it get used inside the function nor does it get passed the std_error from predict(). I'd take this as an indication that the intention was there but it hasn't happened yet!

@hfrick hfrick added the feature a feature request or enhancement label Jun 5, 2023
@jdberson
Copy link
Author

jdberson commented Jun 6, 2023

Hi @hfrick thanks very much for looking into this and for making the changes so quickly! That is absolutely brilliant.

@github-actions
Copy link

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jun 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
feature a feature request or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants