# install.packages(c("torch", "tidyverse", "corrplot"))
library(torch)
library(tidyverse)
Multi-Task Learning with torch in R
Multi-task learning (MTL) is an approach where a single neural network model is trained to perform multiple related tasks simultaneously. This methodology can improve model generalization, reduce overfitting, and leverage shared information across tasks. This post explores how to implement a multi-task learning model using the torch
package in R.
Introduction
Multi-task learning operates by sharing representations between related tasks, enabling models to generalize more effectively. Instead of training separate models for each task, this approach develops a single model with:
- Shared layers that learn common features across tasks
- Task-specific layers that specialize for each individual task
- Multiple loss functions, one for each task
This approach is particularly valuable when dealing with related prediction problems that can benefit from shared feature representations.
Packages
Creating a MTL Model
The implementation will construct a model that simultaneously performs two related tasks:
- Regression: Predicting a continuous value
- Classification: Predicting a binary outcome
Sample Data
# Set seed for reproducibility
set.seed(123)
# Number of samples
<- 1000
n
# Create a dataset with 5 features
<- torch_randn(n, 5)
x
# Task 1 (Regression): Predict continuous value
# Create a target that's a function of the input features plus some noise
<- x[, 1] * 0.7 + x[, 2] * 0.3 - x[, 3] * 0.5 + torch_randn(n) * 0.2
y_regression
# Task 2 (Classification): Predict binary outcome
# Create a classification target based on a nonlinear combination of features
<- x[, 1] * 0.8 - x[, 4] * 0.4 + x[, 5] * 0.6
logits <- (logits > 0)$to(torch_float())
y_classification
# Split into training (70%) and testing (30%) sets
<- 1:round(0.7 * n)
train_idx <- (round(0.7 * n) + 1):n
test_idx
# Training data
<- x[train_idx, ]
x_train <- y_regression[train_idx]
y_reg_train <- y_classification[train_idx]
y_cls_train
# Testing data
<- x[test_idx, ]
x_test <- y_regression[test_idx]
y_reg_test <- y_classification[test_idx] y_cls_test
Define the Multi-Task Neural Network
The architecture design creates a neural network with shared layers and task-specific branches:
# Define the multi-task neural network
<- nn_module(
multi_task_net "MultiTaskNet",
initialize = function(input_size,
hidden_size, reg_output_size = 1,
cls_output_size = 1) {
$input_size <- input_size
self$hidden_size <- hidden_size
self$reg_output_size <- reg_output_size
self$cls_output_size <- cls_output_size
self
# Shared layers - these learn representations useful for both tasks
$shared_layer1 <- nn_linear(input_size, hidden_size)
self$shared_layer2 <- nn_linear(hidden_size, hidden_size)
self
# Task-specific layers
# Regression branch
$regression_layer <- nn_linear(hidden_size, reg_output_size)
self
# Classification branch
$classification_layer <- nn_linear(hidden_size, cls_output_size)
self
},
forward = function(x) {
# Shared feature extraction
<- x %>%
shared_features $shared_layer1() %>%
selfnnf_relu() %>%
$shared_layer2() %>%
selfnnf_relu()
# Task-specific predictions
<- self$regression_layer(shared_features)
regression_output <- self$classification_layer(shared_features)
classification_logits
list(
regression = regression_output,
classification = classification_logits
)
}
)
# Create model instance
<- multi_task_net(
model input_size = 5,
hidden_size = 10
)
# Print model architecture
print(model)
An `nn_module` containing 192 parameters.
── Modules ─────────────────────────────────────────────────────────────────────
• shared_layer1: <nn_linear> #60 parameters
• shared_layer2: <nn_linear> #110 parameters
• regression_layer: <nn_linear> #11 parameters
• classification_layer: <nn_linear> #11 parameters
4. Define Loss Functions and Optimizer
Multi-task learning requires separate loss functions for each task.
# Loss functions
<- nnf_mse_loss # Mean squared error for regression
regression_loss_fn <- nnf_binary_cross_entropy_with_logits # Binary cross-entropy for classification
classification_loss_fn
# Optimizer with weight decay for L2 regularization
<- optim_adam(model$parameters, lr = 0.01)
optimizer
# Task weights - these control the relative importance of each task
<- c(regression = 0.5, classification = 0.5)
task_weights
# Validation split from training data
<- round(0.2 * length(train_idx))
val_size <- sample(train_idx, val_size)
val_indices <- setdiff(train_idx, val_indices)
train_indices
# Create validation sets
<- x[val_indices, ]
x_val <- y_regression[val_indices]
y_reg_val <- y_classification[val_indices]
y_cls_val
# Update training sets
<- x[train_indices, ]
x_train <- y_regression[train_indices]
y_reg_train <- y_classification[train_indices] y_cls_train
Training Loop
# Hyperparameters
<- 100 # Increased epochs since we have early stopping
epochs
# Enhanced training history tracking
<- data.frame(
training_history epoch = integer(),
train_reg_loss = numeric(),
train_cls_loss = numeric(),
train_total_loss = numeric(),
val_reg_loss = numeric(),
val_cls_loss = numeric(),
val_total_loss = numeric(),
val_accuracy = numeric()
)
for (epoch in 1:epochs) {
# Training phase
$train()
model$zero_grad()
optimizer
# Forward pass on training data
<- model(x_train)
outputs
# Calculate training loss for each task
<- regression_loss_fn(
train_reg_loss $regression$squeeze(),
outputs
y_reg_train
)
<- classification_loss_fn(
train_cls_loss $classification$squeeze(),
outputs
y_cls_train
)
# Weighted combined training loss
<- task_weights["regression"] * train_reg_loss +
train_total_loss "classification"] * train_cls_loss
task_weights[
# Backward pass and optimize
$backward()
train_total_loss
# Gradient clipping to prevent exploding gradients
nn_utils_clip_grad_norm_(model$parameters, max_norm = 1.0)
$step()
optimizer
# Validation phase
$eval()
model
with_no_grad({
<- model(x_val)
val_outputs
# Calculate validation losses
<- regression_loss_fn(
val_reg_loss $regression$squeeze(),
val_outputs
y_reg_val
)
<- classification_loss_fn(
val_cls_loss $classification$squeeze(),
val_outputs
y_cls_val
)
<- task_weights["regression"] * val_reg_loss + task_weights["classification"] * val_cls_loss
val_total_loss
# Calculate validation accuracy
<- nnf_sigmoid(val_outputs$classification$squeeze())
val_cls_probs <- (val_cls_probs > 0.5)$to(torch_int())
val_cls_preds <- (val_cls_preds == y_cls_val$to(torch_int()))$sum()$item() / length(val_indices)
val_accuracy
})
# Record history
<- rbind(
training_history
training_history,data.frame(
epoch = epoch,
train_reg_loss = as.numeric(train_reg_loss$item()),
train_cls_loss = as.numeric(train_cls_loss$item()),
train_total_loss = as.numeric(train_total_loss$item()),
val_reg_loss = as.numeric(val_reg_loss$item()),
val_cls_loss = as.numeric(val_cls_loss$item()),
val_total_loss = as.numeric(val_total_loss$item()),
val_accuracy = val_accuracy
)
)
# Print progress every 25 epochs
if (epoch %% 25 == 0 || epoch == 1) {
cat(sprintf("Epoch %d - Train Loss: %.4f, Val Loss: %.4f, Val Acc: %.3f\n",
epoch, $item(),
train_total_loss$item(),
val_total_loss
val_accuracy))
}
}
Epoch 1 - Train Loss: 0.7958, Val Loss: 0.7369, Val Acc: 0.493
Epoch 25 - Train Loss: 0.3267, Val Loss: 0.3035, Val Acc: 0.821
Epoch 50 - Train Loss: 0.1548, Val Loss: 0.1350, Val Acc: 0.971
Epoch 75 - Train Loss: 0.0599, Val Loss: 0.0479, Val Acc: 0.993
Epoch 100 - Train Loss: 0.0381, Val Loss: 0.0356, Val Acc: 1.000
Model Evaluation
# Set model to evaluation mode
$eval()
model
# Make predictions on test set
with_no_grad({
<- model(x_test)
outputs
# Regression evaluation
<- outputs$regression$squeeze()
reg_preds <- regression_loss_fn(reg_preds, y_reg_test)
reg_test_loss
# Classification evaluation
<- outputs$classification$squeeze()
cls_preds <- nnf_sigmoid(cls_preds)
cls_probs <- classification_loss_fn(cls_preds, y_cls_test)
cls_test_loss
# Convert predictions to binary (threshold = 0.5)
<- (cls_probs > 0.5)$to(torch_int())
cls_pred_labels
# Calculate accuracy
<- (cls_pred_labels == y_cls_test$to(torch_int()))$sum()$item() / length(test_idx)
accuracy
})
# Calculate additional metrics
<- as.numeric(reg_preds)
reg_preds_r <- as.numeric(y_reg_test)
y_reg_test_r <- as.numeric(cls_probs)
cls_probs_r <- as.numeric(y_cls_test)
y_cls_test_r
# Regression metrics
<- sqrt(mean((reg_preds_r - y_reg_test_r)^2))
rmse <- mean(abs(reg_preds_r - y_reg_test_r))
mae <- cor(reg_preds_r, y_reg_test_r)^2
r_squared
# Classification metrics
<- pROC::auc(pROC::roc(y_cls_test_r, cls_probs_r, quiet = TRUE))
auc
# Display results
<- data.frame(
performance_results Task = c("Regression", "Regression", "Regression", "Classification", "Classification", "Classification"),
Metric = c("Test Loss (MSE)", "RMSE", "R-squared", "Test Loss (BCE)", "Accuracy", "AUC"),
Value = c(
round(reg_test_loss$item(), 4),
round(rmse, 4),
round(r_squared, 4),
round(cls_test_loss$item(), 4),
round(accuracy * 100, 2),
round(auc * 100, 2)
)
)
print(performance_results)
Task Metric Value
1 Regression Test Loss (MSE) 0.0533
2 Regression RMSE 0.2308
3 Regression R-squared 0.9449
4 Classification Test Loss (BCE) 0.0492
5 Classification Accuracy 98.0000
6 Classification AUC 99.9400
Visualization and Overfitting Analysis
# Plot enhanced training history with overfitting detection
<- training_history %>%
p1 select(epoch, train_total_loss, val_total_loss) %>%
pivot_longer(cols = c(train_total_loss, val_total_loss),
names_to = "split", values_to = "loss") %>%
mutate(split = case_when(
== "train_total_loss" ~ "Training",
split == "val_total_loss" ~ "Validation"
split %>%
)) ggplot(aes(x = epoch, y = loss, color = split)) +
geom_line(size = 1) +
geom_vline(xintercept = which.min(training_history$val_total_loss),
linetype = "dashed", color = "red", alpha = 0.7) +
labs(title = "Training vs Validation Loss",
subtitle = "Red line shows optimal stopping point",
x = "Epoch", y = "Total Loss", color = "Dataset") +
theme_minimal() +
scale_color_brewer(palette = "Set1")
# Separate task losses
<- training_history %>%
p2 select(epoch, train_reg_loss, val_reg_loss, train_cls_loss, val_cls_loss) %>%
pivot_longer(cols = -epoch, names_to = "metric", values_to = "loss") %>%
separate(metric, into = c("split", "task", "loss_type"), sep = "_") %>%
mutate(
split = ifelse(split == "train", "Training", "Validation"),
task = ifelse(task == "reg", "Regression", "Classification"),
metric_name = paste(split, task)
%>%
) ggplot(aes(x = epoch, y = loss, color = metric_name)) +
geom_line(size = 1) +
facet_wrap(~task, scales = "free_y") +
labs(title = "Task-Specific Loss Curves",
subtitle = "Monitoring overfitting in individual tasks",
x = "Epoch", y = "Loss", color = "Split & Task") +
theme_minimal() +
scale_color_brewer(palette = "Set2")
# Validation accuracy progression
<- ggplot(training_history, aes(x = epoch, y = val_accuracy)) +
p3 geom_line(color = "#2c3e50", size = 1) +
geom_hline(yintercept = max(training_history$val_accuracy),
linetype = "dashed", color = "red", alpha = 0.7) +
labs(title = "Validation Accuracy Progression",
subtitle = paste("Peak accuracy:", round(max(training_history$val_accuracy), 3)),
x = "Epoch", y = "Validation Accuracy") +
theme_minimal()
# Overfitting analysis
$overfitting_gap <- training_history$train_total_loss - training_history$val_total_loss
training_history
<- ggplot(training_history, aes(x = epoch, y = overfitting_gap)) +
p4 geom_line(color = "#e74c3c", size = 1) +
geom_hline(yintercept = 0, linetype = "dashed", alpha = 0.5) +
labs(title = "Overfitting Gap Analysis",
subtitle = "Difference between training and validation loss",
x = "Epoch", y = "Training Loss - Validation Loss") +
theme_minimal()
# Regression predictions vs actual values
<- data.frame(
regression_results Actual = y_reg_test_r,
Predicted = reg_preds_r
)
<- ggplot(regression_results, aes(x = Actual, y = Predicted)) +
p5 geom_point(alpha = 0.6, color = "#2c3e50") +
geom_abline(slope = 1, intercept = 0, color = "#e74c3c", linetype = "dashed", size = 1) +
geom_smooth(method = "lm", color = "#3498db", se = TRUE) +
labs(title = "Regression Task: Actual vs Predicted Values",
subtitle = paste("R² =", round(r_squared, 3), ", RMSE =", round(rmse, 3)),
x = "Actual Values", y = "Predicted Values") +
theme_minimal()
# Classification probability distribution
<- data.frame(
cls_results Probability = cls_probs_r,
Actual_Class = factor(y_cls_test_r, labels = c("Class 0", "Class 1"))
)
<- ggplot(cls_results, aes(x = Probability, fill = Actual_Class)) +
p6 geom_histogram(alpha = 0.7, bins = 20, position = "identity") +
geom_vline(xintercept = 0.5, linetype = "dashed", color = "red") +
labs(title = "Classification Task: Predicted Probabilities",
subtitle = paste("Accuracy =", round(accuracy * 100, 1), "%"),
x = "Predicted Probability", y = "Count", fill = "Actual Class") +
theme_minimal() +
scale_fill_brewer(palette = "Set1")
# Combine plots
library(patchwork)
| p3) / (p2) / (p4) / (p5 | p6) (p1
Key Takeaways
- Architecture Design: The shared-private paradigm enables models to learn both common and task-specific representations
- Loss Combination: Properly weighting multiple loss functions proves crucial for balanced learning across tasks
- Evaluation Strategy: Each task requires appropriate metrics, and overall model success depends on performance across all tasks
- Parameter Efficiency: Multi-task models can achieve comparable performance with fewer total parameters when properly regularized
- Knowledge Transfer: Related tasks can benefit from shared feature learning, especially when data is limited