This function maps the model type and learner type to the corresponding model function.
Arguments
- model_type
The model type for policy learning. Options include
"causal_forest"
,"s_learner"
, and"m_learner"
. Default is"causal_forest"
. Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.- learner_type
The learner type for the chosen model. Options include
"ridge"
for Ridge Regression,"fnn"
for Feedforward Neural Network and"caret"
for Caret. Default is"ridge"
. if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.- model_params
A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to
NULL
. For FNNs, the following elements are defined in the model params list:input_layer
A list defining the input layer. Must include:
units
Number of units in the input layer.
activation
Activation function for the input layer.
input_shape
Input shape for the layer.
layers
A list of lists, where each sublist specifies a hidden layer with:
units
Number of units in the layer.
activation
Activation function for the layer.
output_layer
A list defining the output layer. Must include:
units
Number of units in the output layer.
activation
Activation function for the output layer (e.g.,
"linear"
or"sigmoid"
).
compile_args
A list of arguments for compiling the model. Must include:
optimizer
Optimizer for training (e.g.,
"adam"
or"sgd"
).loss
Loss function (e.g.,
"mse"
or"binary_crossentropy"
).metrics
Optional list of metrics for evaluation (e.g.,
c("accuracy")
).
For other learners (e.g.,
"ridge"
or"causal_forest"
),model_params
can include relevant hyperparameters.