Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bart model predictions do not provide the standard error for the predictions #976

jdberson opened this issue May 29, 2023 · 3 comments · Fixed by #978

bart model predictions do not provide the standard error for the predictions #976

jdberson opened this issue May 29, 2023 · 3 comments · Fixed by #978
feature a feature request or enhancement


Copy link

The problem

I'm having trouble with getting the standard error for the predictions from a parsnip::bart() model in the same way that can be done for a stan model.

I think this should be possible and is a bug rather than a feature request.


Reproducible example

# Load packages and prepare data ------------------------------------------

# Code adapted from:

#> Attaching package: 'dbarts'
#> The following object is masked from 'package:tidyr':
#>     extract
#> The following object is masked from 'package:parsnip':
#>     bart
#> Loading required package: Rcpp
#> Attaching package: 'Rcpp'
#> The following object is masked from 'package:rsample':
#>     populate
#> This is rstanarm version 2.21.4
#> - See for changes to default priors!
#> - Default priors may change, so it's safest to specify priors, even if equivalent to the defaults.
#> - For execution on a local, multicore CPU with excess RAM we recommend calling
#>   options(mc.cores = parallel::detectCores())


# Data
data_train <- two_class_dat[-(1:10), ]
data_test  <- two_class_dat[  1:10 , ]

# BART model --------------------------------------------------------------

# Example to show that BART model does not include the standard error for the
# predictions

# BART model specification and fit
bt_cls_fit <- 
  parsnip::bart() %>% 
  set_mode("classification") %>% 
  set_engine("dbarts") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output does not include the .std_error column
  predict(bt_cls_fit, data_test, type = "prob"),
  predict(bt_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 2
#>    .pred_Class1 .pred_Class2
#>           <dbl>        <dbl>
#>  1       0.344         0.656
#>  2       0.82          0.18 
#>  3       0.562         0.438
#>  4       0.608         0.392
#>  5       0.438         0.562
#>  6       0.234         0.766
#>  7       0.632         0.368
#>  8       0.448         0.552
#>  9       0.971         0.029
#> 10       0.0780        0.922

# Check to see if the information is contained in the dbarts model fit object to
# calculate the standard error for the predictions (i.e. check to make sure that
# calling std_error = TRUE in predict can return the standard error). It looks 
# like the information is there but I'm not 100% sure that the below is correct.

# Extract dbarts fit
bt_cls_eng <- 

# Make predictions
bt_cls_pred <- 
  predict(bt_cls_eng, data_test)

# Summarise posterior predictions for each obeservation
  .pred_class1 = 1 - apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .pred_class2 = apply(bt_cls_pred, 2, base::mean, na.rm = TRUE),
  .std_error = apply(bt_cls_pred, 2, stats::sd, na.rm = TRUE)
#> # A tibble: 10 × 3
#>    .pred_class1 .pred_class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1       0.335        0.665      0.0947
#>  2       0.830        0.170      0.0743
#>  3       0.586        0.414      0.0969
#>  4       0.606        0.394      0.123 
#>  5       0.434        0.566      0.104 
#>  6       0.231        0.769      0.0876
#>  7       0.649        0.351      0.111 
#>  8       0.448        0.552      0.110 
#>  9       0.977        0.0228     0.0265
#> 10       0.0785       0.922      0.0486

# Stan model --------------------------------------------------------------

# Example to show that a Stan model does include the standard error for the 
# predictions (what I'm hoping the bart model can provide).

# Stan model specification and fit
logreg_cls_fit <- 
  logistic_reg() %>% 
  set_engine("stan") %>% 
  fit(Class ~ ., data = data_train)

# Make predictions - output includes the .std_error column
  predict(logreg_cls_fit, data_test, type = "prob"),
  predict(logreg_cls_fit, data_test, type = "pred_int", std_error = TRUE)
) %>%
  select(-contains(c("lower", "upper")))
#> # A tibble: 10 × 3
#>    .pred_Class1 .pred_Class2 .std_error
#>           <dbl>        <dbl>      <dbl>
#>  1        0.518      0.482       0.500 
#>  2        0.909      0.0909      0.287 
#>  3        0.650      0.350       0.474 
#>  4        0.609      0.391       0.491 
#>  5        0.443      0.557       0.497 
#>  6        0.206      0.794       0.402 
#>  7        0.708      0.292       0.454 
#>  8        0.568      0.432       0.497 
#>  9        0.994      0.00580     0.0834
#> 10        0.108      0.892       0.313

Created on 2023-05-29 with reprex v2.0.2

Session info
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.3 (2023-03-15 ucrt)
#>  os       Windows 10 x64 (build 19042)
#>  system   x86_64, mingw32
#>  ui       RTerm
#>  language (EN)
#>  collate  English_Australia.utf8
#>  ctype    English_Australia.utf8
#>  tz       Australia/Perth
#>  date     2023-05-29
#>  pandoc   2.19.2 @ C:/program files/rstudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package      * version    date (UTC) lib source
#>    backports      1.4.1      2021-12-13 [2] CRAN (R 4.2.0)
#>    base64enc      0.1-3      2015-07-28 [2] CRAN (R 4.2.0)
#>    bayesplot      1.10.0     2022-11-16 [1] CRAN (R 4.2.3)
#>    boot           1.3-28.1   2022-11-22 [2] CRAN (R 4.2.3)
#>    broom        * 1.0.4      2023-03-11 [2] CRAN (R 4.2.3)
#>    cachem         1.0.8      2023-05-01 [1] CRAN (R 4.2.3)
#>    callr          3.7.3      2022-11-02 [2] CRAN (R 4.2.3)
#>    class          7.3-21     2023-01-23 [2] CRAN (R 4.2.3)
#>    cli            3.6.1      2023-03-23 [1] CRAN (R 4.2.3)
#>    codetools      0.2-19     2023-02-01 [2] CRAN (R 4.2.3)
#>    colorspace     2.1-0      2023-01-23 [2] CRAN (R 4.2.3)
#>    colourpicker   1.2.0      2022-10-28 [1] CRAN (R 4.2.3)
#>    conflicted     1.2.0      2023-02-01 [2] CRAN (R 4.2.3)
#>    crayon         1.5.2      2022-09-29 [2] CRAN (R 4.2.3)
#>    crosstalk      1.2.0      2021-11-04 [2] CRAN (R 4.2.3)
#>    curl           5.0.0      2023-01-12 [2] CRAN (R 4.2.3)
#>    data.table     1.14.8     2023-02-17 [2] CRAN (R 4.2.3)
#>    dbarts       * 0.9-23     2023-01-23 [1] CRAN (R 4.2.3)
#>    dials        * 1.2.0      2023-04-03 [2] CRAN (R 4.2.3)
#>    DiceDesign     1.9        2021-02-13 [2] CRAN (R 4.2.3)
#>    digest         0.6.31     2022-12-11 [2] CRAN (R 4.2.3)
#>    dplyr        * 1.1.2      2023-04-20 [1] CRAN (R 4.2.3)
#>    DT             0.27       2023-01-17 [1] CRAN (R 4.2.3)
#>    dygraphs    2018-07-11 [1] CRAN (R 4.2.3)
#>    ellipsis       0.3.2      2021-04-29 [2] CRAN (R 4.2.3)
#>    evaluate       0.21       2023-05-05 [1] CRAN (R 4.2.3)
#>    fansi          1.0.4      2023-01-22 [1] CRAN (R 4.2.2)
#>    fastmap        1.1.1      2023-02-24 [2] CRAN (R 4.2.3)
#>    foreach        1.5.2      2022-02-02 [2] CRAN (R 4.2.3)
#>    fs             1.6.2      2023-04-25 [1] CRAN (R 4.2.3)
#>    furrr          0.3.1      2022-08-15 [2] CRAN (R 4.2.3)
#>    future         1.32.0     2023-03-07 [2] CRAN (R 4.2.3)
#>    future.apply   1.10.0     2022-11-05 [2] CRAN (R 4.2.3)
#>    generics       0.1.3      2022-07-05 [2] CRAN (R 4.2.3)
#>    ggplot2      * 3.4.2      2023-04-03 [1] CRAN (R 4.2.3)
#>    globals        0.16.2     2022-11-21 [2] CRAN (R 4.2.2)
#>    glue           1.6.2      2022-02-24 [2] CRAN (R 4.2.3)
#>    gower          1.0.1      2022-12-22 [2] CRAN (R 4.2.2)
#>    GPfit          1.0-8      2019-02-08 [2] CRAN (R 4.2.3)
#>    gridExtra      2.3        2017-09-09 [2] CRAN (R 4.2.3)
#>    gtable         0.3.3      2023-03-21 [2] CRAN (R 4.2.3)
#>    gtools         3.9.4      2022-11-27 [1] CRAN (R 4.2.3)
#>    hardhat        1.3.0      2023-03-30 [2] CRAN (R 4.2.3)
#>    htmltools      0.5.5      2023-03-23 [2] CRAN (R 4.2.3)
#>    htmlwidgets    1.6.2      2023-03-17 [2] CRAN (R 4.2.3)
#>    httpuv         1.6.11     2023-05-11 [1] CRAN (R 4.2.3)
#>    igraph         1.4.3      2023-05-22 [1] CRAN (R 4.2.3)
#>    infer        * 1.0.4      2022-12-02 [2] CRAN (R 4.2.3)
#>    inline         0.3.19     2021-05-31 [1] CRAN (R 4.2.3)
#>    ipred          0.9-14     2023-03-09 [2] CRAN (R 4.2.3)
#>    iterators      1.0.14     2022-02-05 [2] CRAN (R 4.2.3)
#>    jsonlite       1.8.4      2022-12-06 [1] CRAN (R 4.2.2)
#>    knitr          1.42       2023-01-25 [2] CRAN (R 4.2.3)
#>    later          1.3.1      2023-05-02 [1] CRAN (R 4.2.3)
#>    lattice        0.20-45    2021-09-22 [2] CRAN (R 4.2.3)
#>    lava     2023-02-27 [2] CRAN (R 4.2.3)
#>    lhs            1.1.6      2022-12-17 [2] CRAN (R 4.2.3)
#>    lifecycle      1.0.3      2022-10-07 [2] CRAN (R 4.2.3)
#>    listenv        0.9.0      2022-12-16 [2] CRAN (R 4.2.3)
#>    lme4           1.1-33     2023-04-25 [1] CRAN (R 4.2.3)
#>    loo            2.6.0      2023-03-31 [1] CRAN (R 4.2.3)
#>    lubridate      1.9.2      2023-02-10 [1] CRAN (R 4.2.2)
#>    magrittr       2.0.3      2022-03-30 [2] CRAN (R 4.2.3)
#>    markdown       1.7        2023-05-16 [1] CRAN (R 4.2.3)
#>    MASS           7.3-58.2   2023-01-23 [2] CRAN (R 4.2.3)
#>    Matrix         1.5-3      2022-11-11 [1] CRAN (R 4.2.2)
#>    matrixStats    0.63.0     2022-11-18 [1] CRAN (R 4.2.3)
#>    memoise        2.0.1      2021-11-26 [2] CRAN (R 4.2.3)
#>    mime           0.12       2021-09-28 [2] CRAN (R 4.2.0)
#>    miniUI    2018-05-18 [2] CRAN (R 4.2.3)
#>    minqa          1.2.5      2022-10-19 [2] CRAN (R 4.2.3)
#>    modeldata    * 1.1.0      2023-01-25 [2] CRAN (R 4.2.3)
#>    munsell        0.5.0      2018-06-12 [2] CRAN (R 4.2.3)
#>    nlme           3.1-162    2023-01-31 [2] CRAN (R 4.2.3)
#>    nloptr         2.0.3      2022-05-26 [2] CRAN (R 4.2.3)
#>    nnet           7.3-18     2022-09-28 [2] CRAN (R 4.2.3)
#>    parallelly     1.35.0     2023-03-23 [2] CRAN (R 4.2.3)
#>    parsnip      * 1.1.0      2023-04-12 [2] CRAN (R 4.2.3)
#>    pillar         1.9.0      2023-03-22 [2] CRAN (R 4.2.3)
#>    pkgbuild       1.4.0      2022-11-27 [2] CRAN (R 4.2.3)
#>    pkgconfig      2.0.3      2019-09-22 [2] CRAN (R 4.2.3)
#>    plyr           1.8.8      2022-11-11 [2] CRAN (R 4.2.3)
#>    prettyunits    1.1.1      2020-01-24 [2] CRAN (R 4.2.3)
#>    processx       3.8.1      2023-04-18 [1] CRAN (R 4.2.3)
#>    prodlim        2023.03.31 2023-04-02 [2] CRAN (R 4.2.3)
#>    promises    2021-02-11 [2] CRAN (R 4.2.3)
#>    ps             1.7.5      2023-04-18 [1] CRAN (R 4.2.3)
#>    purrr        * 1.0.1      2023-01-10 [1] CRAN (R 4.2.2)
#>    R6             2.5.1      2021-08-19 [2] CRAN (R 4.2.3)
#>    Rcpp         * 1.0.10     2023-01-22 [1] CRAN (R 4.2.2)
#>  D RcppParallel   5.1.7      2023-02-27 [1] CRAN (R 4.2.3)
#>    recipes      * 1.0.6      2023-04-25 [2] CRAN (R 4.2.3)
#>    reprex         2.0.2      2022-08-17 [1] CRAN (R 4.2.3)
#>    reshape2       1.4.4      2020-04-09 [1] CRAN (R 4.2.3)
#>    rlang          1.1.1      2023-04-28 [1] CRAN (R 4.2.3)
#>    rmarkdown      2.21       2023-03-26 [2] CRAN (R 4.2.3)
#>    rpart          4.1.19     2022-10-21 [2] CRAN (R 4.2.3)
#>    rsample      * 1.1.1      2022-12-07 [2] CRAN (R 4.2.3)
#>    rstan          2.26.15    2023-02-11 [1] local
#>    rstanarm     * 2.21.4     2023-04-08 [1] CRAN (R 4.2.3)
#>    rstantools     2.3.1      2023-03-30 [1] CRAN (R 4.2.3)
#>    rstudioapi     0.14       2022-08-22 [2] CRAN (R 4.2.3)
#>    scales       * 1.2.1      2022-08-20 [2] CRAN (R 4.2.3)
#>    sessioninfo    1.2.2      2021-12-06 [2] CRAN (R 4.2.3)
#>    shiny          1.7.4      2022-12-15 [2] CRAN (R 4.2.3)
#>    shinyjs        2.1.0      2021-12-23 [1] CRAN (R 4.2.3)
#>    shinystan      2.6.0      2022-03-03 [1] CRAN (R 4.2.3)
#>    shinythemes    1.2.0      2021-01-25 [1] CRAN (R 4.2.3)
#>    StanHeaders    2.26.15    2023-02-11 [1] local
#>    stringi        1.7.12     2023-01-11 [1] CRAN (R 4.2.2)
#>    stringr        1.5.0      2022-12-02 [1] CRAN (R 4.2.2)
#>    survival       3.5-3      2023-02-12 [2] CRAN (R 4.2.3)
#>    threejs        0.3.3      2020-01-21 [1] CRAN (R 4.2.3)
#>    tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.2.3)
#>    tidymodels   * 1.1.0      2023-05-01 [1] CRAN (R 4.2.3)
#>    tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.2.2)
#>    tidyselect     1.2.0      2022-10-10 [2] CRAN (R 4.2.3)
#>    timechange     0.2.0      2023-01-11 [1] CRAN (R 4.2.2)
#>    timeDate       4022.108   2023-01-07 [2] CRAN (R 4.2.3)
#>    tune         * 1.1.1      2023-04-11 [2] CRAN (R 4.2.3)
#>    utf8           1.2.3      2023-01-31 [1] CRAN (R 4.2.2)
#>    V8             4.3.0      2023-04-08 [1] CRAN (R 4.2.3)
#>    vctrs          0.6.2      2023-04-19 [1] CRAN (R 4.2.3)
#>    withr          2.5.0      2022-03-03 [2] CRAN (R 4.2.3)
#>    workflows    * 1.1.3      2023-02-22 [2] CRAN (R 4.2.3)
#>    workflowsets * 1.0.1      2023-04-06 [2] CRAN (R 4.2.3)
#>    xfun           0.39       2023-04-20 [1] CRAN (R 4.2.3)
#>    xtable         1.8-4      2019-04-21 [2] CRAN (R 4.2.3)
#>    xts            0.13.1     2023-04-16 [1] CRAN (R 4.2.3)
#>    yaml           2.3.7      2023-01-23 [2] CRAN (R 4.2.3)
#>    yardstick    * 1.2.0      2023-04-21 [2] CRAN (R 4.2.3)
#>    zoo            1.8-12     2023-04-13 [1] CRAN (R 4.2.3)
#>  [1] C:/Users/00055815/AppData/Local/R/win-library/4.2
#>  [2] C:/Program Files/R/R-4.2.3/library
#>  D ── DLL MD5 mismatch, broken installation.
#> ──────────────────────────────────────────────────────────────────────────────
Copy link

hfrick commented Jun 5, 2023

Thanks for the issue @jdberson ! I think you're right, this should be possible.

dbart_predict_calc() has a std_err argument but neither does it get used inside the function nor does it get passed the std_error from predict(). I'd take this as an indication that the intention was there but it hasn't happened yet!

@hfrick hfrick added the feature a feature request or enhancement label Jun 5, 2023
Copy link

jdberson commented Jun 6, 2023

Hi @hfrick thanks very much for looking into this and for making the changes so quickly! That is absolutely brilliant.

Copy link

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: and link to this issue.

@github-actions github-actions bot locked and limited conversation to collaborators Jun 23, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
feature a feature request or enhancement
None yet

Successfully merging a pull request may close this issue.

2 participants