This function performs the learning part of the Cram Policy method.
Usage
cram_learning(
X,
D,
Y,
batch,
model_type = "causal_forest",
learner_type = "ridge",
baseline_policy = NULL,
parallelize_batch = FALSE,
model_params = NULL,
custom_fit = NULL,
custom_predict = NULL,
n_cores = detectCores() - 1,
propensity = NULL
)
Arguments
- X
A matrix or data frame of covariates for each sample.
- D
A vector of binary treatment indicators (1 for treated, 0 for untreated).
- Y
A vector of outcome values for each sample.
- batch
Either an integer specifying the number of batches (which will be created by random sampling) or a vector of length equal to the sample size providing the batch assignment (index) for each individual in the sample.
- model_type
The model type for policy learning. Options include
"causal_forest"
,"s_learner"
, and"m_learner"
. Default is"causal_forest"
. Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.- learner_type
The learner type for the chosen model. Options include
"ridge"
for Ridge Regression,"fnn"
for Feedforward Neural Network and"caret"
for Caret. Default is"ridge"
. if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.- baseline_policy
A list providing the baseline policy (binary 0 or 1) for each sample. If
NULL
, the baseline policy defaults to a list of zeros with the same length as the number of rows inX
.- parallelize_batch
Logical. Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches. They are learned in parallel when parallelize_batch is TRUE vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training). Defaults to
FALSE
.- model_params
A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to
NULL
.- custom_fit
A custom, user-defined, function that outputs a fitted model given training data (allows flexibility). Defaults to
NULL
.- custom_predict
A custom, user-defined, function for making predictions given a fitted model and test data (allow flexibility). Defaults to
NULL
.- n_cores
Number of cores to use for parallelization when parallelize_batch is set to TRUE. Defaults to detectCores() - 1.
- propensity
The propensity score
Value
A list containing:
- final_policy_model
The final fitted policy model, depending on
model_type
andlearner_type
.- policies
A matrix of learned policies, where each column represents a batch's learned policy and the first column is the baseline policy.
- batch_indices
The indices for each batch, either as generated (if
batch
is an integer) or as provided by the user.