Skip to contents

This function performs the learning part of the Cram Policy method.

Usage

cram_learning(
  X,
  D,
  Y,
  batch,
  model_type = "causal_forest",
  learner_type = "ridge",
  baseline_policy = NULL,
  parallelize_batch = FALSE,
  model_params = NULL,
  custom_fit = NULL,
  custom_predict = NULL,
  n_cores = detectCores() - 1,
  propensity = NULL
)

Arguments

X

A matrix or data frame of covariates for each sample.

D

A vector of binary treatment indicators (1 for treated, 0 for untreated).

Y

A vector of outcome values for each sample.

batch

Either an integer specifying the number of batches (which will be created by random sampling) or a vector of length equal to the sample size providing the batch assignment (index) for each individual in the sample.

model_type

The model type for policy learning. Options include "causal_forest", "s_learner", and "m_learner". Default is "causal_forest". Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.

learner_type

The learner type for the chosen model. Options include "ridge" for Ridge Regression, "fnn" for Feedforward Neural Network and "caret" for Caret. Default is "ridge". if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.

baseline_policy

A list providing the baseline policy (binary 0 or 1) for each sample. If NULL, the baseline policy defaults to a list of zeros with the same length as the number of rows in X.

parallelize_batch

Logical. Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches. They are learned in parallel when parallelize_batch is TRUE vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training). Defaults to FALSE.

model_params

A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to NULL.

custom_fit

A custom, user-defined, function that outputs a fitted model given training data (allows flexibility). Defaults to NULL.

custom_predict

A custom, user-defined, function for making predictions given a fitted model and test data (allow flexibility). Defaults to NULL.

n_cores

Number of cores to use for parallelization when parallelize_batch is set to TRUE. Defaults to detectCores() - 1.

propensity

The propensity score

Value

A list containing:

final_policy_model

The final fitted policy model, depending on model_type and learner_type.

policies

A matrix of learned policies, where each column represents a batch's learned policy and the first column is the baseline policy.

batch_indices

The indices for each batch, either as generated (if batch is an integer) or as provided by the user.