Read it, or Reddit? A GNN Approach to Predicting Relationships on Reddit | by Finn Dayton | Jun, 2023


By: Finn Dayton and Brandon Minsung Kang as a part of the Stanford CS224W course venture

Be happy to observe together with the related Colab!

Think about: you’re mindlessly zombie scrolling by your Reddit dwelling feed, craving for a submit that’s price interacting with. Sadly, your feed is riddled with irrelevant posts from irrelevant subreddits. You don’t care about posts from r/mildlyinteresting… In any case, your complete subreddit is just stuffed with content material that wasn’t attention-grabbing sufficient to make r/attention-grabbing! Tragically, that’s the unhappy actuality of Reddit scrolling for a lot of with numerous hours of human productiveness misplaced on a regular basis to the depths of un-engaging content material.

Gif courtesy of Giphy

And that’s the place Graph Machine Studying (GraphML) is available in. With the ability of graph illustration and Graph Neural Networks (GNNs), we suggest that the aforementioned drawback could be represented as a GraphML process the place we are able to optimize Reddit dwelling pages and feeds to show content material from communities {that a} person is extra prone to work together with based mostly on their earlier exercise.

So observe alongside on our exploration of how we are able to implement GNNs to leverage the wealthy community construction of Reddit and seize relationships between customers, submit, and communities within the hopes of delivering a greater shopping expertise.

Desk of Contents

We construction our submit into 5 broad sections:

  1. Exploration of Our Dataset
  2. Rationalization of Fashions
  3. Walkthrough of Strategy
  4. Dialogue of Outcomes
  5. Conclusion

Dataset Construction

The Reddit dataset [1] is a graph dataset that was created from Reddit posts in September 2014. The authors sampled 50 giant communities to create a dataset containing 232,965 posts with a mean diploma of 492. Node labels are constructed from the “subreddit” {that a} submit belongs to. From this, the authors constructed a post-to-post graph that connects posts the place the identical person commented on each. Moreover, every submit has a size 602 function vector that describe the typical embedding of the submit title, the typical embedding of all of the submit’s feedback, the submit’s rating, and the variety of feedback made on the submit. The Reddit graph could be summarized as an undirected graph with no remoted nodes and no self-loops. The graph is characterised by a number of completely different node sorts that signify numerous subreddits and a single edge sort that represents whether or not two posts have a shared commenter. Our code under offers a abstract of key statistics which may assist with understanding the kind of info that the dataset incorporates.

from torch_geometric.datasets import Reddit
# Load the Reddit dataset
dataset = Reddit(root='knowledge/Reddit')
# Extract a single graph from the dataset
knowledge = dataset[0]
. . .
>>> Common node diploma: 492
>>> Variety of nodes: 232965
>>> Variety of edges: 114615892
>>> Variety of options per node: 602
>>> Variety of options per edge: 0

Utilizing the Dataset

The duty we discover is hyperlink prediction. Hyperlink prediction describes a process the place the purpose is to foretell lacking or future connections between nodes in a community. Within the context of our dataset, our hyperlink prediction process will contain predicting which posts a person would almost definitely touch upon based mostly on the nodes (posts) that they’ve beforehand commented on in addition to the hyperlinks that exist between posts that they’ve commented on. Broadly talking, we are able to construct a mannequin optimized for this process by masking current connections between nodes and coaching our mannequin to appropriately predict the hidden edges. Extra particulars to observe!

Graphic that reveals the method of a normal hyperlink prediction process. Courtesy of AnalyticsVidhya

Information Preprocessing

To make our dataset extra wieldy, we first carry out a number of preprocessing steps. We first import the Reddit dataset straight from torch_geometric, a library constructed upon PyTorch to particularly assist prepare GNNs. We then use the related Information object. This helps scale back the complexity of downloading, re-uploading, and formatting the json recordsdata offered within the unique dataset’s zip. See the above code block in the event you want a refresheer on how you can import and extract the Reddit Dataset

As we talked about earlier than, the Reddit dataset is wealthy with each numerous nodes (2,329,645 nodes) in addition to excessive connectivity between nodes (114,615,892 edges). This makes it each time consuming and computationally costly to carry out explorations in addition to mannequin constructing with the RAM and the GPU sizes offered by Colab Professional. To unravel this problem, we chosen a subgraph (a subset of nodes/edges from our unique graph) that might assist enhance computability whereas remaining consultant of the unique graph. We do that within the following steps:

  1. Discover consultant nodes. Right here, we outline “consultant” nodes to be nodes which have node levels equal to the typical node diploma within the graph (492 edges).
  2. Randomly choose a subset of consultant nodes.
  3. Carry out Ok-Hop subgraph era from these consultant nodes
  4. Create a subgraph object from the subset of edges and nodes generated from the above steps

The next code demonstrates our implementation of the above steps.

# Calculate diploma for every node in graph
levels = diploma(knowledge.edge_index[0], num_nodes=knowledge.num_nodes)

# Discover Node(s) with diploma 492
nodes_with_degree_492 = (levels == 492).nonzero(as_tuple=False).squeeze()


# Calculate variety of nodes with diploma 492
n = nodes_with_degree_492.form[0]

# Randomly choose pattern of nodes with diploma 492, number_of_samples could be modified as desired
number_of_samples = 10

# Sampled nodes of the proper diploma
nodes = []

for i in vary(number_of_samples):
random_index = random.randint(0, n-1)
node_id = nodes_with_degree_492[random_index].merchandise()
from torch_geometric.knowledge import Information

# Create a consultant subgraph producing Ok-Hop Neighborhood from subset of "Common Nodes"
subset, edge_index, mapping, edge_mask = k_hop_subgraph(nodes, 1, edge_index=knowledge.edge_index,relabel_nodes=True)

# Our subgraph - voila!
sub_graph = Information(x=knowledge.x[subset], edge_index=edge_index,

As soon as these steps are completed, we are able to create a NetworkX Graph object of this last subgraph. This helps us create attention-grabbing visualizations such because the one under.

A visualization of submit connectivity in an instance Reddit subgraph the place completely different colours represents completely different subreddits

Be aware: Although every time we re-run the graph era code, we get a graph with the identical variety of nodes, the preliminary nodes are sampled randomly. This implies you’ll get a barely completely different trying graph, however many of the traits (i.e. graph dimension, connectedness, edge course, subreddit selection) would be the similar.

Splitting our Dataset

One last step! We partition sub_graph into coaching, validation, and check graphs. This is a crucial step in mannequin improvement because it permits us to judge the efficiency of our fashions on knowledge that it has not seen throughout coaching.

We carry out a 80/10/10 Practice/Validate/Check break up on the perimeters in our subgraph after shuffling the order of our edge checklist so as to add randomization. After getting the respective edge units, we then add the corresponding nodes. This ensures that the three ensuing graphs is not going to have any stranded nodes. This system provides us three NetworkX Graph objects which we then rework again to PyTorch Geometric objects to facilitate coaching and testing on our GNN mannequin.

Now that we’ve outlined our drawback house, let’s discover how Graph Neural Networks can assist us predict what Reddit posts that you simply may like! First, let’s go over some fundamental definitions that can aid you perceive the ideas we introduce:

What’s a Graph Neural Community (GNN)?

A Graph Neural Community (GNN) is a sort of neural community that’s designed to work with graph knowledge buildings, corresponding to social networks or molecular buildings. GNNs can be taught to signify nodes and edges in a graph as vectors, permitting them to carry out machine studying duties on graph knowledge. In different phrases, a GNN learns embeddings for nodes based mostly on each their options and their neighborhood construction within the graph.

GNNs are a normal sort of mannequin and have a number of implementations together with Graphical Consideration Networks (GAT) [2] and GraphSAGE [3].


In our software, GraphSAGE is likely one of the key parts of our GNN structure. The GraphSAGE layer has the next properties. For a given central node, now we have the next message passing replace rule:

Equation to Replace a Node’s Embedding, Courtesy of GraphSAGE

Right here, W_1 and W_2 are learnable weight matrices that apply linear transformations to the central node embedding and aggregation outputs, respectively. The nodes which might be being aggregated over are the nodes neighboring the central node. In our software, we use imply aggregation for simplicity.

Instance Imply Aggregator Operate

Now wait — you is perhaps questioning, what does message passing and message aggregation even imply? Truthful query. We offer fundamental definitions for each under:

  • Message Passing — the passing of knowledge between nodes which might be linked by edges. Within the context of a GNN, every layer incorporates message passing that permits for a node to replace its embedding based mostly on its neighbors.
  • Message Aggregation — an aggregation step is utilized as soon as a node receives messages from its neighbors. Aggregation determines how the messages are mixed. Examples embody sum, imply, max aggregation.

A short graphic that demonstrates message passing and aggregation under:

Instance of Message Passing and Aggregation. Courtesy of Giuseppe Futia

Although the maths might look convoluted to the uninitiated, concern not! Merely put, your complete level is to be taught an embedding (vector) for a given node.

LinkPredictor Class

Given two embeddings for nodes A and B, you may ask, how can we predict 0 (unconnected) or 1 (linked)? This query is vital to our hyperlink prediction process because it entails calculating the likelihood that an edge between two nodes may exist.

For our LinkPredictor, we determined to make use of an easy set of two linear layers, utilizing ReLU and dropout between them. Primarily, given two node embeddings outputted from the GNNStack, the LinkPredictor will both carry out element-wise multiplication or addition (as much as you!) to mix them. This method was impressed by the paper “Hyperlink Prediction Based mostly on Graph Neural Networks” 2018 by Yang and Chen [4].

The whole mannequin, as depicted above, incorporates a number of back-to-back GraphSAGE fashions to be taught deeper embeddings and the LinkPredictor class is tacked on the tip. Our implementation could be discovered under:

class LinkPredictor(nn.Module):
def __init__(self, input_dim, hidden_dim, dropout, technique="multiply"):
tremendous(LinkPredictor, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.technique = technique
self.mannequin = nn.Sequential(nn.Linear(self.input_dim, self.hidden_dim),
nn.Dropout(p = dropout),
nn.Linear(self.hidden_dim, 1),
def ahead(self, a, b):
if self.technique == "addition":
return self.mannequin(a + b)
elif self.technique == "multiply":
return self.mannequin(a * b)

After splitting our subgraph into three subgraphs: prepare, validation and check, there’s yet one more vital nuance that we name consideration to. You will need to generate a set of “unfavorable” edges (this refers to edges that don’t exist in our graphs) that’s equal to the variety of optimistic edges that exist in our graphs. This turns into crucial as we need a mannequin that may differentiate between edges that exist and don’t exist in our graph buildings. Thus, offering a unfavorable pattern can assist inform our mannequin. We offer a glimpse of this course of under:

# args incorporates dropout and num_layers
class mannequin(nn.Module):
def __init__(self, args, input_dim=602, hidden_dim=256, output_dim=256, layer=GraphSage, link_predictor_class=LinkPredictor, stacker=GNNStack):
tremendous(mannequin, self).__init__()
self.stacked_layers_model = self.stacker(layer, input_dim, hidden_dim, output_dim, args).to(machine) #these are the graph sage layers
self.link_predictor = link_predictor_class(output_dim, hidden_dim, dropout=args['dropout']).to(machine)

def ahead(self, x, edge_index):
# _, num_edges = sub_graph.edge_index.dimension()
node_embeddings = self.stacked_layers_model(x, edge_index)

num_nodes, embedding_him = node_embeddings.dimension()
num_edges = edge_index.dimension(1)

# pattern optimistic edges (indices)
pos_edge_indices = edge_index
# print(f"n pos_edge_indices: {pos_edge_indices}n")
# print(f"n num_edges: {num_edges}n")

# pattern unfavorable edges (indices)
neg_edge_indices = negative_sampling(edge_index, num_nodes, num_neg_samples=int(num_edges), methodology='dense')

# feed these embeddings into link_predictor to get neg_preds and pos_preds
link_predictions_pos = self.link_predictor(node_embeddings[pos_edge_indices[0]], node_embeddings[pos_edge_indices[1]]) # Brandon TODO
link_predictions_neg = self.link_predictor(node_embeddings[neg_edge_indices[0]], node_embeddings[neg_edge_indices[1]]) # Brandon TODO
# print(link_predictions_pos)
# print(link_predictions_neg)
# print(f"n link_predictions_pos: {link_predictions_pos.dimension()}, distinctive values: {len(torch.distinctive(link_predictions_pos.view(-1)))}")
# print(f"link_predictions_neg: {link_predictions_pos.dimension()}, distinctive values: {len(torch.distinctive(link_predictions_neg.view(-1)))} n")

# if self.eval:
return link_predictions_pos, link_predictions_neg

As soon as now we have our unfavorable samples, we’re all able to go! From right here, we run the prepare graph by your complete mannequin, get the predictions, then back-propagate again by all of the gradients within the mannequin. Each 5 epochs, we run the validation and check graphs by the mannequin to see how the mannequin is generalizing to unseen graphs. The under code performs this. See the linked Colab for definitions of the helper features.

import matplotlib.pyplot as plt

epochs = args['epochs']
epochs_bar = trange(1, epochs + 1, desc='Loss n/a')

# extract the sting indices and node function matrices
x_train = pyg_train_graph.x
edge_index_train = pyg_train_graph.edge_index

x_val = pyg_val_graph.x
edge_index_val = pyg_val_graph.edge_index

x_test = pyg_test_graph.x
edge_index_test = pyg_test_graph.edge_index

# transfer to cuda
x_train =
x_val =
x_test =
edge_index_train =
edge_index_val =
edge_index_test =

# instantiate the mannequin and stated it to coach mode
my_model = mannequin(args)

# get optimizer
optimizer = optim.Adam(my_model.parameters(), lr=.005)

# convert the edge_index tensors into Pytorch Dataset sort so we are able to iteration of them in prepare()
edge_index_train = TensorDataset(edge_index_train)
edge_index_val = TensorDataset(edge_index_val)
edge_index_test = TensorDataset(edge_index_test)

# Practice the mannequin
losses = []
valid_hits_validation_list = []
valid_hits_test_list = []
for epoch in epochs_bar:
epoch_loss, total_examples = 0, 0
loss = prepare(x_train, edge_index_train, args, my_model, optimizer)
epochs_bar.set_description(f'Epoch {epoch}, Loss {loss:0.4f}')
if epoch % 5 == 0:
valid_hits_validation = check(my_model, x_val, edge_index_val, args, okay=10)

valid_hits_test = check(my_model, x_val, edge_index_val, args, okay=10)
print(f'Epoch: {epoch}, Practice loss: {loss}, Validation Hits@20: {valid_hits_validation}, Check Hits@20: {valid_hits_test}')
valid_hits_validation_list.append(valid_hits_validation_list[-1] if valid_hits_validation_list else 0)
valid_hits_test_list.append(valid_hits_test_list[-1] if valid_hits_test_list else 0)

We educated our mannequin for 100 epochs on the prepare graph, operating the validation and check graphs although the mannequin each 5 epochs of coaching.

Our analysis metric for the validation and check fashions is Hits@Ok the place Ok is a variable set to twenty in our code. Right here, we outline Hits@Ok because the variety of the highest Ok suggestions which might be truly linked.

In different phrases, the sigmoid on the finish of the Hyperlink Predictor class outputs a quantity between 0.0 and 1.0 for every given pair. We type these predictions in descending order. Hits@20 is the variety of pairs of nodes within the prime 20 (with excessive predicted likelihood of a hyperlink between them) that truly have a hyperlink between them. Subsequently, the very best rating is 20, and the worst rating is 0. We are able to additionally signify this quantity as a decimal between 0.0 and 1.0 the place the denominator is Ok and numerator is the variety of “hits”. The code for calculating hits@Ok is proven under:

def calculate_hits_at_k(pos_preds,neg_preds,okay=20):
tensor_ones = torch.ones_like(pos_preds)
tensor_zeros = torch.zeros_like(neg_preds)

pos_preds_labeled =, tensor_ones),dim=1)
neg_preds_labeled =, tensor_zeros), dim=1)

combined_preds =, neg_preds_labeled), dim=0)
sorted_combined_preds, indices = torch.type(combined_preds[:,0], descending=True)
sorted_combined_preds = combined_preds[indices]

hit_indices = torch.arange(okay)
hits = torch.sum(sorted_combined_preds[hit_indices,1])

General, our mannequin carried out nicely. Outcomes will range based mostly on the Ok worth for Hits@Ok, and the opposite hyperparameters. As you’ll be able to see, the mannequin’s prepare loss went down monotonically throughout the epochs, however the validation converged to 100% Hits@Ok after solely 60 epochs. If we had extra time, we’d take into account enhancements referring to gradient clipping, studying price decay, and sampling bigger graphs to coach over.

Should you’ve caught round till this level, thanks! We hope you’ve discovered as a lot as we’ve discovered as we constructed up our venture. Our work solely touches the floor of what GNNs can do and the place they are often utilized. However earlier than we wrap issues up, let’s summarize all the things we’ve discovered at this time.

  1. We discovered in regards to the NetworkX and PyTorch Geometric libraries and explored the tradeoffs between utilizing graph representations in a single library over one other.
  2. We explored how you can generate consultant subgraphs to assist enhance the computability and explorability of huge graphs.
  3. We found how you can construct a GraphSAGE GNN and utilized the Hyperlink Predictor class to assist construct an efficient mannequin for hyperlink prediction.

General, we’re excited as a result of we imagine these insights and explorations could be prolonged past Reddit to any type of media that may be represented by graph buildings. We hope you’re pleased with what you’ve achieved and that this marks the start of additional explorations of Graph Neural Networks.

[1] Reddit Dataset. Pytorch Geometric.

[2] PetarVelicˇkovic ́,GuillemCucurull,ArantxaCasanova,AdrianaRomero,PietroLiò,andYoshua Bengio. Graph consideration networks, 2018.

[3] William L. Hamilton, Rex Ying, and Jure Leskovec. Inductive illustration studying on giant graphs, 2018.

[4] Muhan Zhang and Yixin Chen. Hyperlink prediction based mostly on graph neural networks, 2018.

Source link


Please enter your comment!
Please enter your name here