Multi-Task Learning with torch in R

R
Deep Learning
torch
Multi-Task Learning
Published

May 11, 2025

Multi-task learning (MTL) is an approach where a single NN model is trained to perform multiple related tasks simultaneously. This approach can improve model generalization, reduce overfitting, and leverage the shared information across tasks. In this post, we’ll implement a multi-task learning model using the torch package in R.

Introduction to Multi-Task Learning

Multi-task learning works by sharing representations between related tasks, allowing the model to generalize better. Instead of training separate models for each task, we train a single model with:

  • Shared layers that learn common features across tasks
  • Task-specific layers that specialize for each individual task
  • Multiple loss functions, one for each task

Installation and Setup

First, let’s install and load the required packages:

# install.packages(c("torch", "tidyverse"))
library(torch)
library(tidyverse)

Creating a Multi-Task Learning Model

We’ll build a model that simultaneously performs two related tasks: 1. Regression: Predicting a continuous value 2. Classification: Predicting a binary outcome

1. Generate Sample Data for Multiple Tasks

set.seed(123)

# Number of samples
n <- 1000

# Create a dataset with 5 features
x <- torch_randn(n, 5)

# Task 1 (Regression): Predict continuous value
# We'll create a target that's a function of the input features plus some noise
y_regression <- x[, 1] * 0.7 + x[, 2] * 0.3 - x[, 3] * 0.5 + torch_randn(n) * 0.2

# Task 2 (Classification): Predict binary outcome
# Create a classification target based on a nonlinear combination of features
logits <- x[, 1] * 0.8 - x[, 4] * 0.4 + x[, 5] * 0.6
y_classification <- (logits > 0)$to(torch_float())

# Split into training (70%) and testing (30%) sets
train_idx <- 1:round(0.7 * n)
test_idx <- (round(0.7 * n) + 1):n

# Training data
x_train <- x[train_idx, ]
y_reg_train <- y_regression[train_idx]
y_cls_train <- y_classification[train_idx]

# Testing data
x_test <- x[test_idx, ]
y_reg_test <- y_regression[test_idx]
y_cls_test <- y_classification[test_idx]

2. Define the Multi-Task Neural Network

We’ll create a neural network with: - A shared base network for both tasks - Task-specific branches for regression and classification

multi_task_net <- nn_module(
  "MultiTaskNet",
  
  initialize = function(input_size, 
                        hidden_size, 
                        reg_output_size = 1, 
                        cls_output_size = 1) {
    
    self$input_size <- input_size
    self$hidden_size <- hidden_size
    self$reg_output_size <- reg_output_size
    self$cls_output_size <- cls_output_size
    
    # Shared layers
    self$shared_layer1 <- nn_linear(input_size, hidden_size)
    self$shared_layer2 <- nn_linear(hidden_size, hidden_size)
    
    # Task-specific layers
    
    # Regression branch
    self$regression_layer <- nn_linear(hidden_size, reg_output_size)
    
    # Classification branch
    self$classification_layer <- nn_linear(hidden_size, cls_output_size)
  },
  
  forward = function(x) {
    # Shared feature extraction
    shared_features <- x %>%
      self$shared_layer1() %>%
      nnf_relu() %>%
      self$shared_layer2() %>%
      nnf_relu()
    
    # Task-specific predictions
    regression_output <- self$regression_layer(shared_features)
    classification_logits <- self$classification_layer(shared_features)
    
    list(
      regression = regression_output,
      classification = classification_logits
    )
  }
)

# Create model instance
model <- multi_task_net(
  input_size = 5,
  hidden_size = 64
)

# Print model architecture
model
An `nn_module` containing 4,674 parameters.

── Modules ─────────────────────────────────────────────────────────────────────
• shared_layer1: <nn_linear> #384 parameters
• shared_layer2: <nn_linear> #4,160 parameters
• regression_layer: <nn_linear> #65 parameters
• classification_layer: <nn_linear> #65 parameters

3. Define Loss Functions and Optimizer

For multi-task learning, we need separate loss functions for each task:

# Loss functions
regression_loss_fn <- nnf_mse_loss  # Mean squared error for regression
classification_loss_fn <- nnf_binary_cross_entropy_with_logits  # Binary cross-entropy for classification

# Optimizer
optimizer <- optim_adam(model$parameters, lr = 0.01)

4. Custom Training Loop for Multi-Task Learning

We’ll train the model by combining the losses from both tasks:

# Hyperparameters
epochs <- 200
task_weights <- c(regression = 0.5, 
                  classification = 0.5)  # Relative importance of each task

# Training loop
training_history <- data.frame(
  epoch = integer(),
  reg_loss = numeric(),
  cls_loss = numeric(),
  total_loss = numeric()
)

for (epoch in 1:epochs) {
  model$train()
  optimizer$zero_grad()
  
  # Forward pass
  outputs <- model(x_train)
  
  # Calculate loss for each task
  reg_loss <- regression_loss_fn(
    outputs$regression$squeeze(), 
    y_reg_train
  )
  
  cls_loss <- classification_loss_fn(
    outputs$classification$squeeze(), 
    y_cls_train
  )
  
  # Weighted combined loss
  total_loss <- task_weights["regression"] * reg_loss + 
               task_weights["classification"] * cls_loss
  
  # Backward pass and optimize
  total_loss$backward()
  optimizer$step()
  
  # Record history (every 20 epochs)
  if (epoch %% 20 == 0 || epoch == 1) {
    training_history <- rbind(
      training_history,
      data.frame(
        epoch = epoch,
        reg_loss = as.numeric(reg_loss$item()),
        cls_loss = as.numeric(cls_loss$item()),
        total_loss = as.numeric(total_loss$item())
      )
    )
  }
}

5. Evaluate the Model

Let’s evaluate the model’s performance on both tasks:

model$eval()
with_no_grad({
  outputs <- model(x_test)
  
  # Regression evaluation
  reg_preds <- outputs$regression$squeeze()
  reg_test_loss <- regression_loss_fn(reg_preds, y_reg_test)
  
  # Classification evaluation
  cls_preds <- outputs$classification$squeeze()
  cls_probs <- nnf_sigmoid(cls_preds)
  cls_test_loss <- classification_loss_fn(cls_preds, y_cls_test)
  
  # Convert predictions to binary (threshold = 0.5)
  cls_pred_labels <- (cls_probs > 0.5)$to(torch_int())
  
  # Calculate accuracy
  accuracy <- (cls_pred_labels == y_cls_test$to(torch_int()))$sum()$item() / length(test_idx)
})

cat(sprintf("Test Regression Loss: %.4f\n", reg_test_loss$item()))
Test Regression Loss: 0.0426
cat(sprintf("Test Classification Loss: %.4f\n", cls_test_loss$item()))
Test Classification Loss: 0.0328
cat(sprintf("Classification Accuracy: %.2f%%\n", accuracy * 100))
Classification Accuracy: 99.00%

6. Visualize the Results

Let’s visualize the training progress and model predictions:

# Plot training history
ggplot(training_history, aes(x = epoch)) +
  geom_line(aes(y = reg_loss, color = "Regression Loss")) +
  geom_line(aes(y = cls_loss, color = "Classification Loss")) +
  labs(title = "Training Loss Over Time",
       x = "Epoch",
       y = "Loss",
       color = "Task") +
  theme_minimal()

# Convert test predictions to R objects for visualization
reg_preds_r <- as.numeric(reg_preds)
y_reg_test_r <- as.numeric(y_reg_test)

# Plot regression predictions vs actual values
regression_results <- data.frame(
  Actual = y_reg_test_r,
  Predicted = reg_preds_r
)

ggplot(regression_results, aes(x = Actual, y = Predicted)) +
  geom_point(alpha = 0.5) +
  geom_abline(slope = 1, intercept = 0, color = "red", linetype = "dashed") +
  labs(title = "Regression Task: Actual vs Predicted Values",
       x = "Actual Values",
       y = "Predicted Values") +
  theme_minimal()

# Confusion matrix visualization
cls_pred_labels_r <- as.integer(as.array(cls_pred_labels))
y_cls_test_r <- as.integer(as.array(y_cls_test))

confusion_data <- table(
  Predicted = cls_pred_labels_r,
  Actual = y_cls_test_r
)

confusion_df <- as.data.frame(confusion_data)

ggplot(confusion_df, aes(x = Actual, y = Predicted, fill = Freq)) +
  geom_tile() +
  geom_text(aes(label = Freq), color = "white", size = 10) +
  scale_fill_gradient(low = "steelblue", high = "darkblue") +
  labs(title = "Classification Task: Confusion Matrix",
       x = "Actual Class",
       y = "Predicted Class") +
  theme_minimal()

Benefits of Multi-Task Learning

  1. Parameter Efficiency: We’re using a single model for two different tasks instead of training separate models.

  2. Improved Generalization: The shared layers learn representations that work well for both tasks, which can lead to better generalization.

  3. Training Efficiency: We’re updating the shared parameters using signals from both tasks, which can speed up learning.

Extensions and Variations

  • Task Weighting: Experiment with different weights for each task based on their relative importance.
  • Additional Tasks: Add more related tasks to benefit from the shared representations.
  • Task-Specific Layers: Add more specialized layers to each task branch for better task-specific performance.