In the following, we explain the counterfactuals workflow for both a classification and a regression task using concrete use cases.
The Predictor class of the iml package provides the necessary flexibility to cover classification and regression models fitted with diverse R packages. In the introduction vignette, we saw models fitted with the mlr3 and randomForest packages. In the following, we show extensions to - an classification tree fitted with the caret package, the mlr (a predecesor of mlr3) and tidymodels. For each model we generate counterfactuals for the first row of the BostonHousing dataset using the WhatIf method
library("caret")
treecaret = caret::train(medv ~ ., data = BostonHousing[-1,], method = "rpart",
tuneGrid = data.frame(cp = 0.01))
predcaret = Predictor$new(model = treecaret, data = BostonHousing[-1L,], y = "medv")
predcaret$predict(x_interest)
#> .prediction
#> 1 27.49074
nicecaret = NICERegr$new(predcaret, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicecaret$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47library("tidymodels")
treetm = decision_tree(mode = "regression", engine = "rpart") %>%
fit(medv ~ ., data = BostonHousing[-1L,])
predtm = Predictor$new(model = treetm, data = BostonHousing[-1L,], y = "medv")
predtm$predict(x_interest)
#> .pred
#> 1 27.49074
nicetm = NICERegr$new(predtm, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicetm$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47library("mlr")
task = mlr::makeRegrTask(data = BostonHousing[-1L,], target = "medv")
mod = makeLearner("regr.rpart")
treemlr = train(mod, task)
predmlr = Predictor$new(model = treemlr, data = BostonHousing[-1L,], y = "medv")
predmlr$predict(x_interest)
#> .prediction
#> 1 27.49074
nicemlr = NICERegr$new(predmlr, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicemlr$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47treerpart = rpart(medv ~ ., data = BostonHousing[-1L,])
predrpart = Predictor$new(model = treerpart, data = BostonHousing[-1L,], y = "medv")
predrpart$predict(x_interest)
#> pred
#> 1 27.49074
nicerpart = NICERegr$new(predrpart, optimization = "plausibility",
margin_correct = 0.5, return_multiple = FALSE)
nicerpart$find_counterfactuals(x_interest, desired_outcome = c(30, 40))
#> 1 Counterfactual(s)
#>
#> Desired outcome range: [30, 40]
#>
#> Head:
#> crim zn indus chas nox rm age dis rad tax ptratio b lstat
#> 1: 0.07503 33 2.18 0 0.472 7.42 71.9 3.0992 7 222 18.4 396.9 6.47