Skip to contents

This function performs inference using a trained model, providing flexibility for different types of models such as Causal Forest, Ridge Regression, and Feedforward Neural Networks (FNNs).

Usage

model_predict(model, X, D, model_type, learner_type, model_params)

Arguments

model

A trained model object returned by the `fit_model` function.

X

A matrix or data frame of covariates for which predictions are required.

D

A vector of binary treatment indicators (1 for treated, 0 for untreated). Optional, depending on the model type.

model_type

The model type for policy learning. Options include "causal_forest", "s_learner", and "m_learner". Default is "causal_forest".

learner_type

The learner type for the chosen model. Options include "ridge" for Ridge Regression and "fnn" for Feedforward Neural Network. Default is "ridge".

model_params

A list of additional parameters to pass to the model, which can be any parameter defined in the model reference package. Defaults to NULL.

Value

A vector of binary policy assignments, depending on the model_type and learner_type.