Skip to content

Commit

Permalink
Merge pull request #1604 from venpopov/get_prior_s3
Browse files Browse the repository at this point in the history
Proposal: transform get_prior, make_stancode and make_standata into S3 methods
  • Loading branch information
paul-buerkner authored Mar 3, 2024
2 parents 81d9233 + 354be0d commit 2c00f71
Show file tree
Hide file tree
Showing 44 changed files with 1,191 additions and 928 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: brms
Encoding: UTF-8
Type: Package
Title: Bayesian Regression Models using 'Stan'
Version: 2.20.13
Version: 2.20.14
Date: 2024-02-27
Authors@R:
c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.com",
Expand Down Expand Up @@ -98,4 +98,4 @@ Additional_repositories:
VignetteBuilder:
knitr,
R.rsp
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ S3method(data_response,brmsterms)
S3method(data_response,mvbrmsterms)
S3method(def_scale_prior,brmsterms)
S3method(def_scale_prior,mvbrmsterms)
S3method(default_prior,brmsfit)
S3method(default_prior,default)
S3method(dpar_family,default)
S3method(dpar_family,mixfamily)
S3method(duplicated,brmsprior)
Expand Down Expand Up @@ -249,7 +251,9 @@ S3method(stan_predictor,btl)
S3method(stan_predictor,btnl)
S3method(stan_predictor,mvbrmsterms)
S3method(stancode,brmsfit)
S3method(stancode,default)
S3method(standata,brmsfit)
S3method(standata,default)
S3method(standata_basis,brmsterms)
S3method(standata_basis,btl)
S3method(standata_basis,btnl)
Expand Down Expand Up @@ -376,6 +380,7 @@ export(data_predictor)
export(data_response)
export(dbeta_binomial)
export(ddirichlet)
export(default_prior)
export(density_ratio)
export(dexgaussian)
export(dfrechet)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ if potentially results-changing arguments are provided to the criterion method.

### Other Changes

* Change `make_stancode` and `make_standata` to be aliases of `stancode` and
`standata`, respectively. Change `get_prior` to be an alias of a new generic
method `default_prior`. This enable other packages to define new `stancode`,
`standata` and `default_prior` methods to generate Stan code and data, and extract
the default priors, for their own objects building on brms. Thanks to Ven Popov
for helping with this. (#1604)
* No longer automatically canonicalize the Stan code if cmdstanr is used
as backend. (#1544)
* Improve parameter class names in the `summary` output.
Expand Down
24 changes: 12 additions & 12 deletions R/brm.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#' \code{family} might also be a list of families.
#' @param prior One or more \code{brmsprior} objects created by
#' \code{\link{set_prior}} or related functions and combined using the
#' \code{c} method or the \code{+} operator. See also \code{\link{get_prior}}
#' \code{c} method or the \code{+} operator. See also \code{\link[brms:default_prior.default]{default_prior}}
#' for more help.
#' @param data2 A named \code{list} of objects containing data, which
#' cannot be passed via argument \code{data}. Required for some objects
Expand Down Expand Up @@ -271,7 +271,7 @@
#' \code{\link[brms:set_prior]{set_prior}} function. Its documentation
#' contains detailed information on how to correctly specify priors. To find
#' out on which parameters or parameter classes priors can be defined, use
#' \code{\link[brms:get_prior]{get_prior}}. Default priors are chosen to be
#' \code{\link[brms:default_prior.default]{default_prior}}. Default priors are chosen to be
#' non or very weakly informative so that their influence on the results will
#' be negligible and you usually don't have to worry about them. However,
#' after getting more familiar with Bayesian statistics, I recommend you to
Expand Down Expand Up @@ -318,12 +318,12 @@
#' @examples
#' \dontrun{
#' # Poisson regression for the number of seizures in epileptic patients
#' # using normal priors for population-level effects
#' # and half-cauchy priors for standard deviations of group-level effects
#' prior1 <- prior(normal(0, 10), class = b) +
#' prior(cauchy(0, 2), class = sd)
#' fit1 <- brm(count ~ zBase * Trt + (1|patient), data = epilepsy,
#' family = poisson(), prior = prior1)
#' fit1 <- brm(
#' count ~ zBase * Trt + (1|patient),
#' data = epilepsy, family = poisson(),
#' prior = prior(normal(0, 10), class = b) +
#' prior(cauchy(0, 2), class = sd)
#' )
#'
#' # generate a summary of the results
#' summary(fit1)
Expand Down Expand Up @@ -418,8 +418,8 @@
#'
#'
#' # fit a model manually via rstan
#' scode <- make_stancode(count ~ Trt, data = epilepsy)
#' sdata <- make_standata(count ~ Trt, data = epilepsy)
#' scode <- stancode(count ~ Trt, data = epilepsy)
#' sdata <- standata(count ~ Trt, data = epilepsy)
#' stanfit <- rstan::stan(model_code = scode, data = sdata)
#' # feed the Stan model back into brms
#' fit8 <- brm(count ~ Trt, data = epilepsy, empty = TRUE)
Expand Down Expand Up @@ -537,7 +537,7 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
)
ranef <- tidy_ranef(bterms, data = data)
# generate Stan code
model <- .make_stancode(
model <- .stancode(
bterms, data = data, prior = prior,
stanvars = stanvars, save_model = save_model,
backend = backend, threads = threads, opencl = opencl,
Expand All @@ -556,7 +556,7 @@ brm <- function(formula, data, family = gaussian(), prior = NULL,
exclude <- exclude_pars(x)
# generate Stan data before compiling the model to avoid
# unnecessary compilations in case of invalid data
sdata <- .make_standata(
sdata <- .standata(
bterms, data = data, prior = prior, data2 = data2,
stanvars = stanvars, threads = threads
)
Expand Down
4 changes: 2 additions & 2 deletions R/brms-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#' formula syntax to specify a wide range of complex Bayesian models
#' (see \code{\link{brmsformula}} for details). Based on the supplied
#' formulas, data, and additional information, it writes the Stan code
#' on the fly via \code{\link{make_stancode}}, prepares the data via
#' \code{\link{make_standata}}, and fits the model using
#' on the fly via \code{\link[brms:stancode.default]{stancode}}, prepares the data via
#' \code{\link[brms:standata.default]{standata}} and fits the model using
#' \pkg{\link[rstan:rstan]{Stan}}.
#'
#' Subsequently, a large number of post-processing methods can be applied:
Expand Down
4 changes: 2 additions & 2 deletions R/brmsfit-helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -871,9 +871,9 @@ validate_cores_post_processing <- function(cores) {
#' and scripts should not use it.
#'
#' @param fit Old \code{brmsfit} object (e.g., loaded from file).
#' @param sdata New Stan data (result of a call to \code{\link{make_standata}}).
#' @param sdata New Stan data (result of a call to \code{\link[brms:standata.default]{standata}}).
#' Pass \code{NULL} to avoid this data check.
#' @param scode New Stan code (result of a call to \code{\link{make_stancode}}).
#' @param scode New Stan code (result of a call to \code{\link[brms:stancode.default]{stancode}}).
#' Pass \code{NULL} to avoid this code check.
#' @param data New data to check consistency of factor level names.
#' Pass \code{NULL} to avoid this data check.
Expand Down
2 changes: 1 addition & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ acat <- function(link = "logit", link_disc = "log",
#' pp_check(fit4)
#'
#' ## compare model fit
#' LOO(fit1, fit2, fit3, fit4)
#' loo(fit1, fit2, fit3, fit4)
#' }
#'
#' @export
Expand Down
2 changes: 1 addition & 1 deletion R/formula-gp.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
#' plot(me3, ask = FALSE, points = TRUE)
#'
#' # compare model fit
#' LOO(fit1, fit2, fit3)
#' loo(fit1, fit2, fit3)
#'
#' # simulate data with a factor covariate
#' dat2 <- mgcv::gamSim(4, n = 90, scale = 2)
Expand Down
132 changes: 85 additions & 47 deletions R/make_stancode.R
Original file line number Diff line number Diff line change
@@ -1,44 +1,89 @@
#' @title Stan Code for Bayesian models
#'
#' @description \code{stancode} is a generic function that can be used to
#' generate Stan code for Bayesian models. It's original use is
#' within the \pkg{brms} package, but new methods for use
#' with objects from other packages can be registered to the same generic.
#'
#' @param object An object whose class will determine which method to apply.
#' Usually, it will be some kind of symbolic description of the model
#' form which Stan code should be generated.
#' @param formula Synonym of \code{object} for use in \code{make_stancode}.
#' @param ... Further arguments passed to the specific method.
#'
#' @return Usually, a character string containing the generated Stan code.
#' For pretty printing, we recommend the returned object to be of class
#' \code{c("character", "brmsmodel")}.
#'
#' @details
#' See \code{\link[brms:stancode.default]{stancode.default}} for the default
#' method applied for \pkg{brms} models.
#' You can view the available methods by typing: \code{methods(stancode)}
#' The \code{make_stancode} function is an alias of \code{stancode}.
#'
#' @seealso
#' \code{\link{stancode.default}}, \code{\link{stancode.brmsfit}}
#'
#' @examples
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' @export
stancode <- function(object, ...) {
UseMethod("stancode")
}

#' @rdname stancode
#' @export
make_stancode <- function(formula, ...) {
# became an alias of 'stancode' in 2.20.14
stancode(formula, ...)
}

#' Stan Code for \pkg{brms} Models
#'
#' Generate Stan code for \pkg{brms} models
#'
#' @inheritParams brm
#' @param object An object of class \code{\link[stats:formula]{formula}},
#' \code{\link{brmsformula}}, or \code{\link{mvbrmsformula}} (or one that can
#' be coerced to that classes): A symbolic description of the model to be
#' fitted. The details of model specification are explained in
#' \code{\link{brmsformula}}.
#' @param ... Other arguments for internal usage only.
#'
#' @return A character string containing the fully commented \pkg{Stan} code
#' to fit a \pkg{brms} model.
#' to fit a \pkg{brms} model. It is of class \code{c("character", "brmsmodel")}
#' to facilitate pretty printing.
#'
#' @examples
#' make_stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#' stancode(rating ~ treat + period + carry + (1|subject),
#' data = inhaler, family = "cumulative")
#'
#' make_stancode(count ~ zAge + zBase * Trt + (1|patient),
#' data = epilepsy, family = "poisson")
#' stancode(count ~ zAge + zBase * Trt + (1|patient),
#' data = epilepsy, family = "poisson")
#'
#' @export
make_stancode <- function(formula, data, family = gaussian(),
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
drop_unused_levels = TRUE,
threads = getOption("brms.threads", NULL),
normalize = getOption("brms.normalize", TRUE),
save_model = NULL, ...) {
stancode.default <- function(object, data, family = gaussian(),
prior = NULL, autocor = NULL, data2 = NULL,
cov_ranef = NULL, sparse = NULL,
sample_prior = "no", stanvars = NULL,
stan_funs = NULL, knots = NULL,
drop_unused_levels = TRUE,
threads = getOption("brms.threads", NULL),
normalize = getOption("brms.normalize", TRUE),
save_model = NULL, ...) {

if (is.brmsfit(formula)) {
stop2("Use 'stancode' to extract Stan code from 'brmsfit' objects.")
}
formula <- validate_formula(
formula, data = data, family = family,
object <- validate_formula(
object, data = data, family = family,
autocor = autocor, sparse = sparse,
cov_ranef = cov_ranef
)
bterms <- brmsterms(formula)
bterms <- brmsterms(object)
data2 <- validate_data2(
data2, bterms = bterms,
get_data2_autocor(formula),
get_data2_cov_ranef(formula)
get_data2_autocor(object),
get_data2_cov_ranef(object)
)
data <- validate_data(
data, bterms = bterms,
Expand All @@ -52,24 +97,25 @@ make_stancode <- function(formula, data, family = gaussian(),
stanvars <- validate_stanvars(stanvars, stan_funs = stan_funs)
threads <- validate_threads(threads)

.make_stancode(
.stancode(
bterms, data = data, prior = prior,
stanvars = stanvars, threads = threads,
normalize = normalize, save_model = save_model,
...
)
}

# internal work function of 'make_stancode'
# internal work function of 'stancode.default'
# @param parse parse the Stan model for automatic syntax checking
# @param backend name of the backend used for parsing
# @param silent silence parsing messages
.make_stancode <- function(bterms, data, prior, stanvars,
.stancode <- function(bterms, data, prior, stanvars,
threads = threading(),
normalize = getOption("brms.normalize", TRUE),
parse = getOption("brms.parse_stancode", FALSE),
backend = getOption("brms.backend", "rstan"),
silent = TRUE, save_model = NULL, ...) {

normalize <- as_one_logical(normalize)
parse <- as_one_logical(parse)
backend <- match.arg(backend, backend_choices())
Expand Down Expand Up @@ -329,31 +375,29 @@ print.brmsmodel <- function(x, ...) {
invisible(x)
}

#' Extract Stan model code
#'
#' Extract Stan code that was used to specify the model.
#' Extract Stan code from \code{brmsfit} objects
#'
#' @aliases stancode.brmsfit
#' Extract Stan code from a fitted \pkg{brms} model.
#'
#' @param object An object of class \code{brmsfit}.
#' @param version Logical; indicates if the first line containing
#' the \pkg{brms} version number should be included.
#' Defaults to \code{TRUE}.
#' @param regenerate Logical; indicates if the Stan code should
#' be regenerated with the current \pkg{brms} version.
#' By default, \code{regenerate} will be \code{FALSE} unless required
#' to be \code{TRUE} by other arguments.
#' @param threads Controls whether the Stan code should be threaded.
#' See \code{\link{threading}} for details.
#' @param version Logical; indicates if the first line containing the \pkg{brms}
#' version number should be included. Defaults to \code{TRUE}.
#' @param regenerate Logical; indicates if the Stan code should be regenerated
#' with the current \pkg{brms} version. By default, \code{regenerate} will be
#' \code{FALSE} unless required to be \code{TRUE} by other arguments.
#' @param threads Controls whether the Stan code should be threaded. See
#' \code{\link{threading}} for details.
#' @param backend Controls the Stan backend. See \code{\link{brm}} for details.
#' @param ... Further arguments passed to \code{\link{make_stancode}} if the
#' Stan code is regenerated.
#' @param ... Further arguments passed to
#' \code{\link[brms:stancode.default]{stancode}} if the Stan code is
#' regenerated.
#'
#' @return Stan model code for further processing.
#' @return Stan code for further processing.
#'
#' @export
stancode.brmsfit <- function(object, version = TRUE, regenerate = NULL,
threads = NULL, backend = NULL, ...) {

if (is.null(regenerate)) {
# determine whether regenerating the Stan code is required
regenerate <- FALSE
Expand Down Expand Up @@ -400,12 +444,6 @@ stancode.brmsfit <- function(object, version = TRUE, regenerate = NULL,
out
}

#' @rdname stancode.brmsfit
#' @export
stancode <- function(object, ...) {
UseMethod("stancode")
}

# expand '#include' statements
# This could also be done automatically by Stan at compilation time
# but would result in Stan code that is not self-contained until compilation
Expand Down
Loading

0 comments on commit 2c00f71

Please sign in to comment.