Skip to contents

This function maps the model type and learner type to the corresponding model function.

Usage

set_model(model_type, learner_type, model_params)

Arguments

model_type

The model type for policy learning. Options include "causal_forest", "s_learner", and "m_learner". Default is "causal_forest". Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model.

learner_type

The learner type for the chosen model. Options include "ridge" for Ridge Regression, "fnn" for Feedforward Neural Network and "caret" for Caret. Default is "ridge". if model_type is 'causal_forest', choose NULL, if model_type is 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'.

model_params

A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to NULL. For FNNs, the following elements are defined in the model params list:

input_layer

A list defining the input layer. Must include:

units

Number of units in the input layer.

activation

Activation function for the input layer.

input_shape

Input shape for the layer.

layers

A list of lists, where each sublist specifies a hidden layer with:

units

Number of units in the layer.

activation

Activation function for the layer.

output_layer

A list defining the output layer. Must include:

units

Number of units in the output layer.

activation

Activation function for the output layer (e.g., "linear" or "sigmoid").

compile_args

A list of arguments for compiling the model. Must include:

optimizer

Optimizer for training (e.g., "adam" or "sgd").

loss

Loss function (e.g., "mse" or "binary_crossentropy").

metrics

Optional list of metrics for evaluation (e.g., c("accuracy")).

For other learners (e.g., "ridge" or "causal_forest"), model_params can include relevant hyperparameters.

Value

The instantiated model object or the corresponding model function.