9  Graph (convolutional) neural networks

Graph neural networks (GNN) is a young representative of the deep neural network family but is receiving more and more attention in the last years because of their ability to process non-Euclidean data such as graphs.

Currently there is no R package for GNNs available. However, we can use the ‘reticulate’ package (Ushey, Allaire, and Tang (2022)) to use the python packages ‘torch’ and ‘torch_geometric’ (Paszke et al. (2019), 2019; Fey and Lenssen (2019)).

The following example was mostly adapted from the ‘Node Classification with Graph Neural Networks’ example from the torch_geometric documentation (https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).

The dataset is also provided by the ‘torch_geometric’ package and consists of molecules presented as graphs and the task is to predict whether HIV virus replication is inhibited by the molecule or not (classification, binary classification).

We have not implemented this example in Julia because there is not yet a well-established library for GNNs.

library(reticulate)
# Load python packages torch and torch_geometric via the reticulate R package
torch = import("torch") 
torch_geometric = import("torch_geometric")

# helper functions from the torch_geometric modules
GCNConv = torch_geometric$nn$GCNConv
global_mean_pool = torch_geometric$nn$global_mean_pool


# Download the MUTAG TUDataset
dataset = torch_geometric$datasets$TUDataset(root='data/TUDataset', 
                                             name='MUTAG')
dataloader = torch_geometric$loader$DataLoader(dataset, 
                                               batch_size=64L,
                                               shuffle=TRUE)

# Create the model with a python class
# There are two classes in the response variable
GCN = PyClass(
  "GCN", 
   inherit = torch$nn$Module, 
   defs = list(
       `__init__` = function(self, hidden_channels) {
         super()$`__init__`()
         torch$manual_seed(42L)
         self$conv = GCNConv(dataset$num_node_features, hidden_channels)
         self$linear = torch$nn$Linear(hidden_channels, dataset$num_classes)
         NULL
       },
       forward = function(self, x, edge_index, batch) {
         x = self$conv(x, edge_index)
         x = x$relu()
         x = global_mean_pool(x, batch)
         
         x = torch$nn$functional$dropout(x, p = 0.5, training=self$training)
         x = self$linear(x)
         return(x)
       }
   ))

Training loop:

# create model object
model = GCN(hidden_channels = 64L)

# get optimizer and loss function
optimizer = torch$optim$Adamax(model$parameters(), lr = 0.01)
loss_func = torch$nn$CrossEntropyLoss()

# set model into training mode (because of the dropout layer)
model$train()
GCN(
  (conv): GCNConv(7, 64)
  (linear): Linear(in_features=64, out_features=2, bias=True)
)
# train model
for(e in 1:50) {
  iterator = reticulate::as_iterator(dataloader)
  coro::loop(for (b in iterator) { 
     pred = model(b$x, b$edge_index, b$batch)
     loss = loss_func(pred, b$y)
     loss$backward()
     optimizer$step()
     optimizer$zero_grad()
  })
  if(e %% 10 ==0) cat(paste0("Epoch: ",e," Loss: ", round(loss$item()[1], 4), "\n"))
}
Epoch: 10 Loss: 0.6151
Epoch: 20 Loss: 0.6163
Epoch: 30 Loss: 0.5745
Epoch: 40 Loss: 0.5362
Epoch: 50 Loss: 0.5829

Make predictions:

preds = list()
test = torch_geometric$loader$DataLoader(dataset, batch_size=64L,shuffle=FALSE)
iterator = reticulate::as_iterator(test)
model$eval()
GCN(
  (conv): GCNConv(7, 64)
  (linear): Linear(in_features=64, out_features=2, bias=True)
)
counter = 1
coro::loop(for (b in iterator) {
  preds[[counter]] = model(b$x, b$edge_index, b$batch)
  counter <<- counter + 1
  })
head(torch$concat(preds)$sigmoid()$data$cpu()$numpy(), n = 3)
          [,1]      [,2]
[1,] 0.3076028 0.6427078
[2,] 0.4121239 0.5515330
[3,] 0.4119514 0.5516798
# Load python packages torch and torch_geometric via the reticulate R package
import torch
import torch_geometric

# helper functions from the torch_geometric modules
GCNConv = torch_geometric.nn.GCNConv
global_mean_pool = torch_geometric.nn.global_mean_pool


# Download the MUTAG TUDataset
dataset = torch_geometric.datasets.TUDataset(root='data/TUDataset', 
                                             name='MUTAG')
dataloader = torch_geometric.loader.DataLoader(dataset, 
                                               batch_size=64,
                                               shuffle=True)

# Create the model with a python class
# There are two classes in the response variable
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
         super().__init__()
         torch.manual_seed(42)
         self.conv = GCNConv(dataset.num_node_features, hidden_channels)
         self.linear = torch.nn.Linear(hidden_channels, dataset.num_classes)
         
    def forward(self, x, edge_index, batch):
        x = self.conv(x, edge_index)
        x = x.relu()
        x = global_mean_pool(x, batch)
        x = torch.nn.functional.dropout(x, p = 0.5, training=self.training)
        x = self.linear(x)
        return x

Training loop:

# create model object
model = GCN(hidden_channels = 64)

# get optimizer and loss function
optimizer = torch.optim.Adamax(model.parameters(), lr = 0.01)
loss_func = torch.nn.CrossEntropyLoss()

# set model into training mode (because of the dropout layer)
model.train()

# train model
GCN(
  (conv): GCNConv(7, 64)
  (linear): Linear(in_features=64, out_features=2, bias=True)
)
for e in range(50):
  for b in dataloader:
  
    pred = model(b.x, b.edge_index, b.batch)
    loss = loss_func(pred, b.y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
     
  if e % 10 ==0:
    print("Epoch: ", e ," Loss: ", loss.item(), "\n")
Epoch:  0  Loss:  0.6617004871368408 

Epoch:  10  Loss:  0.614981472492218 

Epoch:  20  Loss:  0.6161867380142212 

Epoch:  30  Loss:  0.5802667737007141 

Epoch:  40  Loss:  0.5124867558479309 

Make predictions:

preds = []
test = torch_geometric.loader.DataLoader(dataset, batch_size=64,shuffle=False)
model.eval()
GCN(
  (conv): GCNConv(7, 64)
  (linear): Linear(in_features=64, out_features=2, bias=True)
)
counter = 1
for b in test:
  preds.append( model(b.x, b.edge_index, b.batch) )
  
  
torch.concat(preds).sigmoid().data.cpu().numpy()[0:10]
array([[0.30760282, 0.64270777],
       [0.41212386, 0.551533  ],
       [0.4119514 , 0.5516798 ],
       [0.29887193, 0.650517  ],
       [0.48894534, 0.48584774],
       [0.4310807 , 0.5360305 ],
       [0.31375578, 0.63721913],
       [0.34597102, 0.6093393 ],
       [0.50279325, 0.4740774 ],
       [0.30924183, 0.6412629 ]], dtype=float32)