/**
* jQuery Plugin: Sticky Tabs
*
* @author Aidan Lister
// Set the correct tab when the page loads showStuffFromHash(context);
// Set the correct tab when a user uses their back/forward button $(window).on('hashchange', function() { showStuffFromHash(context); });
// Change the URL when tabs are clicked $('a', context).on('click', function(e) { history.pushState(null, null, this.href); showStuffFromHash(context); });
return this; }; }(jQuery));
window.buildTabsets = function(tocID) {
// build a tabset from a section div with the .tabset class function buildTabset(tabset) {
// check for fade and pills options var fade = tabset.hasClass("tabset-fade"); var pills = tabset.hasClass("tabset-pills"); var navClass = pills ? "nav-pills" : "nav-tabs";
// determine the heading level of the tabset and tabs var match = tabset.attr('class').match(/level(\d) /); if (match === null) return; var tabsetLevel = Number(match[1]); var tabLevel = tabsetLevel + 1;
// find all subheadings immediately below var tabs = tabset.find("div.section.level" + tabLevel); if (!tabs.length) return;
// create tablist and tab-content elements var tabList = $('
'); $(tabs[0]).before(tabList); var tabContent = $('
'); $(tabs[0]).before(tabContent);
// build the tabset var activeTab = 0; tabs.each(function(i) {
// get the tab div var tab = $(tabs[i]);
// get the id then sanitize it for use with bootstrap tabs var id = tab.attr('id');
// see if this is marked as the active tab if (tab.hasClass('active')) activeTab = i;
// remove any table of contents entries associated with // this ID (since we'll be removing the heading element) $("div#" + tocID + " li a[href='#" + id + "']").parent().remove();
// sanitize the id for use with bootstrap tabs id = id.replace(/[.\/?&!#<>]/g, '').replace(/\s/g, '_'); tab.attr('id', id);
// get the heading element within it, grab it's text, then remove it var heading = tab.find('h' + tabLevel + ':first'); var headingText = heading.html(); heading.remove();
// build and append the tab list item var a = $('' + headingText + ''); a.attr('href', '#' + id); a.attr('aria-controls', id); var li = $('
'); li.append(a); tabList.append(li);
// set it's attributes tab.attr('role', 'tabpanel'); tab.addClass('tab-pane'); tab.addClass('tabbed-pane'); if (fade) tab.addClass('fade');
// move it into the tab content div tab.detach().appendTo(tabContent); });
// set active tab $(tabList.children('li')[activeTab]).addClass('active'); var active = $(tabContent.children('div.section')[activeTab]); active.addClass('active'); if (fade) active.addClass('in');
if (tabset.hasClass("tabset-sticky")) tabset.rmarkdownStickyTabs(); }
// convert section divs with the .tabset class to tabsets var tabsets = $("div.section.tabset"); tabsets.each(function(i) { buildTabset($(tabsets[i])); }); };
Introduction to tidymodels: logistic regression in R
A step-by-step tutorial
In this blogpost, we will learn how to build a complete logistic regression workflow using the tidymodels framework in R. The goal is to understand not just what each step does, but what kind of object it creates and how those objects connect to one another. I’ll add links to other ML-related topics as we go, and there’s an “additional resources” section at the end of this blogpost with webpages and other blogs that really helped me understand how it all works.
So if you are ready… let’s dive in!
Logistic regression tutorial with tidymodels
Overview and introduction to the dataset
The aim of this tutorial is to build a logistic regression model using the tidymodels framework. If you’d like a simple explanation of logistic regression, you can check my other blogpost here.
We will predict whether a patient has diabetes using the PimaIndiansDiabetes2 dataset from the mlbench package.
This tutorial covers: 1) a simple model with just a train/test split 2) a cross-validated model, using k-fold (or v-fold) cross-validation.
The tutorial follows these stages:
- Load libraries and data
- Split data with rsample
- Preprocess with recipes
- Define a model with parsnip
- Define metrics with yardstick
- Bundle everything into a workflow
- Fit using
fit()/predict()vslast_fit()
0. Installing necessary packages
For this tutorial you will need R, or Rstudio, and you will need to install the following packages:
1. Load Libraries and Data
First, let’s set up our R environment and load the necessary packages.
## used (Mb) gc trigger (Mb) max used (Mb)
## Ncells 622146 33.3 1388091 74.2 781804 41.8
## Vcells 1169257 9.0 8388608 64.0 1876569 14.4
options(max.print = .Machine$integer.max, scipen = 999, stringsAsFactors = F, dplyr.summarise.inform = F)
# Loading relevant libraries
library(tidyverse)
library(tidymodels)
library(mlbench) # contains PimaIndiansDiabetes2
library(patchwork)
library(ggcorrplot)
tidymodels_prefer() # resolve function conflicts in favour of tidymodels
# Set seed for reproducibility
set.seed(123)
SquidTip! Always clean your environment and set a seed to ensure reproducible results.
Now, let’s load the dataset, (PimaIndiansDiabetes2). It has different clinical features from diabetic and non-diabetic patients. You can read more about it here.
data("PimaIndiansDiabetes2", package = "mlbench")
df <- PimaIndiansDiabetes2
# Quick look at the data
glimpse(df)
## Rows: 768
## Columns: 9
## $ pregnant <dbl> 6, 1, 8, 1, 0, 5, 3, 10, 2, 8, 4, 10, 10, 1, 5, 7, 0, 7, 1, 1…
## $ glucose <dbl> 148, 85, 183, 89, 137, 116, 78, 115, 197, 125, 110, 168, 139,…
## $ pressure <dbl> 72, 66, 64, 66, 40, 74, 50, NA, 70, 96, 92, 74, 80, 60, 72, N…
## $ triceps <dbl> 35, 29, NA, 23, 35, NA, 32, NA, 45, NA, NA, NA, NA, 23, 19, N…
## $ insulin <dbl> NA, NA, NA, 94, 168, NA, 88, NA, 543, NA, NA, NA, NA, 846, 17…
## $ mass <dbl> 33.6, 26.6, 23.3, 28.1, 43.1, 25.6, 31.0, 35.3, 30.5, NA, 37.…
## $ pedigree <dbl> 0.627, 0.351, 0.672, 0.167, 2.288, 0.201, 0.248, 0.134, 0.158…
## $ age <dbl> 50, 31, 32, 21, 33, 30, 26, 29, 53, 54, 30, 34, 57, 59, 51, 3…
## $ diabetes <fct> pos, neg, pos, neg, pos, neg, pos, neg, pos, pos, neg, pos, n…
The dataset contains 768 rows and 9 columns. The outcome variable is diabetes, a factor with levels "neg" and "pos". The eight predictors are numeric measurements (glucose, blood pressure, BMI, etc.) and some contain NA values.
The next step is quite important. Essentially, in logistic regression, we fit a model to classify samples into 2 groups: 0 or 1. In this case, our outcome variable is “diabetes” (pos/neg), which we need to relevel to a factor (0/1). Depending on what we want the event (1) to be, we can relevel it one way or another. The first level is the event in tidymodels.
If we build a model with “neg” being the reference, a positive coefficient actually means an increase in the odds of being healthy (negative for diabetes). That is usually the opposite of what a clinician wants! (And confusing for all)
So we’ll set “pos” as the reference. Once we build the model, a positive coefficient for a given feature means the feature increases the odds of being ‘1’ (diabetic). A negative coefficient would mean the feature decreases the odds of being ‘1’ (diabetic).
# tidymodels treats the FIRST factor level as the "event" (positive class)
# We want "pos" to be the event, so we relevel accordingly
df$diabetes <- factor(df$diabetes, levels = c('pos', 'neg'), labels = c('1', '0'))
levels(df$diabetes)
## [1] "1" "0"
SquidTip! In standard R (glm), the model usually picks the levels alphabetically. But in tidymodels, the first factor level is treated as the “event” (the thing you are trying to predict).
2. Resampling with rsample
2.a. Train test split
Now, it’s time to split our dataset into train and test. There’s another blogpost here where I go into more detail on the ways one can do this.
initial_split() creates an rsplit object that records which rows belong to training and which to testing.
It’s also important that we maintain a similar proportion of positive/negative cases in both sets, especially in smaller or unbalanced datasets were one feature is more predominant than the other. If you don’t, you might accidentally end up with a test set that contains zero positive cases, making it impossible to evaluate your model’s performance.
This is called stratification and we can stratify by 1 or more features. For example, if we had the information available, we might want to add “sex” as a stratifying variable, to make sure there’s similar proportions of male/female.
SquidTip! Note that the proportions don’t necessarily need to be near 50/50. One outcome could be much more common than another: there could be just 1% of positive cases for a rare disease. This is a bit of an extreme example, but with unbalanced datasets you may want to consider
- Down-sampling or Up-sampling: Using the themis package within your tidymodels recipe (e.g., step_downsample() or step_smote()) to balance the classes so the model doesn’t just learn to predict “negative” for everyone to achieve 99% accuracy.
- Alternative Metrics: Shifting your focus from Accuracy (which is misleading for rare events) to Precision-Recall curves or F1-scores, which better capture how well you are identifying that rare 1%.
In this case, we’ll pass the strata = diabetes to ensure that the proportion of positive/negative cases is maintained in both sets.
We’ll go with a 75-25 train-test split, but you can also use 80-20, or 70-30, depending on your sample size. The 75/25 split is a popular “middle ground” that ensures your training set is large enough to learn the underlying biological patterns while leaving enough data in the test set to calculate a reliable ROC AUC or accuracy score.
Don’t forget to set the random seed before running these functions for reproducibility.
# set.seed(123) " already did at the start of the script!!
data_split <- initial_split(df, prop = 0.75, strata = diabetes)
data_split # shows <Training/Testing/Total>
## <Training/Testing/Total>
## <576/192/768>
Have a look at the data_split object we just created. This object is not a flat dataframe but rather a nested list. We can extract the actual data frames with training() and testing():
train_data <- training(data_split)
test_data <- testing(data_split)
cat("Training rows:", nrow(train_data), "\n")
## Training rows: 576
## Testing rows: 192
## pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 4 1 89 66 23 94 28.1 0.167 21 0
## 6 5 116 74 NA NA 25.6 0.201 30 0
## 8 10 115 NA NA NA 35.3 0.134 29 0
## 11 4 110 92 NA NA 37.6 0.191 30 0
## 19 1 103 30 38 83 43.3 0.183 33 0
## 28 1 97 66 15 140 23.2 0.487 22 0
We can check the proportions to make sure the stratification worked properly:
# Train and test targets
targets_train <- train_data %>% select(diabetes)
targets_test <- test_data %>% select(diabetes)
print(prop.table(table(targets_train)) * 100)
## diabetes
## 1 0
## 34.89583 65.10417
## diabetes
## 1 0
## 34.89583 65.10417
## diabetes
## 1 0
## 201 375
## diabetes
## 1 0
## 67 125
We can also visualise it in a plot:
split_summary <- bind_rows(
targets_train %>% count(diabetes) %>% mutate(dataset = "Train"),
targets_test %>% count(diabetes) %>% mutate(dataset = "Test")) %>%
mutate(prop = n / sum(n), .by = dataset,
label = sprintf("n=%d\n(%.1f%%)", n, prop * 100),
diabetes = factor(diabetes, levels = c("1", "0"),
labels = c("Diabetic", "Non-diabetic")))
p1 <- ggplot(split_summary, aes(x = dataset, y = prop, fill = diabetes)) +
geom_col(position = "fill", colour = "black") +
geom_text(aes(label = label),
position = position_fill(vjust = 0.5),
colour = "white", fontface = "bold", size = 3.5) +
scale_fill_manual(values = c("Diabetic" = "aquamarine3", "Non-diabetic" = "pink3")) +
scale_y_continuous(labels = scales::percent) +
labs(x = "Dataset", y = "Proportion", fill = "Diabetes status") +
theme_bw(base_size = 14)
p1
As you can see, both train and test have a similar (identical in this case) proportion of values.
If you have any stratifying variables, I suggest you also look at the distribution of them in the train vs test datasets.
Now is also a good time to save your split results!
2.b. Cross validation
If you want to skip the multiple “folds” of cross-validation and just use a single Training and Validation split, you use validation_split(). This will enable you to build a simple logistic regression model without tuning:
# Create a single 80/20 split within your training data
train_val_split <- validation_split(train_data, prop = 0.8, strata = diabetes)
It’s faster but riskier - if your validation slice happens to have all the “easy” diabetes cases, your tuning results will be misleading.
As you may know already, cross-validation is a resampling technique used to evaluate how well a machine learning model will generalize to an independent, “unseen” dataset. Instead of relying on a single train-test split—which might be lucky or unlucky depending on which rows end up in which set; cross-validation repeats the process multiple times on different subsets of your data. You can read more about cross validation here, but these are the main advantages in a nutshell:
- Reduced Variance: It provides a much more stable and “honest” estimate of model performance than a single split.
- Efficiency: Every single observation in your training set gets to be part of an assessment set exactly once, making the most of your available data.
- Tuning: It is the gold standard for choosing hyperparameters (like your penalty and mixture). By testing your 30-combination grid against 5 or 10 different folds, you ensure the “best” settings aren’t just overfit to one specific slice of data.
Cross validation follows these steps:
- Split: Your training data is randomly partitioned into V equal-sized groups (folds).
- Iterate: The model is trained V times. Each time, one fold is held out as the “assessment” set (test), and the remaining V-1 folds are used for “analysis” (training).
- Average: You calculate the performance metric (like accuracy or RMSE) for each of the V iterations and average them together.
Let’s use 10-fold cross-validation: dividing our dataset into 10 folds.
# v-fold cross validation
v_folds <- 10
folds <- vfold_cv(train_data, v = v_folds, strata = 'diabetes')
cv <- ggplot(tidy(folds), aes(x = Fold, y = Row, fill = Data)) +
geom_tile() +
scale_fill_brewer() +
theme_bw(base_size = 14)
cv
Take a look at your folds object. It is an rset object, which is a collection of rsplit objects. We can check that our splits and folds are properly stratified.
# # Check whether stratified folds are stratified
# ex_split <- folds$splits[[5]]
# analysis(ex_split) %>% dim()
# assessment(ex_split) %>% dim()
# # If we want to get the specific row indices of the training set:
# ex_ind <- as.integer(ex_split, data = "assessment")
# table(train_data[ex_ind,'diabetes'])/ sum(table(train_data[ex_ind,'diabetes'])) * 100
# table(train_data[-ex_ind,'diabetes']) / sum(table(train_data[-ex_ind,'diabetes'])) * 100
# A clean way to get proportions for every fold
fold_checks <- folds %>%
mutate(props = map(splits, ~ {
assessment(.x) %>%
count(diabetes) %>%
mutate(prop = n / sum(n))
})) %>%
select(id, props) %>%
unnest(props)
# Take a quick look at the table
# print(fold_checks)
# Plot proportions across folds — should be near-identical if stratification worked
ggplot(fold_checks, aes(x = id, y = prop, fill = diabetes)) +
geom_col(position = "stack", colour = "black", width = 0.6) +
scale_fill_manual(values = c("1" = "aquamarine3", "0" = "pink3"),
labels = c("Diabetic", "Non-diabetic")) +
scale_y_continuous(labels = scales::percent) +
labs(x = "Fold", y = "Proportion", fill = "Diabetes",
title = "Class balance per CV fold") +
theme_bw(base_size = 13) +
theme(axis.text.x = element_text(angle = 45, hjust = 1))
Nice! Before we move on to model building, let’s do a couple of QC checks.
2.c. Preprocessing checks
Logistic regression models do not deal well with highly correlated variables and features with very little variance. These will be removed during model building if you haven’t done already during the data preprocessing steps, before resampling.
This is only a simple dataset with very few features, so we’re just going to do a few sanity checks. The best practice is to put these steps in the recipe (check step 3 of this tutorial) - the recipe will calculate nzv features and highly correlated features based on the training set and apply them to the test set, avoiding data leakage.
In the following steps, I am going to use the training set to check which features are problematic. They are useful as an exploratory sanity check on the train_data before we build the recipe, but the recipe is the thing that actually enforces the removal.
You can read more about preprocessing steps and data leakage in this other blogpost: feature preprocessing for ML (coming up soon!).
2.d. Check for nzv
Checking for near zero variance is a critical preprocessing step because predictors with only one unique value (zero variance) or very few unique values relative to the sample size (near-zero variance) provide almost no information to the model. In logistic regression, these “constant” variables can cause the math to break down, leading to unstable coefficient estimates or even preventing the model from converging.
Essentially, if every patient in your dataset has the exact same blood pressure reading, that variable cannot help you predict who has diabetes.
Before building the model, you can use caret::nearZeroVar() to see which columns are problematic. You may need to install the caret package beforehand. You can tweak the parameters freqCut and uniqueCut to define exactly how “boring” a variable must be before it gets tossed out. By default, step_nzv() uses a freqCut of 95/5 (meaning if the most common value is 19 times more frequent than the second most common, it’s a candidate for removal) and a uniqueCut of 10% (meaning the number of unique values must be less than 10% of the total sample size). Run ?nearZeroVar to find out more.
# This returns the indices of columns that are near-zero variance
nzv_cols <- caret::nearZeroVar(train_data, saveMetrics = TRUE)
# View the report
nzv_cols %>% filter(nzv == TRUE)
## [1] freqRatio percentUnique zeroVar nzv
## <0 rows> (or 0-length row.names)
No near zero variance features, easy peasy! We don’t need to worry about this. If your dataset does have nzv, I would recommend removing them before you run step 1 of this tutorial, as part of the preprocessing. Check out this other tutorial for more on feature preprocessing for ML (coming up soon!).
3. Preprocessing with recipes
Great! Now we’re ready to start building the model!
The Recipe Workflow: Prep, Bake, Juice
In the tidymodels ecosystem, the recipes package is where the “feature engineering” magic happens. Think of a recipe as a blueprint for transforming your raw data into a format that a machine learning model can actually digest. Instead of manually scaling or encoding variables every time you run a model, you define a sequence of steps that are applied consistently across training and testing data. This also avoids data leakage (coming up soon!), because stats are calculated only on training data.
This is the link to the package recipes: recipes.tidymodels.org
To understand how it works, it’s helpful to view it as a three-stage process:
- Define (recipe()):: You specify the formula (outcome ~ predictors) and the data.
- Estimate (prep()): The recipe looks at the training data to calculate necessary statistics (like the mean for centering or the variance for scaling).
- Apply (bake()): You apply those calculated statistics to any dataset (like your test set) to transform it.
Recipes allow for a really clean syntax, and can also be saved and applied to any new data - so if you want to try out different combinations of features before deciding on a final model, recipes are for you!
A typical recipe is built using a series of “steps.” Here are the most common ones:
-
Handling Categorical Data:
step_dummy()converts nominal variables into numeric “dummy” variables (one-hot encoding). -
Normalization:
step_center()andstep_scale()ensure all numeric variables have a mean of zero and a standard deviation of one. -
Imputation:
step_impute_mean()orstep_impute_knn()fill in missing values based on the training data. -
Transformations:
step_log()orstep_BoxCox()help manage skewed data. -
Dimension Reduction:
step_pca()can collapse many predictors into a few principal components.
SquidTip! Using a recipe ensures scientific integrity by preventing data leakage, as it calculates preprocessing statistics solely on your training data and applies them consistently to new data. It also transforms your feature engineering into a reusable object, allowing you to bundle it with a model into a single, production-ready workflow. This automation eliminates the need for manual, error-prone scripts every time you want to cross-validate or predict on a single new observation.
It just makes it easier and cleaner than preprocessing it manually!
Ok! So how do we actually code this?
The recipe() function takes a formula and the training data, then we chain feature engineering steps using step_*() functions. Let’s add 3 simple steps:
-
step_impute_median()— replacesNAwith the median of each column. This is a very simpole approach, there are other methods for missing data imputation -
step_normalize()— centres (mean = 0) and scales (sd = 1) numeric predictors. Normalising and scaling your data is vital because many machine learning algorithms are sensitive to the scale of your input data. If one variable ranges from 0 to 1 (like an age ratio) and another ranges from 0 to 1,000,000 (like annual income), the model may incorrectly treat the larger numbers as more “important” simply because of their magnitude. -
step_zv()— removes any predictor with zero variance (which we already checked for, and we know it’s not a problem, but good to leave it in if we ever want to reuse the code!)
You might also want to check out my recipe in this other tutorial: Advanced tidymodels tutorial (coming up soon!). There, I also add steps to handle class imbalance (SMOTE, downsampling…).
SquidTip! Note the order matters: 1. Impute first — step_corr needs complete data to compute correlations. If you have dummy variables (categorical), make sure you are imputing them or taking care of NAs properly! 2. NZV before correlation — no point computing correlations for near-constant columns 3. Normalize last — centering/scaling doesn’t affect the correlation structure but should happen after you’ve decided what to keep
rec <- recipe(diabetes ~ ., data = train_data) %>%
step_impute_median(all_numeric_predictors()) %>% # handle NAs
step_nzv(all_predictors()) %>% # removes near-zero variance
step_corr(all_numeric_predictors(), threshold = 0.8) %>% # removes highly correlated
step_normalize(all_numeric_predictors()) %>% # centre and scale
step_zv(all_predictors()) # remove zero-variance cols
rec # prints a summary — not yet "trained" (prepped)
As we mentioned, NZV and correlation filtering should be included in your recipe using step_nzv() and step_corr(), so they are computed only on the training data and applied consistently across folds. This keeps the pipeline leak-free and ensures you can distinguish features removed for numerical reasons from those removed by the model for lack of predictive value.
SquidTip! Use the selectors like all_numeric_predictors() or as_type(“nominal”) instead of naming every column. It makes your code much more robust if your data schema changes slightly.
4. Defining the Model with parsnip
Nice! So now that the preprocessing steps are finished (we haven’t applied them yet, but we have set the recipe!), we can start building our model. The package parsnip provides a unified interface to many modelling engines. Here we specify logistic regression solved with the "glm" engine (base R’s stats::glm).
If you are just using the full train (+validation) dataset (so no cross validation), then building a model is super easy:
# Option A: Simple logistic regression (no tuning, uses base R glm)
lr_simple <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")
lr_simple
## Logistic Regression Model Specification (classification)
##
## Computational engine: glm
As we mentioned, cross validation allows us to tune the model’s parameters by training and testing using different subsets of data (folds). Let’s set a model that allows us to do that. It’s very easy to do with tidymodels!
# Option B: Regularised logistic regression with hyperparameter tuning
# penalty = A non-negative number representing the total amount of regularization
# mixture = proportion of regularisation. 0 = Ridge, 1 = Lasso, between = Elastic Net
lr_tuned <- logistic_reg(
mode = "classification",
penalty = tune(),
mixture = tune()
) %>%
set_engine("glmnet")
# Define the search space for hyperparameters
glmnet_params <- parameters(
penalty(range = c(-3, 2), trans = log10_trans()), # Higher penalties to prevent overfitting (stronger regularization)
mixture() # tests ridge (0) through elastic net to lasso (1)
)
# Space-filling design: 30 combinations spread across the parameter space
# More efficient than a regular grid because it avoids clustering
glmnet_grid <- grid_space_filling(glmnet_params, size = 30)
The model spec (logistic_reg) defines the structure of the model without looking at data yet. We are building a regularized Logistic Regression classifier but leaving the “tuning knobs” (penalty and mixture) blank with tune(). This tells tidymodels to find the best values later rather than picking them now.
The parameters set the boundaries for those knobs. The penalty is set on a log_{10} scale (ranging from 0.001 to 100) to test both weak and aggressive regularization, while mixture will test everything from Ridge (0) to Lasso (1). If you are not too familiar with logistic regression, I highly recommend the tutorials from Josh Starmer’s Statquest.
grid_space_filling creates 30 diverse combinations of these two settings. Instead of a simple uniform grid, it uses a mathematical design to spread the points out as much as possible, ensuring you cover the widest range of “model behaviors” with fewer tests.
At this point, no data has been touched. The model spec just describes what to fit.
5. Metrics with yardstick
metric_set() bundles multiple performance metrics together. We include metrics that require predicted probabilities (like roc_auc) as well as hard-class metrics (like accuracy, sensitivity, specificity). Check out my other blogpost covering ML performance metrics and how to interpret them (coming up soon!).
model_metrics <- metric_set(
roc_auc,
accuracy,
sensitivity,
specificity,
j_index # Youden's J = sensitivity + specificity - 1
)
Nice! Almost there!
6. Bundling with workflows
Now it’s time to introduce workflows.
In the tidymodels ecosystem, a workflow combines your recipe (preprocessing) and your model (algorithm) into a single, portable object. Think of it as a specialized “wrapper” that ensures your data transformations and your model training always stay perfectly synced.
The main goal of a workflow is to simplify the management of the machine learning pipeline. Without it, you have to manually prep and bake your recipes and then pass the results into a model fitting function. With a workflow, you simply call fit() or predict(), and the package handles the internal data flow for you.
SquidTip! A workflow makes it easy to swap out recipes or models and ensures they are always applied together consistently.
You can also try out several models (random forest, SVC…) at once!
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: logistic_reg()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
##
## • step_impute_median()
## • step_nzv()
## • step_corr()
## • step_normalize()
## • step_zv()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Logistic Regression Model Specification (classification)
##
## Main Arguments:
## penalty = tune()
## mixture = tune()
##
## Computational engine: glmnet
7. Fitting a model
7a. Training a model
First, let’s train a model using 1) train-validation split and 2) cross-validation.
If we wanted to train on the training set and automatically validate on the held-out (validation) set, we would use:
# This trains on 'analysis' and scores on 'assessment' automatically
validation_results <- fit_resamples(mywf, resamples = train_val_split)
Important! mywf should be based on the simple glmnet_model we created without the tune() steps. Here, we are not tuning parameters on each of the folds, we are just fitting the model once.
We can collect the performance metrics based on the validation set using collect_metrics(validation_results). These are more accurate than the performance metrics from a fit() object. While you can do it, those metrics are usually fake/biased because the model is “testing” itself on data it already saw.To get an honest score, you always look at the output of fit_resamples() or even better, tune_grid().
A better option is to use a cross-validated dataset to find the optimal hyperparameters. (An even better option would be to use nested cross-validation).
# Run the grid search
start_time <- Sys.time()
# Optional: register a parallel backend to speed this up
# library(doParallel)
# registerDoParallel(cores = parallel::detectCores(logical = FALSE) - 1)
tune_results <- tune_grid(mywf, # Your workflow containing the model + recipe
resamples = folds,
grid = glmnet_grid,
metrics = model_metrics,
control = control_grid(save_pred = TRUE))
end_time <- Sys.time()
# Calculate and print the difference
print(end_time - start_time)
## Time difference of 36.98848 secs
Here is exactly what happens under the hood:
- The Workflow (mywf): It takes your blueprint (the recipe + the glmnet model spec).
- The Resamples (folds): Instead of running once, it repeats the process for every fold you created (e.g., if you have 5 folds, it runs everything 5 times).
- The Grid (glmnet_grid): It looks at your 30 different combinations of “knobs” (hyperparameters).
-
The Execution: It trains \(30 \text{ (combinations)} \times 5 \text{ (folds)} = 150\) total models. For each of the 150 runs, it:
- Preps the recipe on the training folds.
- Trains the model with a specific combination.
- Evaluates performance on the “held-out” fold.
- The Control (save_pred = TRUE): This tells R to keep the actual predictions (the “guesses” the model made) for every fold. This is incredibly useful later if you want to build a ROC Curve or a Confusion Matrix to see where the model is making mistakes.
Nice! Let’s see how our model performed across the folds:
p1 <- autoplot(tune_results, rank_metric = "j_index", select_best = TRUE) +
theme_bw(14) +
theme(strip.background = element_rect(fill = 'white', colour = 'black')) +
geom_point(size = 2, col = 'dodgerblue')
p1
This autoplot() output is our “performance map” for the Glmnet model. It visualizes how our two hyperparameters — Penalty (amount of regularization) and Mixture (proportion of Lasso Penalty) — affect our model’s ability to predict correctly across five different metrics.
Let’s see…
The sharp drop-off in the left-hand plots indicates that our model has a narrow window of effective regularization. Between 10^{-3} and 10^{-1.5}, the model maintains a strong balance, but as the penalty increases beyond 0.1 (10^{-1}), it likely becomes “null,” predicting only the majority class. This explains why specificity hits 1.0 while sensitivity and j_index crash to zero; the model is simply playing it safe by never predicting a positive case. Our best model is likely located where the penalty is between -3 and -1.5. In this range, we maximize accuracy, J-index, and sensitivity without sacrificing too much specificity.
Because the Proportion of Lasso Penalty plots show no clear trend or “curve” (points are all over the place), the mixture parameter doesn’t matter nearly as much as the amount of regularization for this dataset.
We could likely prune our search grid if we were to redo this - we don’t need to test any penalty values higher than 0.1 in the future, as they clearly degrade the model. So to refine our model, we should probably zoom in on the amount of regularization by creating a narrower grid between -4 and -1 on the log_{10} scale. This will allow us to pinpoint the exact transition where the model begins to lose predictive power and ensure we aren’t over-regularizing our features into extinction.
If we’re happy with the model’s training performance, we can now retrain it on the full training set (train+validation) with the optimal parameters. You can also choose your best parameters based on J index, ROC AUC, or whichever metric is more important for you to optimise. If you are going with the simple train-validation split, use select_best(validation_results, metric = “roc_auc”).
Our original workflow, mywf, still contains tune() placeholders. finalize_workflow() replaces those placeholders with the best values. This will convert mywf into a fully specified workflow with fixed hyperparameters.
# This selects the best parameters based on the j_index
best_params <- select_best(tune_results, metric = "j_index")
final_wf <- finalize_workflow(mywf, best_params)
We can now use the final, optimal model to on the entire training dataset. There are two main ways we can do this. They should produce the same predictions, but they create very different objects and offer different convenience functions:
-
Using fit() + predict()
-
Using last_fit()
Let’s check both of them!
7b. Evaluating a model: fit() / predict()
When we use fit(final_wf, data = train_data), the result is a trained workflow object. We can then call predict() on the test data.
Nice! Now we can get hard-class predictions on the test set:
## # A tibble: 6 × 1
## .pred_class
## <fct>
## 1 0
## 2 1
## 3 1
## 4 1
## 5 1
## 6 0
Get predicted probabilities by passing type = "prob":
## # A tibble: 6 × 2
## .pred_1 .pred_0
## <dbl> <dbl>
## 1 0.0460 0.954
## 2 0.903 0.0970
## 3 0.849 0.151
## 4 0.741 0.259
## 5 0.661 0.339
## 6 0.448 0.552
predict() returns a simple tibble — easy to work with. Now bind the predictions to the test data to evaluate performance:
test_results <- test_data %>%
select(diabetes) %>%
bind_cols(pred_class, pred_prob)
head(test_results)
## diabetes .pred_class .pred_1 .pred_0
## 2 0 0 0.04604872 0.95395128
## 5 1 1 0.90303884 0.09696116
## 9 1 1 0.84885238 0.15114762
## 13 0 1 0.74074093 0.25925907
## 15 1 1 0.66070270 0.33929730
## 17 1 0 0.44793162 0.55206838
Calculate metrics for the test set, we can also compare them to the training set:
test_metrics <- model_metrics(test_results,
truth = diabetes,
estimate = .pred_class,
.pred_1) # column name for the "event" class probability
head(test_metrics)
## # A tibble: 5 × 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy binary 0.771
## 2 sensitivity binary 0.597
## 3 specificity binary 0.864
## 4 j_index binary 0.461
## 5 roc_auc binary 0.854
# Collect cross-validation metrics. We'll isolate the single best version of the model from the tuning process
cv_metrics <- collect_metrics(tune_results) %>%
# This semi_join acts like a filter that keeps only the winning 'mixture' and 'penalty'
semi_join(best_params, by = c("penalty", "mixture")) %>% # filter without adding cols
dplyr::filter(.metric %in% c("accuracy", "roc_auc", "sensitivity", "specificity", "j_index")) %>%
dplyr::select(.metric, mean, std_err) %>%
dplyr::mutate(dataset = "Training (CV)")
# Test metrics usually don’t have fold variability, so std_err is NA.
test_metrics <- test_metrics %>%
dplyr::select(.metric, .estimate) %>%
dplyr::rename(mean = .estimate) %>%
dplyr::mutate(std_err = NA, dataset = "Test")
all_metrics <- dplyr::bind_rows(cv_metrics, test_metrics)
ggplot(all_metrics, aes(x = .metric, y = mean, colour = dataset)) +
geom_point(position = position_dodge(width = 0.5), size = 3) +
geom_errorbar(aes(ymin = mean - std_err, ymax = mean + std_err),
position = position_dodge(width = 0.5), width = 0.2) +
scale_colour_manual(values = c("Training (CV)" = "royalblue", "Test" = "firebrick")) +
scale_y_continuous(limits = c(0, 1)) +
theme_bw(14)
When comparing training (via v-fold cross-validation) and test set performance, the goal is to check generalization. The test set simulates unseen data, so metrics should be similar but usually slightly worse than training. The reason is that the model was trained using the training folds, so it fits that data a bit better. With cross-validation, your training results already include uncertainty (std_err), so ideally the test metric should fall within the CV error bars. If it does, your model is behaving as expected. If the test performance results are way worse than training, the model is probably overfitted, if they are way betterm the model is probably underfitted or there is data leakage.
Youden’s index combines sensitivity and specificity - a large drop means the model does not generalize well.
Check out my other blogpost covering ML performance metrics and how to interpret them (coming up soon!).
Inspecting the Model Coefficients
broom::tidy() extracts coefficients and p-values into a tibble:
## # A tibble: 9 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 0.844 0.00221
## 2 pregnant -0.329 0.00221
## 3 glucose -1.10 0.00221
## 4 pressure 0.177 0.00221
## 5 triceps -0.0636 0.00221
## 6 insulin 0.0284 0.00221
## 7 mass -0.622 0.00221
## 8 pedigree -0.246 0.00221
## 9 age -0.231 0.00221
broom::glance() gives overall model statistics:
## # A tibble: 1 × 3
## nulldev npasses nobs
## <dbl> <int> <int>
## 1 745. 386 576
ROC Curve
A Confusion Matrix is a table that summarizes a model’s performance by comparing actual versus predicted classifications, showing exactly where the model succeeded (True Positives/Negatives) and where it stumbled (False Positives/Negatives). While the matrix provides a snapshot of performance at a specific decision threshold, the ROC (Receiver Operating Characteristic) Curve offers a broader view by plotting the True Positive Rate against the False Positive Rate across all possible thresholds. Essentially, the Confusion Matrix tells you how many errors you made on a specific “pass/fail” cutoff (the default is 0.5, meaning that if a predicted probability > 0.5 - then it’s positive, if it’s < 0.5, it’s classified as negative), while the ROC curve illustrates the model’s fundamental ability to distinguish between classes - the further the curve arches toward the top-left corner, the better the model is at separating diabetic from non-diabetic regardless of the threshold you choose.
# Only plot the test ROC curve
p_roc <- test_results %>%
roc_curve(truth = diabetes, .pred_1) %>%
autoplot()
# Confusion matrix - Test set
test_conf_mat <- test_results %>%
conf_mat(truth = diabetes, estimate = .pred_class)
p_cm <- autoplot(test_conf_mat, type = "heatmap") +
theme_bw(14)
# Combine them
p_roc | p_cm
Nice! There’s more useful checks and plots we can look at but let’s explore them using the second option to fit a model, last_fit().
7c. Evaluating a model:last_fit()
last_fit() is described in the tidymodels docs as emulating the process of taking a finalised model, fitting it on the entire training set, and evaluating it on the test set — all in one step.
You pass the untrained workflow with the finalised model (so after running select_best() and finalize_workflow()) and the data_split object (which holds the train/test partition information):
## # Resampling results
## # Manual resampling
## # A tibble: 1 × 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [576/192]> train/test split <tibble> <tibble> <tibble> <workflow>
The result looks like a one-row tibble. It comes from the tune package and works with tune helper functions.
SquidTip!
So, what’s the difference between fit() + predict() and last_fit()?
With fit() and predict(), we manually train the model, generate predictions, and compute metrics. It’s more flexible and works for any dataset. However, it’s more manual - this makes it easier to accidentally leak test data. Moreover, it doesn’t integrate directly with rsample splits
last_fit() automates the final model training and evaluation using the original train/test split. It:
-
Fits the finalized workflow on the training set
-
Predicts on the test set
-
Returns predictions + metrics
Collect Metrics
## # A tibble: 5 × 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 accuracy binary 0.771 pre0_mod0_post0
## 2 sensitivity binary 0.597 pre0_mod0_post0
## 3 specificity binary 0.864 pre0_mod0_post0
## 4 j_index binary 0.461 pre0_mod0_post0
## 5 roc_auc binary 0.854 pre0_mod0_post0
Collect Predictions
Use collect_predictions() instead of indexing into the list directly:
## # A tibble: 6 × 7
## .pred_class .pred_1 .pred_0 id diabetes .row .config
## <fct> <dbl> <dbl> <chr> <fct> <int> <chr>
## 1 0 0.0460 0.954 train/test split 0 2 pre0_mod0_post0
## 2 1 0.903 0.0970 train/test split 1 5 pre0_mod0_post0
## 3 1 0.849 0.151 train/test split 1 9 pre0_mod0_post0
## 4 1 0.741 0.259 train/test split 0 13 pre0_mod0_post0
## 5 1 0.661 0.339 train/test split 1 15 pre0_mod0_post0
## 6 0 0.448 0.552 train/test split 1 17 pre0_mod0_post0
Notice that last_fit() includes the truth column automatically — handy for immediate evaluation without manually joining to the test set.
ROC Curve from last_fit
This time, we’ll go one step further and add the mean ROC across folds.
# lf_preds %>%
# roc_curve(truth = diabetes, .pred_1) %>%
# autoplot()
# 1. Get predictions from all CV folds for best hyperparameters
cv_predictions <- tune_results %>%
collect_predictions() %>%
filter(penalty == best_params$penalty,
mixture == best_params$mixture)
# 2. Extract ROC data for CV (Training)
roc_cv_data <- cv_predictions %>%
roc_curve(truth = diabetes, .pred_1) %>%
mutate(dataset = "Training (CV)")
# 3. Extract ROC data for Final Fit (Test)
roc_test_data <- final_fit %>%
collect_predictions() %>%
roc_curve(truth = diabetes, .pred_1) %>%
mutate(dataset = "Test")
combined_roc <- bind_rows(roc_cv_data, roc_test_data)
# 4. Compare train (CV) vs test metrics
metrics_comparison <- bind_rows(
cv_metrics %>%
select(.metric, mean, std_err) %>%
mutate(dataset = "Training (CV)"),
test_metrics %>%
select(.metric, .estimate) %>%
rename(mean = .estimate) %>%
mutate(std_err = NA, dataset = "Test")
) %>%
select(dataset, .metric, mean, std_err)
# 5. Extract AUC values for the label
auc_labels <- metrics_comparison %>%
filter(.metric == "roc_auc") %>%
mutate(label = paste0(
ifelse(dataset == "Training (CV)", "AUC (train)", "AUC (test)"), " = ", round(mean, 3))) %>%
summarise(label = paste(label, collapse = "\n")) %>%
pull(label)
# ROC
p1 <- ggplot(combined_roc, aes(x = 1 - specificity, y = sensitivity, color = dataset)) +
geom_path(linewidth = 1) +
geom_abline(lty = 3, color = "gray50") + # Diagonal reference line
coord_equal() +
ggplot2::annotate("text", x = 1, y = 0.05, # bottom right (coord_equal keeps x/y in 0-1)
label = auc_labels, hjust = 1, vjust = 0, size = 4,
fontface = "italic", color = "black") +
scale_color_manual(values = c("Training (CV)" = "royalblue", "Test" = "firebrick")) +
labs(color = "") +
theme_bw(14)
p1
Final model coefficients, odds ratios and pretty plots
extract_fit_parsnip() pulls out the underlying parsnip object (in this case, our glmnet trained model) so you can call tidy() and glance() just as with fit():
The resulting table usually contains these columns:
-
term: The name of the predictor (e.g., glucose, bmi, age).
-
estimate: The weight or coefficient. For a logistic regression, this tells you the direction and strength of the relationship with the outcome.
-
penalty: (if using glmnet) The specific regularization value used.
## # A tibble: 6 × 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 0.844 0.00221
## 2 pregnant -0.329 0.00221
## 3 glucose -1.10 0.00221
## 4 pressure 0.177 0.00221
## 5 triceps -0.0636 0.00221
## 6 insulin 0.0284 0.00221
Let’s extract and plot the coefficients. The odds ratios are just the exponential!
Note that because we are using glmnet, we cannot calculate standard errors. This is because the “shrinkage” used in Lasso/Elastic Net biases the estimates. Standard statistical theory for p-values and standard errors doesn’t apply the same way it does for a “normal” glm model.
# Get coefficients (log odds)
coef_df <- fitted_model %>%
filter(term != "(Intercept)", estimate != 0) %>%
arrange(desc(abs(estimate)))
# Get odds ratios
or_df <- coef_df %>%
mutate(odds_ratio = exp(estimate)) %>%
filter(odds_ratio != 1)
## Coefficient plot ----------
p_coef <- ggplot(coef_df, aes(estimate, fct_reorder(term, estimate))) +
geom_vline(xintercept = 0, colour = "gray50", lty = 2, linewidth = 1) +
geom_segment(aes(x = 0, xend = estimate,
y = fct_reorder(term, estimate),
yend = fct_reorder(term, estimate)),
colour = "gray70", linewidth = 1.2) + # linewidth replaces deprecated size
geom_point(size = 3, colour = "#85144B") +
labs(y = NULL, x = "Coefficient (log-odds)") +
theme_bw(14)
# p_coef
## Odds ratios -------
p_or <- ggplot(or_df, aes(odds_ratio, fct_reorder(term, odds_ratio))) +
geom_vline(xintercept = 1, colour = "gray50", lty = 2, linewidth = 1) +
geom_segment(aes(x = 1, xend = odds_ratio,
y = fct_reorder(term, odds_ratio),
yend = fct_reorder(term, odds_ratio)),
colour = "gray70", linewidth = 1.2) +
geom_point(size = 3, colour = "#85144B") +
scale_x_log10() +
labs(y = NULL, x = "Odds ratio (log scale)") +
theme_bw(14)
# p_or
p_coef + p_or + plot_layout(ncol = 2)
Nice!
Session Info
Squidtastic! And that’s the end of this tutorial.
In this post, we covered a beginner’s tutorial into logistic regression with tidymodels in R. Hope you found it useful!
I’m leaving some useful resources below, as well as some additional blogposts with ML topics you might want to dive into.
Additional resources
You might be interested in…
- Credit Card Fraud: A Tidymodels Tutorial. Really useful step by step tutorial which covers a classification problem from importing the data, cleaning, exploring, fitting, choosing a model, and finalizing the model.
- Handling class imbalance with tidymodels
- Making Predictions from Cross-Validated Workflow Using tidymodels
- A gentle introduction to SHAP values in R
Squidtastic! You made it till the end! Hope you found this post useful.
If you have any questions, or if there are any more topics you would like to see here, leave me a comment down below.
Otherwise, have a very nice day and… see you in the next one!
Before you go, you might want to check:
// add bootstrap table styles to pandoc tables function bootstrapStylePandocTables() { $('tr.odd').parent('tbody').parent('table').addClass('table table-condensed'); } $(document).ready(function () { bootstrapStylePandocTables(); });