Skip to contents

This function performs the cram method (simultaneous policy learning and evaluation) for binary policies on data including covariates (X), binary treatment indicator (D) and outcomes (Y).

Usage

cram_policy(
  X,
  D,
  Y,
  batch,
  model_type = "causal_forest",
  learner_type = "ridge",
  baseline_policy = NULL,
  parallelize_batch = FALSE,
  model_params = NULL,
  custom_fit = NULL,
  custom_predict = NULL,
  alpha = 0.05,
  propensity = NULL
)

Arguments

X

A matrix or data frame of covariates for each sample.

D

A vector of binary treatment indicators (1 for treated, 0 for non-treated).

Y

A vector of outcome values for each sample.

batch

Either an integer specifying the number of batches (which will be created by random sampling) or a vector of length equal to the sample size providing the batch assignment (index) for each individual in the sample.

model_type

The model type for policy learning. Options include "causal_forest", "s_learner", and "m_learner". Default is "causal_forest". Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.

learner_type

The learner type for the chosen model. Options include "ridge" for Ridge Regression, "fnn" for Feedforward Neural Network and "caret" for Caret. Default is "ridge". if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.

baseline_policy

A list providing the baseline policy (binary 0 or 1) for each sample. If NULL, defaults to a list of zeros with the same length as the number of rows in X.

parallelize_batch

Logical. Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches. They are learned in parallel when parallelize_batch is TRUE vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training). Defaults to FALSE.

model_params

A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to NULL.

custom_fit

A custom, user-defined, function that outputs a fitted model given training data (allows flexibility). Defaults to NULL.

custom_predict

A custom, user-defined, function for making predictions given a fitted model and test data (allow flexibility). Defaults to NULL.

alpha

Significance level for confidence intervals. Default is 0.05 (95% confidence).

propensity

The propensity score function for binary treatment indicator (D) (probability for each unit to receive treatment). Defaults to 0.5 (random assignment).

Value

A list containing:

  • raw_results: A data frame summarizing key metrics with truncated decimals:

    • Delta Estimate: The estimated treatment effect (delta).

    • Delta Standard Error: The standard error of the delta estimate.

    • Delta CI Lower: The lower bound of the confidence interval for delta.

    • Delta CI Upper: The upper bound of the confidence interval for delta.

    • Policy Value Estimate: The estimated policy value.

    • Policy Value Standard Error: The standard error of the policy value estimate.

    • Policy Value CI Lower: The lower bound of the confidence interval for policy value.

    • Policy Value CI Upper: The upper bound of the confidence interval for policy value.

    • Proportion Treated: The proportion of individuals treated under the final policy.

  • interactive_table: An interactive table summarizing key metrics for detailed exploration.

  • final_policy_model: The final fitted policy model based on model_type and learner_type or custom_fit.

Examples

# Example data
X_data <- matrix(rnorm(100 * 5), nrow = 100, ncol = 5)
D_data <- as.integer(sample(c(0, 1), 100, replace = TRUE))
Y_data <- rnorm(100)
nb_batch <- 5

# Perform CRAM policy
result <- cram_policy(X = X_data,
                          D = D_data,
                          Y = Y_data,
                          batch = nb_batch)

# Access results
result$raw_results
#>                        Metric    Value
#> 1              Delta Estimate  0.36852
#> 2        Delta Standard Error  0.26573
#> 3              Delta CI Lower -0.15230
#> 4              Delta CI Upper  0.88935
#> 5       Policy Value Estimate  0.53258
#> 6 Policy Value Standard Error  0.27544
#> 7       Policy Value CI Lower -0.00727
#> 8       Policy Value CI Upper  1.07243
#> 9          Proportion Treated  1.00000
result$interactive_table
result$final_policy_model #> GRF forest object of type causal_forest #> Number of trees: 100 #> Number of training samples: 100 #> Variable importance: #> 1 2 3 4 5 #> 0.089 0.187 0.142 0.124 0.160