Define cross-validation scheme and training parameters
Usage
define_cv(
train_data,
outcome_colname,
hyperparams_list,
perf_metric_function,
class_probs,
kfold = 5,
cv_times = 100,
groups = NULL,
group_partitions = NULL
)
Arguments
- train_data
Dataframe for training model.
- outcome_colname
Column name as a string of the outcome variable (default
NULL
; the first column will be chosen automatically).- hyperparams_list
Named list of lists of hyperparameters.
- perf_metric_function
Function to calculate the performance metric to be used for cross-validation and test performance. Some functions are provided by caret (see
caret::defaultSummary()
). Defaults: binary classification =twoClassSummary
, multi-class classification =multiClassSummary
, regression =defaultSummary
.- class_probs
Whether to use class probabilities (TRUE for categorical outcomes, FALSE for numeric outcomes).
- kfold
Fold number for k-fold cross-validation (default:
5
).- cv_times
Number of cross-validation partitions to create (default:
100
).- groups
Vector of groups to keep together when splitting the data into train and test sets. If the number of groups in the training set is larger than
kfold
, the groups will also be kept together for cross-validation. Length matches the number of rows in the dataset (default:NULL
).- group_partitions
Specify how to assign
groups
to the training and testing partitions (default:NULL
). Ifgroups
specifies that some samples belong to group"A"
and some belong to group"B"
, then settinggroup_partitions = list(train = c("A", "B"), test = c("B"))
will result in all samples from group"A"
being placed in the training set, some samples from"B"
also in the training set, and the remaining samples from"B"
in the testing set. The partition sizes will be as close totraining_frac
as possible. If the number of groups in the training set is larger thankfold
, the groups will also be kept together for cross-validation.
Examples
training_inds <- get_partition_indices(otu_small %>% dplyr::pull("dx"),
training_frac = 0.8,
groups = NULL
)
train_data <- otu_small[training_inds, ]
test_data <- otu_small[-training_inds, ]
cv <- define_cv(train_data,
outcome_colname = "dx",
hyperparams_list = get_hyperparams_list(otu_small, "glmnet"),
perf_metric_function = caret::multiClassSummary,
class_probs = TRUE,
kfold = 5
)