Skip to content

Commit 9fd3604

Browse files
fix: use modelbased predictions for bayesian check_predictions
Agent-Logs-Url: https://github.com/easystats/performance/sessions/8f6f86ac-dce2-4dd7-ab79-6115b8855b4d Co-authored-by: DominiqueMakowski <8875533+DominiqueMakowski@users.noreply.github.com>
1 parent 2fbf9d6 commit 9fd3604

File tree

5 files changed

+95
-45
lines changed

5 files changed

+95
-45
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Type: Package
22
Package: performance
33
Title: Assessment of Regression Models Performance
4-
Version: 0.16.0
4+
Version: 0.16.0.1
55
Authors@R:
66
c(person(given = "Daniel",
77
family = "Lüdecke",

NEWS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# performance 0.16.0.1
2+
3+
## Bug fixes
4+
5+
* `check_predictions()` for Bayesian models now uses
6+
`modelbased::estimate_prediction()` and returns posterior predictive data in
7+
the same format as for other supported models.
8+
19
# performance 0.16.0
210

311
## Breaking Changes

R/check_predictions.R

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#'
1010
#' **performance** provides posterior predictive check methods for a variety
1111
#' of frequentist models (e.g., `lm`, `merMod`, `glmmTMB`, ...). For Bayesian
12-
#' models, the model is passed to [`bayesplot::pp_check()`].
12+
#' models, posterior predictions are computed with
13+
#' `modelbased::estimate_prediction()` and plotted with the same machinery as
14+
#' for other supported models.
1315
#'
1416
#' If `check_predictions()` doesn't work as expected, try setting
1517
#' `verbose = TRUE` to get hints about possible problems.
@@ -38,7 +40,8 @@
3840
#' @param verbose Toggle warnings.
3941
#' @param ... Additional arguments passed on to downstream functions. For
4042
#' frequentist models, these are forwarded to `simulate()`; for Bayesian models
41-
#' (e.g., `stanreg`, `brmsfit`), they are forwarded to `bayesplot::pp_check()`.
43+
#' (e.g., `stanreg`, `brmsfit`), they are forwarded to
44+
#' `modelbased::estimate_prediction()`.
4245
#' @param object Deprecated, please use `model` instead.
4346
#'
4447
#' @return A data frame of simulated responses and the original response vector.
@@ -203,57 +206,50 @@ check_predictions.stanreg <- function(
203206
c("density", "discrete_dots", "discrete_interval", "discrete_both")
204207
)
205208

206-
# convert to type-argument for pp_check
207-
pp_type <- switch(type, density = "dens", "bars")
208-
209209
insight::check_if_installed(
210-
"bayesplot",
211-
"to create posterior prediction plots for Stan models"
210+
"modelbased",
211+
"to create posterior predictive checks for Bayesian models"
212212
)
213213

214-
# for plotting
215-
resp_string <- insight::find_terms(model)$response
214+
out <- modelbased::estimate_prediction(
215+
model,
216+
iterations = iterations,
217+
keep_iterations = TRUE,
218+
re_formula = re_formula,
219+
verbose = verbose,
220+
...
221+
)
216222

217-
if (inherits(model, "brmsfit")) {
218-
out <- as.data.frame(
219-
bayesplot::pp_check(model, type = pp_type, ndraws = iterations, ...)$data
220-
)
221-
} else {
222-
out <- as.data.frame(
223-
bayesplot::pp_check(model, plotfun = pp_type, nreps = iterations, ...)$data
223+
iter_columns <- startsWith(colnames(out), "iter_")
224+
if (!any(iter_columns)) {
225+
insight::format_error(
226+
"Could not retrieve posterior predictive draws for the Bayesian model."
224227
)
225228
}
226229

227-
# bring data into shape, like we have for other models with `check_predictions()`
228-
if (pp_type == "dens") {
229-
d_filter <- out[!out$is_y, ]
230-
d_filter <- datawizard::data_to_wide(
231-
d_filter,
232-
id_cols = "y_id",
233-
values_from = "value",
234-
names_from = "rep_id"
235-
)
236-
d_filter$y_id <- NULL
237-
colnames(d_filter) <- paste0("sim_", colnames(d_filter))
238-
d_filter$y <- out$value[out$is_y]
239-
out <- d_filter
240-
} else {
241-
colnames(out) <- c("x", "y", "CI_low", "Mean", "CI_high")
242-
# to long, for plotting
243-
out <- datawizard::data_to_long(
244-
out,
245-
select = c("y", "Mean"),
246-
names_to = "Group",
247-
values_to = "Count"
248-
)
230+
out <- as.data.frame(out[iter_columns])
231+
colnames(out) <- sub("^iter_", "sim_", colnames(out))
232+
233+
resp_string <- insight::find_terms(model)$response
234+
pattern <- "^(scale|exp|expm1|log|log1p|log10|log2|sqrt)"
235+
236+
if (
237+
!is.null(resp_string) &&
238+
length(resp_string) == 1 &&
239+
grepl(paste0(pattern, "\\("), resp_string)
240+
) {
241+
out <- .backtransform_sims(out, resp_string)
249242
}
250243

251-
# make x cateogorical for bernoulli/categorical/multinomial models
252-
if (minfo$is_bernoulli || minfo$is_categorical || minfo$is_multinomial) {
253-
out$x <- as.factor(out$x)
244+
response <- insight::get_response(model)
245+
if (is.data.frame(response)) {
246+
response <- eval(
247+
str2lang(insight::find_response(model)),
248+
envir = insight::get_response(model)
249+
)
254250
}
251+
out$y <- response
255252

256-
attr(out, "is_stan") <- TRUE
257253
attr(out, "check_range") <- check_range
258254
attr(out, "response_name") <- resp_string
259255
attr(out, "bandwidth") <- bandwidth

man/check_predictions.Rd

Lines changed: 5 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-check_predictions.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,46 @@ test_that("check_predictions, glmer, works with proportion and cbind binomial an
454454
)
455455
)
456456
})
457+
458+
459+
test_that("check_predictions, Bayesian models use standard predictive-check data", {
460+
skip_if_not_installed("modelbased", minimum_version = "0.12.0")
461+
skip_if_not_installed("curl")
462+
skip_if_not_installed("httr2")
463+
skip_if_offline()
464+
465+
model <- insight::download_model("stanreg_lm_1")
466+
skip_if(is.null(model))
467+
468+
set.seed(123)
469+
out <- check_predictions(model, iterations = 5)
470+
471+
expect_s3_class(out, "performance_pp_check")
472+
expect_named(out, c(paste0("sim_", 1:5), "y"))
473+
expect_false("x" %in% names(out))
474+
expect_false("Group" %in% names(out))
475+
expect_false(isTRUE(attr(out, "is_stan")))
476+
expect_identical(attr(out, "type"), "density")
477+
})
478+
479+
480+
test_that("check_predictions, Bayesian discrete models use standard predictive-check data", {
481+
skip_if_not_installed("modelbased", minimum_version = "0.12.0")
482+
skip_if_not_installed("curl")
483+
skip_if_not_installed("httr2")
484+
skip_if_offline()
485+
486+
model <- insight::download_model("brms_ordinal_1")
487+
skip_if(is.null(model))
488+
489+
set.seed(123)
490+
out <- check_predictions(model, iterations = 5, type = "discrete_interval")
491+
492+
expect_s3_class(out, "performance_pp_check")
493+
expect_named(out, c(paste0("sim_", 1:5), "y"))
494+
expect_false("x" %in% names(out))
495+
expect_false("Group" %in% names(out))
496+
expect_false(isTRUE(attr(out, "is_stan")))
497+
expect_identical(attr(out, "type"), "discrete_interval")
498+
expect_true(attr(out, "model_info")$is_ordinal)
499+
})

0 commit comments

Comments
 (0)