Cram Policy: Efficient Simultaneous Policy Learning and Evaluation
Source:R/cram_policy.R
cram_policy.Rd
This function performs the cram method (simultaneous policy learning and evaluation) for binary policies on data including covariates (X), binary treatment indicator (D) and outcomes (Y).
Usage
cram_policy(
X,
D,
Y,
batch,
model_type = "causal_forest",
learner_type = "ridge",
baseline_policy = NULL,
parallelize_batch = FALSE,
model_params = NULL,
custom_fit = NULL,
custom_predict = NULL,
alpha = 0.05,
propensity = NULL
)
Arguments
- X
A matrix or data frame of covariates for each sample.
- D
A vector of binary treatment indicators (1 for treated, 0 for non-treated).
- Y
A vector of outcome values for each sample.
- batch
Either an integer specifying the number of batches (which will be created by random sampling) or a vector of length equal to the sample size providing the batch assignment (index) for each individual in the sample.
- model_type
The model type for policy learning. Options include
"causal_forest"
,"s_learner"
, and"m_learner"
. Default is"causal_forest"
. Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.- learner_type
The learner type for the chosen model. Options include
"ridge"
for Ridge Regression,"fnn"
for Feedforward Neural Network and"caret"
for Caret. Default is"ridge"
. if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.- baseline_policy
A list providing the baseline policy (binary 0 or 1) for each sample. If
NULL
, defaults to a list of zeros with the same length as the number of rows inX
.- parallelize_batch
Logical. Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches. They are learned in parallel when parallelize_batch is TRUE vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training). Defaults to
FALSE
.- model_params
A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to
NULL
.- custom_fit
A custom, user-defined, function that outputs a fitted model given training data (allows flexibility). Defaults to
NULL
.- custom_predict
A custom, user-defined, function for making predictions given a fitted model and test data (allow flexibility). Defaults to
NULL
.- alpha
Significance level for confidence intervals. Default is 0.05 (95% confidence).
- propensity
The propensity score function for binary treatment indicator (D) (probability for each unit to receive treatment). Defaults to 0.5 (random assignment).
Value
A list containing:
raw_results
: A data frame summarizing key metrics with truncated decimals:Delta Estimate
: The estimated treatment effect (delta).Delta Standard Error
: The standard error of the delta estimate.Delta CI Lower
: The lower bound of the confidence interval for delta.Delta CI Upper
: The upper bound of the confidence interval for delta.Policy Value Estimate
: The estimated policy value.Policy Value Standard Error
: The standard error of the policy value estimate.Policy Value CI Lower
: The lower bound of the confidence interval for policy value.Policy Value CI Upper
: The upper bound of the confidence interval for policy value.Proportion Treated
: The proportion of individuals treated under the final policy.
interactive_table
: An interactive table summarizing key metrics for detailed exploration.final_policy_model
: The final fitted policy model based onmodel_type
andlearner_type
orcustom_fit
.
Examples
# Example data
X_data <- matrix(rnorm(100 * 5), nrow = 100, ncol = 5)
D_data <- as.integer(sample(c(0, 1), 100, replace = TRUE))
Y_data <- rnorm(100)
nb_batch <- 5
# Perform CRAM policy
result <- cram_policy(X = X_data,
D = D_data,
Y = Y_data,
batch = nb_batch)
# Access results
result$raw_results
#> Metric Value
#> 1 Delta Estimate 0.36852
#> 2 Delta Standard Error 0.26573
#> 3 Delta CI Lower -0.15230
#> 4 Delta CI Upper 0.88935
#> 5 Policy Value Estimate 0.53258
#> 6 Policy Value Standard Error 0.27544
#> 7 Policy Value CI Lower -0.00727
#> 8 Policy Value CI Upper 1.07243
#> 9 Proportion Treated 1.00000
result$interactive_table
result$final_policy_model
#> GRF forest object of type causal_forest
#> Number of trees: 100
#> Number of training samples: 100
#> Variable importance:
#> 1 2 3 4 5
#> 0.089 0.187 0.142 0.124 0.160