Building a Simple Neural Network in R with torch

R
Deep Learning
torch
Published

December 5, 2024

The torch package brings the power of deep learning to R by providing bindings to the popular PyTorch library. In this post, you’ll learn how to build and train a simple neural network using torch in R.

Installation

To get started, install the torch package from CRAN and set up the backend:

# install.packages("torch")
library(torch)
# torch::install_torch()

Creating a Simple Neural Network

Let’s create a neural network to perform regression on a simple dataset (predicting y from x).

1. Generate Sample Data

set.seed(42)
x <- torch_randn(100, 1)
y <- 3 * x + 2 + torch_randn(100, 1) * 0.3

2. Define the Neural Network Module

net <- nn_module(
  initialize = function() {
    self$fc1 <- nn_linear(1, 8)
    self$fc2 <- nn_linear(8, 1)
  },
  forward = function(x) {
    x %>% self$fc1() %>% nnf_relu() %>% self$fc2()
  }
)
model <- net()

3. Set Up the Optimizer and Loss Function

optimizer <- optim_adam(model$parameters, lr = 0.01)
loss_fn <- nnf_mse_loss

4. Training Loop

for(epoch in 1:300) {
  model$train()
  optimizer$zero_grad()
  y_pred <- model(x)
  loss <- loss_fn(y_pred, y)
  loss$backward()
  optimizer$step()
  if(epoch %% 50 == 0) {
    cat(sprintf("Epoch %d, Loss: %3f\n", epoch, loss$item()))
  }
}
Epoch 50, Loss: 0.498060
Epoch 100, Loss: 0.136693
Epoch 150, Loss: 0.119539
Epoch 200, Loss: 0.109248
Epoch 250, Loss: 0.102379
Epoch 300, Loss: 0.096627

5. Visualize the Results

x_np <- as.numeric(x$squeeze())
y_np <- as.numeric(y$squeeze())
y_pred_np <- as.numeric(model(x)$squeeze())

plot(x_np, 
     y_np, 
     main = "Neural Net Regression with torch", 
     xlab = "x",
     ylab = "y")

points(x_np, 
       y_pred_np, 
       col = 'red', 
       pch = 20)

legend('topleft', 
       legend = c('Actual', 'Predicted'), 
       col = c('black', 'red'), 
       pch = c(1, 20))

This example demonstrates how easy it is to build and train neural networks in R using torch. You can extend this approach to more complex datasets and deeper architectures as needed.