This function maps the model type and learner type to the corresponding model function.
Arguments
- model_type
The model type for policy learning. Options include
"causal_forest","s_learner", and"m_learner". Default is"causal_forest". Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.- learner_type
The learner type for the chosen model. Options include
"ridge"for Ridge Regression,"fnn"for Feedforward Neural Network and"caret"for Caret. Default is"ridge". if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.- model_params
A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to
NULL. For FNNs, the following elements are defined in the model params list:input_layerA list defining the input layer. Must include:
unitsNumber of units in the input layer.
activationActivation function for the input layer.
input_shapeInput shape for the layer.
layersA list of lists, where each sublist specifies a hidden layer with:
unitsNumber of units in the layer.
activationActivation function for the layer.
output_layerA list defining the output layer. Must include:
unitsNumber of units in the output layer.
activationActivation function for the output layer (e.g.,
"linear"or"sigmoid").
compile_argsA list of arguments for compiling the model. Must include:
optimizerOptimizer for training (e.g.,
"adam"or"sgd").lossLoss function (e.g.,
"mse"or"binary_crossentropy").metricsOptional list of metrics for evaluation (e.g.,
c("accuracy")).
For other learners (e.g.,
"ridge"or"causal_forest"),model_paramscan include relevant hyperparameters.