| Type: | Package |
| Title: | Optimal Transport Weights for Causal Inference |
| Version: | 1.0.2 |
| Date: | 2024-02-17 |
| Author: | Eric Dunipace |
| Maintainer: | Eric Dunipace <edunipace@mail.harvard.edu> |
| Description: | Uses optimal transport distances to find probabilistic matching estimators for causal inference. These methods are described in Dunipace, Eric (2021) <doi:10.48550/arXiv.2109.01991>. The package will build the weights, estimate treatment effects, and calculate confidence intervals via the methods described in the paper. The package also supports several other methods as described in the help files. |
| License: | GPL (== 3.0) |
| Imports: | CBPS, ggplot2, lbfgsb3c, loo, Matrix (≥ 1.5-0), matrixStats, methods, osqp, R6 (≥ 2.4.1), Rcpp (≥ 1.0.3), rlang, sandwich, torch, utils |
| LinkingTo: | BH (≥ 1.66.0), Rcpp (≥ 0.12.0), RcppEigen (≥ 0.3.3.3.0), torch |
| Suggests: | data.table (≥ 1.12.8), testthat (≥ 2.1.0), knitr, reticulate, rkeops (≥ 2.2.2), rmarkdown, V8, withr |
| Additional_repositories: | https://ericdunipace.github.io/drat/ |
| Biarch: | true |
| Depends: | R (≥ 3.5.0) |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.1 |
| LazyData: | true |
| VignetteBuilder: | knitr |
| Collate: | 'DataSimClass.R' 'dataHolder.R' 'weightsClass.R' 'ESS.R' 'OT.R' 'PSIS.R' 'RcppExports.R' 'balanceFunctions.R' 'barycentricProjection.R' 'calc_weight.R' 'causalOT-package.R' 'cost_functions.R' 'scmClass.R' 'gridSearch.R' 'cotClass.R' 'cotOOP.R' 'cot_opts.R' 'likelihoodClass.R' 'mean_balance.R' 'summary.R' 'supportedMethods.R' 'treatment_effect.R' 'utils.R' 'zzz.R' |
| NeedsCompilation: | yes |
| Packaged: | 2024-02-18 21:20:35 UTC; eifer |
| Repository: | CRAN |
| Date/Publication: | 2024-02-18 22:50:08 UTC |
An R package to perform causal inference using optimal transport distances.
Description
R code to perform causal inference weighting using a variety of methods and optimizers. The code can estimate weights, estimate treatment effects, and also give variance estimates. These methods are described in Dunipace, Eric (2021) https://arxiv.org/abs/2109.01991.
Author(s)
Eric Dunipace
CRASH3 data example
Description
CRASH3 data example
CRASH3 data example
Details
Returns the CRASH3 data. Note that gen_data() will initialize the fixed data for x and y, but z is generated from Binom(0.5).
Value
Super class
causalOT::DataSim -> CRASH3
Public fields
site_idThe site of the observation in terms of the original RCT.
Methods
Public methods
Inherited methods
Method gen_data()
The site ID for the observations
Draws new treatment indicators. x and y data are fixed.
Usage
CRASH3$gen_data()
Method gen_x()
Sets up the covariate data. This data is fixed.
Usage
CRASH3$gen_x()
Method gen_y()
Sets up the outcome data. This data is fixed.
Usage
CRASH3$gen_y()
Method gen_z()
Sets up the treatment indicator. Drawn as Z ~ Binom(0.5)
Usage
CRASH3$gen_z()
Method new()
Initializes the CRASH3 object.
Usage
CRASH3$new(n = NULL, p = NULL, param = list(), design = NA_character_, ...)
Arguments
nNot used. Maintained for symmetry with other DataSim objects.
pNot used. Maintained for symmetry with other DataSim objects.
paramNot used. Maintained for symmetry with other DataSim objects.
designNot used
...Not used.
Examples
crash <- CRASH3$new() crash$gen_data() crash$get_n() crash$site_id
Method clone()
The objects of this class are cloneable with this method.
Usage
CRASH3$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
## ------------------------------------------------
## Method `CRASH3$new`
## ------------------------------------------------
crash <- CRASH3$new()
crash$gen_data()
crash$get_n()
crash$site_id
R6 Data Generating Parent Class
Description
R6 Data Generating Parent Class
R6 Data Generating Parent Class
Details
Can be used to make your own data simulation class. Should have the same slots listed in this class at a minimum, but you can add your own, of course. An easy way to do this is to make your class inherit from this one. See the example.
Value
An R6 object
Methods
Public methods
Method get_x()
Gets the covariate data
Usage
DataSim$get_x()
Method get_y()
Gets the outcome vector
Usage
DataSim$get_y()
Method get_z()
Gets the treatment indicator
Usage
DataSim$get_z()
Method get_n()
Gets the number of observations
Usage
DataSim$get_n()
Method get_x1()
Gets the covariate data for the treated individuals
Usage
DataSim$get_x1()
Method get_x0()
Gets the covaraiate data for the control individuals
Usage
DataSim$get_x0()
Method get_p()
Gets the dimensionality covariate data
Usage
DataSim$get_p()
Method get_tau()
Gets the individual treatment effects
Usage
DataSim$get_tau()
Method gen_data()
Generates the data. Default is an empty function
Usage
DataSim$gen_data()
Method clone()
The objects of this class are cloneable with this method.
Usage
DataSim$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
MyClass <- R6::R6Class("MyClass",
inherit = DataSim,
public = list(),
private = list())
Effective Sample Size
Description
Effective Sample Size
Usage
ESS(x)
## S4 method for signature 'numeric'
ESS(x)
## S4 method for signature 'causalWeights'
ESS(x)
Arguments
x |
Either a vector of weights summing to 1 or an object of class causalWeights |
Details
Calculates the effective sample size as described by Kish (1965).
However, this calculation has some problems and the PSIS()
function should be used instead.
Value
Either a number denoting the effective sample size or if x is of class
causalWeights, then returns a list of both values in the treatment
and control groups.
Methods (by class)
-
ESS(numeric): default ESS method for numeric vectors -
ESS(causalWeights): ESS method for objects of class causalWeights
See Also
Examples
x <- rep(1/100,100)
ESS(x)
Hainmueller data example
Description
Hainmueller data example
Hainmueller data example
Details
Generates the data as described in Hainmueller (2012).
Value
Super class
causalOT::DataSim -> Hainmueller
Methods
Public methods
Inherited methods
Method gen_data()
Generates the data
Usage
Hainmueller$gen_data()
Method gen_x()
Generates the covaraiate data
Usage
Hainmueller$gen_x()
Method gen_y()
Generates the outcome data
Usage
Hainmueller$gen_y()
Method gen_z()
Generates the treatment indicator
Usage
Hainmueller$gen_z()
Method new()
Generates the the Hainmueller R6 class
Usage
Hainmueller$new( n = 100, p = 6, param = list(), design = "A", overlap = "low", ... )
Arguments
nThe number of observations
pThe dimensions of the covariates. Fixed to 6.
paramThe data generating parameters fed as a list.
designOne of "A" or "B". See details.
overlapOne of "high", "low", or "medium". See details.
...Extra arguments. Currently unused.
Details
Design
Design "A"
is the setting where the outcome is generated
from a linear model,
Y(0) = Y(1) = X_1 + X_2 + X_3 - X_4 + X_5 + X_6 + \eta
and design "B" is where the outcome is
generated from the non-linear model
Y(0) = Y(1) = (X_1 + X_2 +X_5 )^2 + \eta.
Overlap
The treatment indicator is generated from
Z = 1(X_1 + 2 X_2 - 2 X_3 - X_4 - 0.5 X_5 + X_6 + \nu > 0), where \nu
depends on the overlap selected. If overlap is "high",
then \nu \sim N(0, 100). If overlap is
"low", then \nu \sim N(0, 30). Finally,
if overlap is "medium", then \nu is drawn
from a \chi^2 with 5 degrees of freedom
that is scaled and centered to have mean 0.5 and
variance 67.6.
Returns
An object of class DataSim.
Examples
data <- Hainmueller$new(n = 100, p = 6, design = "A", overlap = "low") data$gen_data() print(data$get_x()[1:2,])
Method get_design()
Returns the chosen design parameters
Usage
Hainmueller$get_design()
Method get_pscore()
Returns the true propensity score
Usage
Hainmueller$get_pscore()
Method clone()
The objects of this class are cloneable with this method.
Usage
Hainmueller$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
## ------------------------------------------------
## Method `Hainmueller$new`
## ------------------------------------------------
data <- Hainmueller$new(n = 100, p = 6, design = "A", overlap = "low")
data$gen_data()
print(data$get_x()[1:2,])
LaLonde data example
Description
LaLonde data example
LaLonde data example
Details
Returns the LaLonde data as used by Dehjia and Wahba. Note the data
is fixed and gen_data() will just initialize the fixed data.
Value
Super class
causalOT::DataSim -> LaLonde
Methods
Public methods
Inherited methods
Method gen_data()
Sets up the data
Usage
LaLonde$gen_data()
Method get_tau()
Returns the experimental treatment effect, $1794
Usage
LaLonde$get_tau()
Method gen_x()
Sets up the covariate data
Usage
LaLonde$gen_x()
Method gen_y()
Sets up the outcome data
Usage
LaLonde$gen_y()
Method gen_z()
Sets up the treatment indicator
Usage
LaLonde$gen_z()
Method new()
Initializes the LaLonde object.
Usage
LaLonde$new(n = NULL, p = NULL, param = list(), design = "NSW", ...)
Arguments
nNot used. Maintained for symmetry with other DataSim objects.
pNot used. Maintained for symmetry with other DataSim objects.
paramNot used. Maintained for symmetry with other DataSim objects.
designOne of "NSW" or "Full". "NSW" uses the original experimental data from the job training program while option "Full" uses the treated individuals from LaLonde's study and compares them to individuals from the Current Population Survey as controls.
...Not used.
Examples
nsw <- LaLonde$new(design = "NSW") nsw$gen_data() nsw$get_n() obs.study <- LaLonde$new(design = "Full") obs.study$gen_data() obs.study$get_n()
Method get_design()
Returns the chosen design parameters
Usage
LaLonde$get_design()
Method clone()
The objects of this class are cloneable with this method.
Usage
LaLonde$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
## ------------------------------------------------
## Method `LaLonde$new`
## ------------------------------------------------
nsw <- LaLonde$new(design = "NSW")
nsw$gen_data()
nsw$get_n()
obs.study <- LaLonde$new(design = "Full")
obs.study$gen_data()
obs.study$get_n()
An R6 Class for setting up measures
Description
An R6 Class for setting up measures
Usage
Measure(
x,
weights = NULL,
probability.measure = TRUE,
adapt = c("none", "weights", "x"),
balance.functions = NA_real_,
target.values = NA_real_,
dtype = NULL,
device = NULL
)
Arguments
x |
The data points |
weights |
The empirical measure. If NULL, assigns equal weight to each observation |
probability.measure |
Is the empirical measure a probability measure? Default is TRUE. |
adapt |
Should we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none". |
balance.functions |
A matrix of functions of the covariates to target for mean balance. If NULL and |
target.values |
The targets for the balance functions. Should be the same length as columns in |
dtype |
The torch_tensor dtype or NULL. |
device |
The device to have the data on. Should be result of |
Value
Returns a Measure object
Public fields
balance_functionsthe functions of the data that we want to adjust towards the targets
balance_targetthe values the balance_functions are targeting
adaptWhat aspect of the data will be adapted. One of "none","weights", or "x".
devicethe
torch::torch_deviceof the data.dtypethe torch::torch_dtype of the data.
nthe rows of the covariates, x.
dthe columns of the covariates, x.
probability_measureis the measure a probability measure?
Active bindings
gradgets or sets gradient
init_weightsreturns the initial value of the weights
init_datareturns the initial value of the data
requires_gradchecks or turns on/off gradient
weightsgets or sets weights
xGets or sets the data
Methods
Public methods
Method detach()
generates a deep clone of the object without gradients.
Usage
Measure$detach()
Method get_weight_parameters()
Makes a copy of the weights parameters.
Usage
Measure$get_weight_parameters()
Method clone()
The objects of this class are cloneable with this method.
Usage
Measure$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
if(torch::torch_is_installed()) {
m <- Measure(x = matrix(0, 10, 2), adapt = "none")
print(m)
m$x
m$x <- matrix(1,10,2) # must have same dimensions
m$x
m$weights
m$weights <- 1:10/sum(1:10)
m$weights
# with gradients
m <- Measure(x = matrix(0, 10, 2), adapt = "weights")
m$requires_grad # TRUE
m$requires_grad <- "none" # turns off
m$requires_grad # FALSE
m$requires_grad <- "x"
m$requires_grad # TRUE
m <- Measure(matrix(0, 10, 2), adapt = "none")
m$grad # NULL
m <- Measure(matrix(0, 10, 2), adapt = "weights")
loss <- sum(m$weights * 1:10)
loss$backward()
m$grad
# note the weights gradient is on the log softmax scale
#and the first parameter is fixed for identifiability
m$grad <- rep(1,9)
m$grad
}
An R6 object for measures
Description
Internal R6 class object for Measure objects
Public fields
balance_functionsthe functions of the data that we want to adjust towards the targets
balance_targetthe values the balance_functions are targeting
adaptWhat aspect of the data will be adapted. One of "none","weights", or "x".
devicethe
torch::torch_device()of the data.dtypethe torch::torch_dtype of the data.
nthe rows of the covariates, x.
dthe columns of the covariates, x.
probability_measureis the measure a probability measure?
Active bindings
gradgets or sets gradient
init_weightsreturns the initial value of the weights
init_datareturns the initial value of the data
requires_gradchecks or turns on/off gradient
weightsgets or sets weights
xGets or sets the data.
Methods
Public methods
Method detach()
generates a deep clone of the object without gradients.
Usage
Measure_$detach()
Method get_weight_parameters()
Makes a copy of the weights parameters.
Usage
Measure_$get_weight_parameters()
Method print()
prints the measure object
Usage
Measure_$print(...)
Arguments
...Not used
Method new()
Constructor function
Usage
Measure_$new(
x,
weights = NULL,
probability.measure = TRUE,
adapt = c("none", "weights", "x"),
balance.functions = NA_real_,
target.values = NA_real_,
dtype = NULL,
device = NULL
)Arguments
xThe data points
weightsThe empirical measure. If NULL, assigns equal weight to each observation
probability.measureIs the empirical measure a probability measure? Default is TRUE.
adaptShould we try to adapt the data ("x"), the weights ("weights"), or neither ("none"). Default is "none".
balance.functionsA matrix of functions of the covariates to target for mean balance. If NULL and
target.valuesare provided, will use the data inx.target.valuesThe targets for the balance functions. Should be the same length as columns in
balance.functions.dtypeThe torch::torch_dtype or NULL.
deviceThe device to have the data on. Should be result of
torch::torch_device()or NULL.
Method clone()
The objects of this class are cloneable with this method.
Usage
Measure_$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Object Oriented OT Problem
Description
Object Oriented OT Problem
Usage
OTProblem(measure_1, measure_2, ...)
Arguments
measure_1 |
An object of class Measure |
measure_2 |
An object of class Measure |
... |
Not used at this time |
Value
An R6 object of class "OTProblem"
Public fields
devicethe
torch::torch_device()of the data.dtypethe torch::torch_dtype of the data.
selected_deltathe delta value selected after
choose_hyperparametersselected_lambdathe lambda value selected after
choose_hyperparameters
Active bindings
lossprints the current value of the objective. Only availble after the
OTProblem$solve()method has been runpenaltyReturns a list of the lambda and delta penalities that will be iterated through. To set these values, use the
OTProblem$setup_arguments()function.
Methods
Public methods
Method add()
adds o2 to the OTProblem
Usage
OTProblem$add(o2)
Arguments
o2A number or object of class OTProblem
Method subtract()
subtracts o2 from OTProblem
Usage
OTProblem$subtract(o2)
Arguments
o2A number or object of class OTProblem
Method multiply()
multiplies OTProblem by o2
Usage
OTProblem$multiply(o2)
Arguments
o2A number or an object of class OTProblem
Method divide()
divides OTProblem by o2
Usage
OTProblem$divide(o2)
Arguments
o2A number or object of class OTProblem
Method setup_arguments()
Usage
OTProblem$setup_arguments(
lambda,
delta,
grid.length = 7L,
cost.function = NULL,
p = 2,
cost.online = "auto",
debias = TRUE,
diameter = NULL,
ot_niter = 1000L,
ot_tol = 0.001
)Arguments
lambdaThe penalty parameters to try for the OT problems. If not provided, function will select some
deltaThe constraint paramters to try for the balance function problems, if any
grid.lengthThe number of hyperparameters to try if not provided
cost.functionThe cost function for the data. Can be any function that takes arguments
x1,x2,p. Defaults to the Euclidean distancepThe power to raise the cost matrix by. Default is 2
cost.onlineShould online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debiasShould debiased OT problems be used? Defaults to TRUE
diameterDiameter of the cost function.
ot_niterNumber of iterations to run the OT problems
ot_tolThe tolerance for convergence of the OT problems
Returns
NULL
Examples
ot$setup_arguments(lambda = c(1000,10))
Method solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
Usage
OTProblem$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)Arguments
niterThe nubmer of iterations to run solver at each combination of hyperparameter values
tolThe tolerance for convergence
optimizerThe optimizer to use. One of "torch" or "frank-wolfe"
torch_optimThe
torch_optimizerto use. Default is torch::optim_lbfgstorch_schedulerThe torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_argsArguments passed to the torch optimizer and scheduler
osqp_argsArguments passed to
osqp::osqpSettings()if appropriatequick.balance.functionShould
osqp::osqp()be used to select balance function constraints (delta) or not. Default true.
Examples
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
Method choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
Usage
OTProblem$choose_hyperparameters(
n_boot_lambda = 100L,
n_boot_delta = 1000L,
lambda_bootstrap = Inf
)Arguments
n_boot_lambdaThe number of bootstrap iterations to run when selecting lambda
n_boot_deltaThe number of bootstrap iterations to run when selecting delta
lambda_bootstrapThe penalty parameter to use when selecting lambda. Higher numbers run faster.
Examples
ot$choose_hyperparameters(n_boot_lambda = 10,
n_boot_delta = 10,
lambda_bootstrap = Inf)
Method info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
Usage
OTProblem$info()
Returns
a list with slots
-
lossthe final loss values -
iterationsThe number of iterations run for each combination of parameters -
balance.function.differencesThe final differences in the balance functions -
hyperparam.metricsA list of the bootstrap evalustion for delta and lambda values
Examples
ot$info()
Method clone()
The objects of this class are cloneable with this method.
Usage
OTProblem$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Examples
## ------------------------------------------------
## Method `OTProblem(measure_1, measure_2)`
## ------------------------------------------------
if (torch::torch_is_installed()) {
# setup measures
x <- matrix(1, 100, 10)
m1 <- Measure(x = x)
y <- matrix(2, 100, 10)
m2 <- Measure(x = y, adapt = "weights")
z <- matrix(3,102, 10)
m3 <- Measure(x = z)
# setup OT problems
ot1 <- OTProblem(m1, m2)
ot2 <- OTProblem(m3, m2)
ot <- 0.5 * ot1 + 0.5 * ot2
print(ot)
## ------------------------------------------------
## Method `OTProblem$setup_arguments`
## ------------------------------------------------
ot$setup_arguments(lambda = 1000)
## ------------------------------------------------
## Method `OTProblem$solve`
## ------------------------------------------------
ot$solve(niter = 1, torch_optim = torch::optim_rmsprop)
## ------------------------------------------------
## Method `OTProblem$choose_hyperparameters`
## ------------------------------------------------
ot$choose_hyperparameters(n_boot_lambda = 1,
n_boot_delta = 1,
lambda_bootstrap = Inf)
## ------------------------------------------------
## Method `OTProblem$info`
## ------------------------------------------------
ot$info()
}
An R6 class to construct OTProblems
Description
OTProblem R6 class
Public fields
devicethe
torch::torch_device()of the data.dtypethe torch::torch_dtype of the data.
selected_deltathe delta value selected after
choose_hyperparametersselected_lambdathe lambda value selected after
choose_hyperparameters
Active bindings
lossprints the current value of the objective. Only availble after the solve method has been run
penaltyReturns a list of the lambda and delta penalities that will be iterated through. To set these values, use the
setup_argumentsfunction.
Methods
Public methods
Method add()
adds o2 to the OTProblem
Usage
OTProblem_$add(o2)
Arguments
o2A number or object of class OTProblem
Method subtract()
subtracts o2 from OTProblem
Usage
OTProblem_$subtract(o2)
Arguments
o2A number or object of class OTProblem
Method multiply()
multiplies OTProblem by o2
Usage
OTProblem_$multiply(o2)
Arguments
o2A number or object of class OTProblem
Method divide()
divides OTProblem by o2
Usage
OTProblem_$divide(o2)
Arguments
o2A number or object of class OTProblem
Method print()
prints the OT problem object
Usage
OTProblem_$print(...)
Arguments
...Not used
Method new()
Constructor method
Usage
OTProblem_$new(measure_1, measure_2)
Arguments
Returns
An R6 object of class "OTProblem"
Method setup_arguments()
Usage
OTProblem_$setup_arguments( lambda, delta, grid.length = 7L, cost.function = NULL, p = 2, cost.online = "auto", debias = TRUE, diameter = NULL, ot_niter = 1000L, ot_tol = 0.001 )
Arguments
lambdaThe penalty parameters to try for the OT problems. If not provided, function will select some
deltaThe constraint paramters to try for the balance function problems, if any
grid.lengthThe number of hyperparameters to try if not provided
cost.functionThe cost function for the data. Can be any function that takes arguments
x1,x2,p. Defaults to the Euclidean distancepThe power to raise the cost matrix by. Default is 2
cost.onlineShould online costs be used? Default is "auto" but "tensorized" stores the cost matrix in memory while "online" will calculate it on the fly.
debiasShould debiased OT problems be used? Defaults to TRUE
diameterDiameter of the cost function.
ot_niterNumber of iterations to run the OT problems
ot_tolThe tolerance for convergence of the OT problems
Returns
NULL
Method solve()
Solve the OTProblem at each parameter value. Must run setup_arguments first.
Usage
OTProblem_$solve(
niter = 1000L,
tol = 1e-05,
optimizer = c("torch", "frank-wolfe"),
torch_optim = torch::optim_lbfgs,
torch_scheduler = torch::lr_reduce_on_plateau,
torch_args = NULL,
osqp_args = NULL,
quick.balance.function = TRUE
)Arguments
niterThe nubmer of iterations to run solver at each combination of hyperparameter values
tolThe tolerance for convergence
optimizerThe optimizer to use. One of "torch" or "frank-wolfe"
torch_optimThe
torch_optimizerto use. Default is torch::optim_lbfgstorch_schedulerThe torch::lr_scheduler to use. Default is torch::lr_reduce_on_plateau
torch_argsArguments passed to the torch optimizer and scheduler
osqp_argsArguments passed to
osqp::osqpSettings()if appropriatequick.balance.functionShould
osqp::osqp()be used to select balance function constraints (delta) or not. Default true.
Method choose_hyperparameters()
Selects the hyperparameter values through a bootstrap algorithm
Usage
OTProblem_$choose_hyperparameters( n_boot_lambda = 100L, n_boot_delta = 1000L, lambda_bootstrap = Inf )
Arguments
n_boot_lambdaThe number of bootstrap iterations to run when selecting lambda
n_boot_deltaThe number of bootstrap iterations to run when selecting delta
lambda_bootstrapThe penalty parameter to use when selecting lambda. Higher numbers run faster.
Method info()
Provides diagnostics after solve and choose_hyperparameter methods have been run.
Usage
OTProblem_$info()
Returns
a list with slots
-
lossthe final loss values -
iterationsThe number of iterations run for each combination of parameters -
balance.function.differencesThe final differences in the balance functions -
hyperparam.metricsA list of the bootstrap evalustion for delta and lambda values
Method clone()
The objects of this class are cloneable with this method.
Usage
OTProblem_$clone(deep = FALSE)
Arguments
deepWhether to make a deep clone.
Pareto-Smoothed Importance Sampling
Description
Pareto-Smoothed Importance Sampling
Usage
PSIS(x, r_eff = NULL, ...)
## S4 method for signature 'numeric'
PSIS(x, r_eff = NULL, ...)
## S4 method for signature 'causalWeights'
PSIS(x, r_eff = NULL, ...)
## S4 method for signature 'list'
PSIS(x, r_eff = NULL, ...)
PSIS_diag(x, ...)
## S4 method for signature 'numeric'
PSIS_diag(x, r_eff = NULL)
## S4 method for signature 'causalWeights'
PSIS_diag(x, r_eff = NULL)
## S4 method for signature 'causalPSIS'
PSIS_diag(x, ...)
## S4 method for signature 'list'
PSIS_diag(x, r_eff = NULL)
## S4 method for signature 'psis'
PSIS_diag(x, r_eff = NULL)
Arguments
x |
For |
r_eff |
A vector of relative effective sample size with one estimate per observation. If providing
an object of class causalWeights, should be a list of vectors with one vector for each
sample. See psis() from the |
... |
Arguments passed to the psis() function. |
Details
Acts as a wrapper to the psis() function from the loo package. It
is built to handle the data types found in this package. This method is preferred to the ESS()
function in causalOT since the latter is prone to error (infinite variances) but will not give good any indication that the estimates
are problematic.
Value
For PSIS(), returns a list. See psis() from loo for a description of the outputs. Will give the log of the
smoothed weights in slot log_weights, and in the slot diagnostics, it will give
the pareto_k parameter (see the pareto-k-diagnostic page) and
the n_eff estimates. PSIS_diag() returns the diagnostic slot from an object of class "psis".
Methods (by class)
-
PSIS(numeric): numeric weights -
PSIS(causalWeights): object of class causalWeights -
PSIS(list): list of weights -
PSIS_diag(numeric): numeric weights -
PSIS_diag(causalWeights): object of class causalWeights diagnostics -
PSIS_diag(causalPSIS): diagnostics from the output of a previous call to PSIS -
PSIS_diag(list): a list of objects -
PSIS_diag(psis): output of PSIS function
See Also
Examples
x <- runif(100)
w <- x/sum(x)
res <- PSIS(x = w, r_eff = 1)
PSIS_diag(res)
PSIS casualWeights class
Description
PSIS casualWeights class
Usage
PSIS.causalWeights(x, r_eff = NULL, ...)
Arguments
x |
object of class causalWeights |
r_eff |
pass to PSIS |
... |
pass to PSIS method |
Value
object of class causalPSIS
Barycentric Projection outcome estimation
Description
Barycentric Projection outcome estimation
Usage
barycentric_projection(
formula,
data,
weights,
separate.samples.on = "z",
penalty = NULL,
cost_function = NULL,
p = 2,
debias = FALSE,
cost.online = "auto",
diameter = NULL,
niter = 1000L,
tol = 1e-07,
...
)
Arguments
formula |
A formula object specifying the outcome and covariates. |
data |
A data.frame of the data to use in the model. |
weights |
Either a vector of weights, one for each observations, or an object of class causalWeights. |
separate.samples.on |
The variable in the data denoting the treatment indicator. How to separate samples for the optimal transport calculation |
penalty |
The penalty parameter to use in the optimal transport calculation. By default it is |
cost_function |
A user supplied cost function. If supplied, must take arguments |
p |
The power to raise the cost function. Default is 2.0. For user supplied cost functions, the cost will not be raised by this power unless the user so specifies. |
debias |
Should debiased barycentric projections be used? See details. |
cost.online |
Should an online cost algorithm be used? Default is "auto", which selects an online cost algorithm when the sample size in each group specified by |
diameter |
The diameter of the covariate space, if known. |
niter |
The maximum number of iterations to run the optimal transport problems |
tol |
The tolerance for convergence of the optimal transport problems |
... |
Not used at this time. |
Details
The barycentric projection uses the dual potentials from the optimal transport distance between the two samples to calculate projections from one sample into another. For example, in the sample of controls, we may wish to know their outcome had they been treated. In general, we then seek to minimize
\text{argmin}_{\eta} \sum_{ij} cost(\eta_i, y_j) \pi_{ij}
where \pi_{ij} is the primal solution from the optimal transport problem.
These values can also be de-biased using the solutions from running an optimal transport problem of one sample against itself. Details are listed in Pooladian et al. (2022) https://arxiv.org/abs/2202.08919.
Value
An object of class "bp" which is a list with slots:
-
potentialsThe dual potentials from calculating the optimal transport distance -
penaltyThe value of the penalty parameter used in calculating the optimal transport distance -
cost_functionThe cost function used to calculate the distances between units. -
cost_algA character vector denoting if anL_1distance, a squared euclidean distance, or other distance metric was used. -
pThe power to which the cost matrix was raised if not using a user supplied cost function. -
debiasWhether barycentric projections should be debiased. -
tensorizedTRUE/FALSE denoting wether to use offline cost matrices. -
dataAn object of class dataHolder with the data used to calculate the optimal transport distance. -
y_aThe outcome vector in the first sample. -
y_bThe outcome vector in the second sample. -
x_aThe covariate matrix in the first sample. -
x_bThe covariate matrix in the second sample. -
aThe empirical measure in the first sample. -
bThe empirical measure in the second sample. -
termsThe terms object from the formula.
Examples
if(torch::torch_is_installed()) {
set.seed(23483)
n <- 2^5
pp <- 6
overlap <- "low"
design <- "A"
estimate <- "ATT"
power <- 2
data <- causalOT::Hainmueller$new(n = n, p = pp,
design = design, overlap = overlap)
data$gen_data()
weights <- causalOT::calc_weight(x = data,
z = NULL, y = NULL,
estimand = estimate,
method = "NNM")
df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x())
fit <- causalOT::barycentric_projection(y ~ ., data = df,
weight = weights,
separate.samples.on = "z",
niter = 2)
inherits(fit, "bp")
}
Estimate causal weights
Description
Estimate causal weights
Usage
calc_weight(
x,
z,
estimand = c("ATC", "ATT", "ATE"),
method = supported_methods(),
options = NULL,
weights = NULL,
...
)
Arguments
x |
A numeric matrix of covariates. You can also pass an object of class dataHolder or DataSim, which will make argument |
z |
A binary treatment indicator. |
estimand |
The estimand of interest. One of "ATT","ATC", or "ATE". |
method |
The method to estimate the causal weights. Must be one of the methods returned by |
options |
The options for the solver. Specific options depend on the solver you will be using and you can use the solver specific options functions as detailed below.. |
weights |
The sample weights. Should be |
... |
Not used at this time. |
Details
We detail some of the particulars of the function arguments below.
Causal Optimal Transport (COT)
This is the.main method of the package. This method relies on various solvers depending on the particular options chosen. Please see cotOptions() for more details.
Energy Balancing Weights (EnergyBW)
This is equivalent to COT with an infinite penalty parameter, options(lambda = Inf). Uses the same solver and options as COT, cotOptions().
Nearest Neighbor Matching with replacement (NNM)
This is equivalent to COT with a penalty parameter = 0, options(lambda = 0). Uses the same solver and options as COT, cotOptions().
Synthetic Control Method (SCM)
The SCM method is equivalent to an OT problem from a different angle. See scmOptions().
Entropy Balancing Weights (EntropyBW)
This method balances chosen functions of the covariates specified in the data argument, x. See entBWOptions() for more details. Hainmueller (2012).
Stable Balancing Weights (SBW)
Entropy Balancing Weights with a different penalty parameter, proposed by Zuizarreta (2012). See sbwOptions() for more details
Covariate Balancing Propensity Score (CBPS)
The CBPS method of Imai and Ratkovic. Options argument is passed to the function CBPS().
Logistic Regression or Probit Regression
The main methods historically for implementing inverse probability weights. Options are passed directly to the glm function from R.
Value
An object of class causalWeights
See Also
Examples
set.seed(23483)
n <- 2^5
p <- 6
#### get data ####
data <- Hainmueller$new(n = n, p = p)
data$gen_data()
x <- data$get_x()
z <- data$get_z()
if (torch::torch_is_installed()) {
# estimate weights
weights <- calc_weight(x = x,
z = z,
estimand = "ATE",
method = "COT",
options = list(lambda = 0))
#we can also use the dataSim object directly
weightsDS <- calc_weight(x = data,
z = NULL,
estimand = "ATE",
method = "COT",
options = list(lambda = 0))
all.equal(weights@w0, weightsDS@w0)
all.equal(weights@w1, weightsDS@w1)
}
causalEffect class
Description
causalEffect class
causalEffect constructor function
Usage
causalEffect(data, causalWeights, model.outputs, augment.estimate, call)
Arguments
data |
an object of class dataHolder |
causalWeights |
an object of class causalWeights |
model.outputs |
Outputs of the estimate_model() function |
augment.estimate |
Is the estimate to be the augmented (doubly robust) estimator? TRUE/FALSE |
call |
the call used to calculate the treatment effects |
Details
The variables in slot augmentedData are
-
weights: The causalWeights targeting the causal estimand. -
y_obs: The vector of the observed outcomes for each observation -
y_0: The outcome under the control condition. Missingness respects the design of the experiment. i.e.,Y(0) | Z = 1=NA. -
y_hat_0: The conditional mean outcome under the control condition. Estimated from a model. -
y_hat_1: The conditional mean outcome under the treatment condition. Estimated from a model. -
x: The columns denoting the covariates. -
z: The treatment indicator.
The slot fit is a list with slots control, treated, and overall_sample. Control and treated will be filled if estimate.separately is TRUE in estimate_effect. overall_sample will be filled if estimate.separately is FALSE.
Value
an object of class causalEffect
Slots
estimateThe estimated treatment effect.
estimandThe estimand of interest
weightsThe weights as an object of class causalWeights
augmentedDataThe data as a
data.framewith variablesweights,y_obs,y_0,y_1,y_hat_0,y_hat_1,x, andz. See details for more info.fitThe fitted model if present. See details.
callThe call from the estimate_effect() function.
causalWeights class
Description
causalWeights class
Details
This object is returned by the calc_weight function in this package. The slots can be accessed as any S4 object. There is no publicly accessible constructor function.
Slots
w0A slot with the weights for the control group with
n_0entries. Weights sum to 1.w1The weights for the treated group with
n_1entries. Weights sum to 1.estimandA character denoting the estimand targeted by the weights. One of "ATT","ATC", or "ATE".
infoA slot to store a variety of info for inference. Currently under development.
methodA character denoting the method used to estimate the weights.
penaltyA list or the selected penalty parameters, if relevant.
dataThe dataHolder object containing the original data.
callThe call used to construct the weights.
Extract treatment effect estimate
Description
Extract treatment effect estimate
Usage
## S3 method for class 'causalEffect'
coef(object, ...)
Arguments
object |
An object of class causalEffect |
... |
Not used |
Value
A number corresponding to the estimated treatment effect
Examples
# set-up data
set.seed(1234)
data <- Hainmueller$new()
data$gen_data()
# calculate quantities
weight <- calc_weight(data, method = "Logistic", estimand = "ATE")
tx_eff <- estimate_effect(causalWeights = weight)
all.equal(coef(tx_eff), c(estimate = tx_eff@estimate))
Options available for the COT method
Description
Options available for the COT method
Usage
cotOptions(
lambda = NULL,
delta = NULL,
opt.direction = c("dual", "primal"),
debias = TRUE,
p = 2,
cost.function = NULL,
cost.online = "auto",
diameter = NULL,
balance.formula = NULL,
quick.balance.function = TRUE,
grid.length = 7L,
torch.optimizer = torch::optim_rmsprop,
torch.scheduler = torch::lr_multiplicative,
niter = 2000,
nboot = 100L,
lambda.bootstrap = 0.05,
tol = 1e-04,
device = NULL,
dtype = NULL,
...
)
Arguments
lambda |
The penalty parameter for the entropy penalized optimal transport. Default is NULL. Can be a single number or a set of numbers to try. |
delta |
The bound for balancing functions if they are being used. Only available for biased entropy penalized optimal transport. Can be a single number or a set of numbers to try. |
opt.direction |
Should the optimizer solve the primal or dual problems. Should be one of "dual" or "primal" with a default of "dual" since it is typically faster. |
debias |
Should debiased optimal transport be used? TRUE or FALSE. |
p |
The power of the cost function to use for the cost. |
cost.function |
A function to calculate the pairwise costs. Should take arguments |
cost.online |
Should an online cost algorithm be used? One of "auto", "online", or "tensorized". "tensorized" is the offline option. |
diameter |
The diameter of the covariate space, if known. Default is NULL. |
balance.formula |
Formula for the balancing functions. |
quick.balance.function |
TRUE or FALSE denoting whether balance function constraints should be selected via a linear program (TRUE) or just checked for feasibility (FALSE). Default is TRUE. |
grid.length |
The number of penalty parameters to explore in a grid search if none are provided in arguments |
torch.optimizer |
The torch optimizer to use for methods using debiased entropy penalized optimal transport. If |
torch.scheduler |
The scheduler for the optimizer. Defaults to |
niter |
The number of iterations to run the solver |
nboot |
The number of iterations for the bootstrap to select the final penalty parameters. |
lambda.bootstrap |
The penalty parameter to use for the bootstrap hyperparameter selection of lambda. |
tol |
The tolerance for convergence |
device |
An object of class |
dtype |
An object of class |
... |
Arguments passed to the solvers. See details |
Value
A list of class cotOptions with the following slots
-
lambdaThe penalty parameter for the optimal transport distance -
deltaThe constraint for the balancing functions -
opt.directionWhether to solve the primal or dual optimization problems -
debiasTRUE or FALSE if debiased optimal transport distances are used -
balance.formulaThe formula giving how to generate the balancing functions. -
quick.balance.functionTRUE or FALSE whether quick balance functions will be run. -
grid.lengthThe number of parameters to check in a grid search of best parameters -
pThe power of the cost function -
cost.onlineWhether online costs are used -
cost.functionThe user supplied cost function if supplied. -
diameterThe diameter of the covariate space. -
torch.optimizerThetorchoptimizer used for Sinkhorn Divergences -
torch.schedulerThe scheduler for thetorchoptimizer -
solver.optionsThe arguments to be passeed to thetorch.optimizer -
scheduler.optionsThe arguments to be passeed to thetorch.scheduler -
osqp.optionsArguments passed to theosqpfunction if quick balance functions are used. -
niterThe number of iterations to run the solver -
nbootThe number of bootstrap samples -
lambda.bootstrapThe penalty parameter to use for the bootstrap hyperparameter selection. -
tolThe tolerance for convergence. -
deviceAn object of classtorch_device. -
dtypeAn object of classtorch_dtype.
Solvers and distances
The function is setup to direct the COT optimizer to run two basic methods: debiased entropy penalized optimal transport (Sinkhorn Divergences) or entropy penalized optimal transport (Sinkhorn Distances).
Sinkhorn Distances
The optimal transport problem solved is min_w OT_\lambda(w,b) where
OT_\lambda(w,b) = \sum_{ij} C(x_i, x_j) P_{ij} + \lambda \sum_{ij} P_{ij}\log(P_{ij}),
such that the rows of the matrix P_{ij} sum to w and the columns sum to b. In this case C(,) is the cost between units i and j.
Sinkhorn Divergences
The Sinkhorn Divergence solves
min_w OT_\lambda(w,b) - 0.5 OT_\lambda(w,w) - 0.5 * OT_\lambda(b,b).
The solver for this function uses the torch package in R and by default will use the optim_rmsprop solver. Your desired torch optimizer can be passed via torch.optimizer with a scheduler passed via torch.scheduler. GPU support is available as detailed in the torch package. Additional arguments in ... are passed as extra arguments to the torch optimizer and schedulers as appropriate.
Function balancing
There may be certain functions of the covariates that we wish to balance within some tolerance, \delta. For these functions B, we will desire
\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta
, where in this case we are targeting balance with the treatment group for the ATT. \sigma is the pooled standard deviation prior to balancing.
Cost functions
The cost function specifies pairwise distances. If argument cost.function is NULL, the function will default to using L_p^p distances with a default p = 2 supplied by the argument p. So for p = 2, the cost between units x_i and x_j will be
C(x_i, x_j) = \frac{1}{2} \| x_i - x_j \|_2^2.
If cost.function is provided, it should be a function that takes arguments x1, x2, and p: function(x1, x2, p){...}.
Examples
if ( torch::torch_is_installed()) {
opts1 <- cotOptions(lambda = 1e3, torch.optimizer = torch::optim_rmsprop)
opts2 <- cotOptions(lambda = NULL)
opts3 <- cotOptions(lambda = seq(0.1, 100, length.out = 7))
}
cot_solve method for ateClass objects
Description
cot_solve method for ateClass objects
Usage
## S4 method for signature 'ateClass'
cot_solve(object)
Arguments
object |
ateClass. |
Value
object of class causalWeights
cot_solve for gridSearch
Description
cot_solve for gridSearch
Usage
## S4 method for signature 'gridSearch'
cot_solve(object)
Arguments
object |
gridSearch. |
Value
returns object of class causalWeights
cot_solve method for likelihoodMethods
Description
cot_solve method for likelihoodMethods
Usage
## S4 method for signature 'likelihoodMethods'
cot_solve(object)
Arguments
object |
likelihoodMethods. |
Value
object of class causalWeights
dataHolder
Description
dataHolder
Usage
dataHolder(x, z, y = NA_real_, weights = NA_real_)
Arguments
x |
the covariate data. Can be a matrix, an object of class |
z |
the treatment indicator |
y |
the outcome data |
weights |
the empirical distribution of the sample |
Details
Creates an object used internally by the causalOT package for data management.
Value
Returns an object of class dataHolder with slots
-
xmatrix. A matrix of confounders. -
zinteger. The treatment indicator,z_i \in \{0,1\}. -
ynumeric. The outcome data. -
n0integer. The number of observations where z==0 -
n1integer. The number of observations where z==1 -
weightsnumeric. The empirical distribution of the full sample.
Examples
x <- matrix(0, 100, 10)
z <- stats::rbinom(100, 1, 0.5)
# don't need to provide outcome
# function will assume each observation gets equal mass
dataHolder(x = x, z = z)
dataHolder-methods
Description
dataHolder-methods
dataHolder-methods
dataHolder-methods
dataHolder-methods
Usage
## S4 method for signature 'dataHolder'
dataHolder(x, z = NA_integer_, y = NA_real_)
## S4 method for signature 'matrix'
dataHolder(x, z, y = NA_real_, weights = NA_real_)
dataHolder.DataSim(x, z, y = NA_real_, weights = NA_real_)
## S4 method for signature 'ANY'
dataHolder(x, z = NA_integer_, y = NA_real_, weights = NA_real_)
## S3 method for class 'dataHolder'
terms(x, ...)
Arguments
x |
dataHolder object constructed from a formula |
... |
Not used at this time |
Value
a list with the formula terms for treatment and, if present, outcome formulae.
dataHolder-class
Description
dataHolder-class
Slots
xmatrix. A matrix of confounders.
zinteger. The treatment indicator,
z_i \in \{0,1\}.ynumeric. The outcome data.
n0integer. The number of observations where z==0
n1integer. The number of observations where z==1
weightsnumeric. The empirical distribution of the full sample.
Title
Description
Title
Usage
## S3 method for class 'dataHolder'
data_separate(data, estimand)
Arguments
data |
dataHolder. |
estimand |
character. |
df2dataHolder
Description
Function to turn a data.frame into a dataHolder object.
Usage
df2dataHolder(
treatment.formula,
outcome.formula = NA_character_,
data,
weights = NA_real_
)
Arguments
treatment.formula |
a formula specifying the treatment indicator and covariates. Required. |
outcome.formula |
an optional formula specifying the outcome function. |
data |
a data.frame with the data |
weights |
optional vector of sampling weights for the data |
Details
This will take the formulas specified and transform that data.frame into a dataHolder object that is used internally by the causalOT package. Take care if you do not specify an outcome formula that you do not include the outcome in the data.frame. If you are not careful, the function may include the outcome as a covariate, which is not kosher in causal inference during the design phase.
If both outcome.formula and treatment.formula are specified, it will assume you are in the design phase, and create a combined covariate matrix to balance on the assumed treatment and outcome models.
If you are in the outcome phase of estimation, you can just provide a dummy formula for the treatment.formula like "z ~ 0" just so the function can identify the treatment indicator appropriately in the data creation phase.
Value
Returns an object of class dataHolder()
Examples
set.seed(20348)
n <- 15
d <- 3
x <- matrix(stats::rnorm(n*d), n, d)
z <- rbinom(n, 1, prob = 0.5)
y <- rnorm(n)
weights <- rep(1/n,n)
df <- data.frame(x, z, y)
dh <- df2dataHolder(
treatment.formula = "z ~ .",
outcome.formula = "y ~ ." ,
data = df,
weights = weights)
df2dataHolder-methods
Description
df2dataHolder-methods
Usage
## S4 method for signature 'ANY,ANY,data.frame'
df2dataHolder(
treatment.formula = NA_character_,
outcome.formula = NA_character_,
data,
weights = NA_real_
)
Options for the Entropy Balancing Weights
Description
Options for the Entropy Balancing Weights
Usage
entBWOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)
Arguments
delta |
A number or vector of tolerances for the balancing functions. Default is NULL which will use a grid search |
grid.length |
The number of values to try in the grid search |
nboot |
The number of bootstrap samples to run during the grid search. |
... |
Arguments passed on to lbfgsb3c() |
Value
A list of class entBWOptions with slots
-
deltaDelta values to try -
grid.lengthThe number of parameters to try -
nbootNumber of bootstrap samples -
solver.optionsA list of options passed to 'lbfgsb3c()
Function balancing
This method will balance functions of the covariates within some tolerance, \delta. For these functions B, we will desire
\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta
, where in this case we are targeting balance with the treatment group for the ATT. \sigma is the pooled standard deviation prior to balancing.
Examples
opts <- entBWOptions(delta = 0.1)
Estimate treatment effects
Description
Estimate treatment effects
Usage
estimate_effect(
causalWeights,
x = NULL,
y = NULL,
model.function,
estimate.separately = TRUE,
augment.estimate = FALSE,
normalize.weights = TRUE,
...
)
Arguments
causalWeights |
An object of class causalWeights |
x |
A dataHolder, matrix, data.frame, or object of class DataSim. See calc_weight for more details how to input the data. If |
y |
The outcome vector. |
model.function |
The modeling function to use, if desired. Must take arguments "formula", "data", and "weights". Other arguments passed via |
estimate.separately |
Should the outcome model be estimated separately in each treatment group? TRUE or FALSE. |
augment.estimate |
Should an augmented, doubly robust estimator be used? |
normalize.weights |
Should the weights in the |
... |
Pass additional arguments to the outcome modeling functions. |
Value
an object of class causalEffect
Examples
if ( torch::torch_is_installed() ){
# set-up data
data <- Hainmueller$new()
data$gen_data()
# calculate quantities
weight <- calc_weight(data, method = "COT",
estimand = "ATT",
options = list(lambda = 0))
tx_eff <- estimate_effect(causalWeights = weight)
# get estimate
print(tx_eff@estimate)
all.equal(coef(tx_eff), c(estimate = tx_eff@estimate))
}
Function to estimate outcome models
Description
Function to estimate outcome models
Usage
estimate_model(data, causalWeights, model.function, separate.estimation, ...)
Arguments
data |
A |
causalWeights |
A causalWeights object |
model.function |
The model function passed by the user |
separate.estimation |
TRUE or FALSE, should models be estimated separately in each group? |
... |
Extra agruments passed to the predict functions |
Value
a list with slots y_hat_0, y_hat_1, and fit.
gridSearch S4 class
Description
gridSearch S4 class
Slots
penalty_listnumeric.
nbootinteger.
solverR6.
methodcharacter.
estimandcharacter.
Standardized absolute mean difference calculations
Description
This function will calculate the difference in means between treatment groups standardized by the pooled standard-deviation of the respective covariates.
Usage
mean_balance(x = NULL, z = NULL, weights = NULL, ...)
Arguments
x |
Either a matrix, an object of class dataHolder, or an object of class DataSim |
z |
A integer vector denoting the treatments of each observations. Can be null if |
weights |
An object of class causalWeights. |
... |
Not used at this time. |
Value
A vector of mean balances
Examples
n <- 100
p <- 6
x <- matrix(stats::rnorm(n * p), n, p)
z <- stats::rbinom(n, 1, 0.5)
weights <- calc_weight(x = x, z = z, estimand = "ATT", method = "Logistic")
mb <- mean_balance(x = x, z = z, weights = weights)
print(mb)
Internal function to select appropriate loss function
Description
Selects sinkhorn or energy distance losses depending on value of penalty parameter
Usage
oop_loss_select(ot)
Arguments
ot |
an OT object |
Optimal Transport Distance
Description
Optimal Transport Distance
Usage
ot_distance(
x1,
x2 = NULL,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'causalWeights'
ot_distance(
x1,
x2 = NULL,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'matrix'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'array'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
## S3 method for class 'torch_tensor'
ot_distance(
x1,
x2,
a = NULL,
b = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07
)
Arguments
x1 |
Either an object of class causalWeights or a matrix of the covariates in the first sample |
x2 |
|
a |
Empirical measure of the first sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
b |
Empirical measure of the second sample. If NULL, assumes each observation gets equal mass. Ignored for objects of class causalWeights. |
penalty |
The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE. |
p |
|
cost |
Supply your own cost function. Should take arguments |
debias |
TRUE or FALSE. Should the debiased optimal transport distances be used. |
online.cost |
How to calculate the distance matrix. One of "auto", "tensorized", or "online". |
diameter |
The diameter of the metric space, if known. Default is NULL. |
niter |
The maximum number of iterations for the Sinkhorn updates |
tol |
The tolerance for convergence |
Value
For objects of class matrix, numeric value giving the optimal transport distance. For objects of class causalWeights, results are returned as a list for before ('pre') and after adjustment ('post').
Methods (by class)
-
ot_distance(causalWeights): method for causalWeights class -
ot_distance(matrix): method for matrices -
ot_distance(array): method for arrays -
ot_distance(torch_tensor): method for torch_tensors
Examples
if ( torch::torch_is_installed()) {
x <- matrix(stats::rnorm(10*5), 10, 5)
z <- stats::rbinom(10, 1, 0.5)
weights <- calc_weight(x = x, z = z, method = "Logistic", estimand = "ATT")
ot1 <- ot_distance(x1 = weights, penalty = 100,
p = 2, debias = TRUE, online.cost = "auto",
diameter = NULL)
ot2<- ot_distance(x1 = x[z==0, ], x2 = x[z == 1,],
a= weights@w0/sum(weights@w0), b = weights@w1,
penalty = 100, p = 2, debias = TRUE, online.cost = "auto", diameter = NULL)
all.equal(ot1$post, ot2)
}
plot.causalWeights
Description
plot.causalWeights
Usage
## S3 method for class 'causalWeights'
plot(
x,
r_eff = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07,
...
)
Arguments
x |
A causalWeights object |
r_eff |
The |
penalty |
The penalty of the optimal transport distance to use. If missing or NULL, the function will try to guess a suitable value depending if debias is TRUE or FALSE. |
p |
|
cost |
Supply your own cost function. Should take arguments |
debias |
TRUE or FALSE. Should the debiased optimal transport distances be used. |
online.cost |
How to calculate the distance matrix. One of "auto", "tensorized", or "online". |
diameter |
The diameter of the metric space, if known. Default is NULL. |
niter |
The maximum number of iterations for the Sinkhorn updates |
tol |
The tolerance for convergence |
... |
Not used at this time |
Details
The plot method first calls summary.causalWeights on the causalWeights object. Then plots the diagnostics from that summary object.
Value
The plot method returns an invisible object of class summary_causalWeights.
See Also
An external control trial of treatments for post-partum hemorrhage
Description
A dataset evaluating treatments for post-partum hemorrhage. The data contain treatment groups receiving misoprostol vs potential controls from other locations that received only oxytocin. The data is stored as a numeric matrix.
Usage
data(pph)
Format
A matrix with 802 rows and 17 variables
Details
The variables are as follows:
cum_blood_20m. The outcome variable denoting cumulative blood loss in mL 20 minutes after the diagnosis of post-partum hemorrhage (650 – 2000).
tx. The treatment indicator of whether an individual received misoprostol (1) or oxytocin (0).
age. the mother's age in years (15 – 43).
no_educ. whether a woman had no education (1) or some education (0).
num_livebirth. the number of previous live births.
cur_married. whether a mother is currently married (1 = yes, 0 = no).
gest_age. the gestational age of the fetus in weeks (35 – 43).
prev_pphyes. whether the woman has had a previous post-partum hemorrahge.
hb_test. the woman's hemoglobin in mg/dL (7 – 15).
induced_laboryes. whether labor was induced (1 = yes, 0 = no).
augmented_laboryes. whether labor was augmented (1 = yes, 0 = no).
early_cordclampyes. whether the umbilical cord was clamped early (1 = yes, 0 = no).
control_cordtractionyes. whether cord traction was controlled (1 = yes, 0 = no).
uterine_massageyes. whether a uterine massage was given (1 = yes, 0 = no).
placenta. whether placenta was delivered before treatment given (1 = yes, 0 = no).
bloodlossattx. amount of blood lost when treatment given (500 mL – 1800 mL)
sitecode. Which site is the individual from? (1 = Cairo, Egypt, 2 = Turkey, 3 = Hocmon, Vietnam, 4 = Cuchi, Vietnam, and 5 Burkina Faso).
Source
Data from the following Harvard Dataverse:
Winikoff, Beverly, 2019, "Two randomized controlled trials of misoprostol for the treatment of postpartum hemorrhage", https://doi.org/10.7910/DVN/ETHH4N, Harvard Dataverse, V1.
The data was originally analyzed in
Blum, J. et al. Treatment of post-partum haemorrhage with sublingual misoprostol versus oxytocin in women receiving prophylactic oxytocin: a double-blind, randomised, non-inferiority trial. The Lancet 375, 217–223 (2010).
Predict method for barycentric projection models
Description
Predict method for barycentric projection models
Usage
## S3 method for class 'bp'
predict(
object,
newdata = NULL,
source.sample,
cost_function = NULL,
niter = 1000,
tol = 1e-07,
...
)
Arguments
object |
An object of class "bp" |
newdata |
a data.frame containing new observations |
source.sample |
a vector giving the sample each observations arise from |
cost_function |
a cost metric between observations |
niter |
number of iterations to run the barycentric projection for powers > 2. |
tol |
Tolerance on the optimization problem for projections with powers > 2. |
... |
Dots passed to the lbfgs method in the torch package. |
Examples
if(torch::torch_is_installed()) {
set.seed(23483)
n <- 2^5
pp <- 6
overlap <- "low"
design <- "A"
estimate <- "ATT"
power <- 2
data <- causalOT::Hainmueller$new(n = n, p = pp,
design = design, overlap = overlap)
data$gen_data()
weights <- causalOT::calc_weight(x = data,
z = NULL, y = NULL,
estimand = estimate,
method = "NNM")
df <- data.frame(y = data$get_y(), z = data$get_z(), data$get_x())
# undebiased
fit <- causalOT::barycentric_projection(y ~ ., data = df,
weight = weights,
separate.samples.on = "z", niter = 2)
#debiased
fit_d <- causalOT::barycentric_projection(y ~ ., data = df,
weight = weights,
separate.samples.on = "z", debias = TRUE, niter = 2)
# predictions, without new data
undebiased_predictions <- predict(fit, source.sample = df$z)
debiased_predictions <- predict(fit_d, source.sample = df$z)
isTRUE(all.equal(unname(undebiased_predictions), df$y)) # FALSE
isTRUE(all.equal(unname(debiased_predictions), df$y)) # TRUE
}
print.dataHolder
Description
print.dataHolder
Usage
## S3 method for class 'dataHolder'
print(x, ...)
Arguments
x |
dataHolder object |
... |
Not used |
Options for the SBW method
Description
Options for the SBW method
Usage
sbwOptions(delta = NULL, grid.length = 20L, nboot = 1000L, ...)
Arguments
delta |
A number or vector of tolerances for the balancing functions. Default is NULL which will use a grid search |
grid.length |
The number of values to try in the grid search |
nboot |
The number of bootstrap samples to run during the grid search. |
... |
Arguments passed on to osqpSettings() |
Value
A list of class sbwOptions with slots
-
deltaDelta values to try -
grid.lengthThe number of parameters to try -
sumto1Forced to be TRUE. Weights will always sum to 1. -
nbootNumber of bootstrap samples -
solver.optionsA list with arguments passed to osqpSettings()
Function balancing
This method will balance functions of the covariates within some tolerance, \delta. For these functions B, we will desire
\frac{\sum_{i: Z_i = 0} w_i B(x_i) - \sum_{j: Z_j = 1} B(x_j)/n_1}{\sigma} \leq \delta
, where in this case we are targeting balance with the treatment group for the ATT. \sigma is the pooled standard deviation prior to balancing.
Examples
opts <- sbwOptions(delta = 0.1)
Options for the SCM Method
Description
Options for the SCM Method
Usage
scmOptions(...)
Arguments
... |
Arguments passed to the osqpSettings() function which solves the problem. |
Details
Options for the solver used in the optimization of the Synthetic Control Method of Abadie and Gardeazabal (2003).
Value
A list with arguments to pass to osqpSettings()
Examples
opts <- scmOptions()
Summary diagnostics for causalWeights
Description
Summary diagnostics for causalWeights
print.summary_causalWeights
plot.summary_causalWeights
Usage
## S3 method for class 'causalWeights'
summary(
object,
r_eff = NULL,
penalty,
p = 2,
cost = NULL,
debias = TRUE,
online.cost = "auto",
diameter = NULL,
niter = 1000,
tol = 1e-07,
...
)
## S3 method for class 'summary_causalWeights'
print(x, ...)
## S3 method for class 'summary_causalWeights'
plot(x, ...)
Arguments
object |
an object of class causalWeights |
r_eff |
The r_eff used in the PSIS calculation. See |
penalty |
The penalty parameter to use |
p |
The power of the Lp distance to use. Overridden by argument |
cost |
A user supplied cost function. Should take arguments |
debias |
Should debiased optimal transport distances be used. TRUE or FALSE |
online.cost |
Should the cost be calculated online? One of "auto","tensorized", or "online". |
diameter |
the diameter of the covariate space. Default is NULL. |
niter |
the number of iterations to run the optimal transport distances |
tol |
the tolerance for convergence for the optimal transport distances |
... |
Not used |
x |
an object of class "summary_causalWeights" |
Value
The summary method returns an object of class "summary_causalWeights".
Functions
-
print(summary_causalWeights): print method -
plot(summary_causalWeights): plot method
Examples
if(torch::torch_is_installed()) {
n <- 2^6
p <- 6
overlap <- "high"
design <- "A"
estimand <- "ATE"
#### get simulation functions ####
original <- Hainmueller$new(n = n, p = p,
design = design, overlap = overlap)
original$gen_data()
weights <- calc_weight(x = original, estimand = estimand, method = "Logistic")
s <- summary(weights)
plot(s)
}
Supported Methods
Description
Supported Methods
Usage
supported_methods()
Value
A character list with supported methods. Note "COT" is the same as "Wasserstein". We provide the second name for backwards compatibility.
Examples
supported_methods()
Get the variance of a causalEffect
Description
Get the variance of a causalEffect
Usage
## S3 method for class 'causalEffect'
vcov(object, ...)
Arguments
object |
An object of class causalEffect |
... |
Passed on to the sandwich estimator if there is a model fit that supports one |
Value
The variance of the treatment effect as a matrix
Examples
# set-up data
set.seed(1234)
data <- Hainmueller$new()
data$gen_data()
# calculate quantities
weight <- calc_weight(data, estimand = "ATT", method = "Logistic")
tx_eff <- estimate_effect(causalWeights = weight)
vcov(tx_eff)