With the advancement of Internet technology, recommendations acted as an integral part of the internet whether we are watching some videos on YouTube, enjoying movies with friends on Netflix, getting new friends suggestion on facebook, online advertisements or Amazon product suggestions all of these scenarios uses recommendation systems at their backend to suggest users an items that could interest them. The main objective of recommendation systems is to suggest a user suitable item based on his/her past behavior of engagements with the items and user's side information (e.g. age, gender, demographics, etc). Therefore, recommender systems (RS) are used to search for the small number of relevant items that match a user's personalized interests. Here, I am assuming a recommender system with no-cold start problem.
On the other hand, Graph Neural Networks (GNNs) based methods have shown a great success for tackling the recommendation problems when compared to the traditional recommendation technique like collaborative filtering (CF). There are several use cases out there in industry which use GNNs to solve their recommendation problems. For example, Food Discovery with Uber Eats Uber uses the power of GNNs to suggest to its users the dishes, restaurants, and cuisines they might like next. To make these recommendations Uber eats uses the GraphSAGE algorithm because of its inductive nature and the power to scale up-to billion of nodes. Another interesting application is PinSage ( a random-walk Graph Convolutional Network capable of learning embeddings for nodes in web-scale graphs containing billions of objects) an algorithm developed by Pinterest company to perform visual recommendations.
Therefore, in this blogpost, we will together build a complete movie recommendation application using ArangoDB (open-source native multi-model graph database) and PyTorch Geometric (library built upon PyTorch to easily write and train custom GNNs). In graph machine learning (or graph neural networks) we solve mainly three types of tasks i.e. node classification, link prediction and graph classification. In this blogpost, we are going to tackle the challenge of building movie recommendation application by transforming it into the task of link prediction. The aim of the link prediction task is to predict whether there is an edge existing between two given nodes. First, we are going to build a bipartite graph of user and movie nodes where an edge between a user and movie represents how much rating (lies between 1-5) the user has given to that movie. Then, our goal is to predict missing links between a user and the movies which he has not watched yet. The missing links are computed with the graph neural networks where we predict the ratings for all the unseen movies. At the end, only those movies are recommended to a user where the predicting ratings are equal to 5. Below figure depicts the bipartite graph of user and movie nodes:
We are going to use The Movies Dataset from Kaggle which contains the metadata for all 45,000 movies listed in the Full MovieLens Dataset. With the help of metadata information we are going to prepare features for our movie nodes which then can be used as an input for our GNNs. The focus of this blogpost revolves around the following key content:
If you are running this notebook on Google Colab, please make sure to enable hardware acceleration using either a GPU or a TPU. If it is run with CPU-only enabled, training GNNs will take an incredibly long time! Hardware acceleration can be enabled by navigating to Runtime -> Change Runtime. This will present you with a popup, where you can select an appropriate Hardware Accelerator
%%capture
!git clone -b oasis_connector --single-branch https://github.com/arangodb/interactive_tutorials.git
!git clone -b movie-data-source --single-branch https://github.com/arangodb/interactive_tutorials.git movie_data_source
!rsync -av interactive_tutorials/ ./ --exclude=.git
!chmod -R 755 ./tools
!pip install pyarango
!pip install "python-arango>=5.0"
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install sentence-transformers
We are going to use the sampled version of The Movies Dataset . This dataset contains mainly three csv files:
# unzip movies dataset zip file
!unzip ./movie_data_source/sampled_movie_dataset.zip
Archive: ./movie_data_source/sampled_movie_dataset.zip creating: sampled_movie_dataset/ inflating: sampled_movie_dataset/links_small.csv inflating: __MACOSX/sampled_movie_dataset/._links_small.csv inflating: sampled_movie_dataset/.DS_Store inflating: __MACOSX/sampled_movie_dataset/._.DS_Store inflating: sampled_movie_dataset/movies_metadata.csv inflating: __MACOSX/sampled_movie_dataset/._movies_metadata.csv inflating: sampled_movie_dataset/ratings_small.csv inflating: __MACOSX/sampled_movie_dataset/._ratings_small.csv
import pandas as pd
from arango import ArangoClient
from tqdm import tqdm
import numpy as np
import itertools
import requests
import sys
import oasis
from arango import ArangoClient
import torch
import torch.nn.functional as F
from torch.nn import Linear
from arango import ArangoClient
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from sentence_transformers import SentenceTransformer
from torch_geometric.data import HeteroData
import yaml
In this section, we will read the data from multiple csv files and construct a graph out of it. The graph which we are going to construct is a bipartite graph having user nodes on the one side and movie nodes on the other side of the graph. In addition to this user and movies nodes of the graph will also be accompanied by their corresponding attributes features. Next, we are going to store this generated graph inside the ArangoDB.
metadata_path = './sampled_movie_dataset/movies_metadata.csv'
df = pd.read_csv(metadata_path)
/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py:2882: DtypeWarning: Columns (10) have mixed types.Specify dtype option on import or set low_memory=False. exec(code_obj, self.user_global_ns, self.user_ns)
df.head()
adult | belongs_to_collection | budget | genres | homepage | id | imdb_id | original_language | original_title | overview | ... | release_date | revenue | runtime | spoken_languages | status | tagline | title | video | vote_average | vote_count | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | False | {'id': 10194, 'name': 'Toy Story Collection', ... | 30000000 | [{'id': 16, 'name': 'Animation'}, {'id': 35, '... | http://toystory.disney.com/toy-story | 862 | tt0114709 | en | Toy Story | Led by Woody, Andy's toys live happily in his ... | ... | 1995-10-30 | 373554033.0 | 81.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | NaN | Toy Story | False | 7.7 | 5415.0 |
1 | False | NaN | 65000000 | [{'id': 12, 'name': 'Adventure'}, {'id': 14, '... | NaN | 8844 | tt0113497 | en | Jumanji | When siblings Judy and Peter discover an encha... | ... | 1995-12-15 | 262797249.0 | 104.0 | [{'iso_639_1': 'en', 'name': 'English'}, {'iso... | Released | Roll the dice and unleash the excitement! | Jumanji | False | 6.9 | 2413.0 |
2 | False | {'id': 119050, 'name': 'Grumpy Old Men Collect... | 0 | [{'id': 10749, 'name': 'Romance'}, {'id': 35, ... | NaN | 15602 | tt0113228 | en | Grumpier Old Men | A family wedding reignites the ancient feud be... | ... | 1995-12-22 | 0.0 | 101.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Still Yelling. Still Fighting. Still Ready for... | Grumpier Old Men | False | 6.5 | 92.0 |
3 | False | NaN | 16000000 | [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... | NaN | 31357 | tt0114885 | en | Waiting to Exhale | Cheated on, mistreated and stepped on, the wom... | ... | 1995-12-22 | 81452156.0 | 127.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Friends are the people who let you be yourself... | Waiting to Exhale | False | 6.1 | 34.0 |
4 | False | {'id': 96871, 'name': 'Father of the Bride Col... | 0 | [{'id': 35, 'name': 'Comedy'}] | NaN | 11862 | tt0113041 | en | Father of the Bride Part II | Just when George Banks has recovered from his ... | ... | 1995-02-10 | 76578911.0 | 106.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Just When His World Is Back To Normal... He's ... | Father of the Bride Part II | False | 5.7 | 173.0 |
5 rows × 24 columns
df.columns
Index(['adult', 'belongs_to_collection', 'budget', 'genres', 'homepage', 'id', 'imdb_id', 'original_language', 'original_title', 'overview', 'popularity', 'poster_path', 'production_companies', 'production_countries', 'release_date', 'revenue', 'runtime', 'spoken_languages', 'status', 'tagline', 'title', 'video', 'vote_average', 'vote_count'], dtype='object')
# on these rows metadata information is missing
df = df.drop([19730, 29503, 35587])
# sampled from links.csv file
links_small = pd.read_csv('./sampled_movie_dataset/links_small.csv')
links_small.head()
movieId | imdbId | tmdbId | |
---|---|---|---|
0 | 1 | 114709 | 862.0 |
1 | 2 | 113497 | 8844.0 |
2 | 3 | 113228 | 15602.0 |
3 | 4 | 114885 | 31357.0 |
4 | 5 | 113041 | 11862.0 |
# selecting tmdbId coloumn from links_small file
links_small = links_small[links_small['tmdbId'].notnull()]['tmdbId'].astype('int')
df['id'] = df['id'].astype('int')
sampled_md = df[df['id'].isin(links_small)]
sampled_md.shape
(9099, 24)
sampled_md['tagline'] = sampled_md['tagline'].fillna('')
sampled_md['description'] = sampled_md['overview'] + sampled_md['tagline']
sampled_md['description'] = sampled_md['description'].fillna('')
sampled_md = sampled_md.reset_index()
sampled_md.head()
index | adult | belongs_to_collection | budget | genres | homepage | id | imdb_id | original_language | original_title | ... | revenue | runtime | spoken_languages | status | tagline | title | video | vote_average | vote_count | description | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | False | {'id': 10194, 'name': 'Toy Story Collection', ... | 30000000 | [{'id': 16, 'name': 'Animation'}, {'id': 35, '... | http://toystory.disney.com/toy-story | 862 | tt0114709 | en | Toy Story | ... | 373554033.0 | 81.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Toy Story | False | 7.7 | 5415.0 | Led by Woody, Andy's toys live happily in his ... | |
1 | 1 | False | NaN | 65000000 | [{'id': 12, 'name': 'Adventure'}, {'id': 14, '... | NaN | 8844 | tt0113497 | en | Jumanji | ... | 262797249.0 | 104.0 | [{'iso_639_1': 'en', 'name': 'English'}, {'iso... | Released | Roll the dice and unleash the excitement! | Jumanji | False | 6.9 | 2413.0 | When siblings Judy and Peter discover an encha... |
2 | 2 | False | {'id': 119050, 'name': 'Grumpy Old Men Collect... | 0 | [{'id': 10749, 'name': 'Romance'}, {'id': 35, ... | NaN | 15602 | tt0113228 | en | Grumpier Old Men | ... | 0.0 | 101.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Still Yelling. Still Fighting. Still Ready for... | Grumpier Old Men | False | 6.5 | 92.0 | A family wedding reignites the ancient feud be... |
3 | 3 | False | NaN | 16000000 | [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... | NaN | 31357 | tt0114885 | en | Waiting to Exhale | ... | 81452156.0 | 127.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Friends are the people who let you be yourself... | Waiting to Exhale | False | 6.1 | 34.0 | Cheated on, mistreated and stepped on, the wom... |
4 | 4 | False | {'id': 96871, 'name': 'Father of the Bride Col... | 0 | [{'id': 35, 'name': 'Comedy'}] | NaN | 11862 | tt0113041 | en | Father of the Bride Part II | ... | 76578911.0 | 106.0 | [{'iso_639_1': 'en', 'name': 'English'}] | Released | Just When His World Is Back To Normal... He's ... | Father of the Bride Part II | False | 5.7 | 173.0 | Just when George Banks has recovered from his ... |
5 rows × 26 columns
indices = pd.Series(sampled_md.index, index=sampled_md['title'])
ind_gen = pd.Series(sampled_md.index, index=sampled_md['genres'])
We are going to use the ratings file to construct a bipartite graph. This file includes movies rated by different users on the scale of 1-5, rating of 1 implies very bad movie and 5 corresponds to a very good movie.
ratings_path = './sampled_movie_dataset/ratings_small.csv'
ratings_df = pd.read_csv(ratings_path)
ratings_df.head()
userId | movieId | rating | timestamp | |
---|---|---|---|---|
0 | 1 | 31 | 2.5 | 1260759144 |
1 | 1 | 1029 | 3.0 | 1260759179 |
2 | 1 | 1061 | 3.0 | 1260759182 |
3 | 1 | 1129 | 2.0 | 1260759185 |
4 | 1 | 1172 | 4.0 | 1260759205 |
# performs user and movie mappings
def node_mappings(path, index_col):
df = pd.read_csv(path, index_col=index_col)
mapping = {index: i for i, index in enumerate(df.index.unique())}
return mapping
user_mapping = node_mappings(ratings_path, index_col='userId')
movie_mapping = node_mappings(ratings_path, index_col='movieId')
m_id = ratings_df['movieId'].tolist()
# all unique movie_ids present inside ratings file
#m_id = list(set(m_id))
m_id = list(dict.fromkeys(m_id))
len(m_id)
9066
def convert_int(x):
try:
return int(x)
except:
return np.nan
id_map = pd.read_csv('./sampled_movie_dataset/links_small.csv')[['movieId', 'tmdbId']]
id_map['tmdbId'] = id_map['tmdbId'].apply(convert_int)
id_map.columns = ['movieId', 'id']
id_map.head()
movieId | id | |
---|---|---|
0 | 1 | 862.0 |
1 | 2 | 8844.0 |
2 | 3 | 15602.0 |
3 | 4 | 31357.0 |
4 | 5 | 11862.0 |
# tmbdid is same (of links_small) as of id in sampled_md
id_map = id_map.merge(sampled_md[['title', 'id']], on='id').set_index('title')
indices_map = id_map.set_index('id')
In this section, we first connect to the temporary ArangoDB on cloud using Oasis (our managed service for cloud). Once it is connected, we can then load movies metadata (movie title, genres, description, etc.) and ratings (ratings given by users to movies) information inside ArangoDB collections.
# get temporary credentials for ArangoDB on cloud
login = oasis.getTempCredentials(tutorialName="MovieRecommendations", credentialProvider="https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB")
# Connect to the temp database
# Please note that we use the python-arango driver as it has better support for ArangoSearch
movie_rec_db = oasis.connect_python_arango(login)
Requesting new temp credentials. Temp database ready to use.
Printing authentication credentials
# url to access the ArangoDB Web UI
print("https://"+login["hostname"]+":"+str(login["port"]))
print("Username: " + login["username"])
print("Password: " + login["password"])
print("Database: " + login["dbName"])
https://tutorials.arangodb.cloud:8529 Username: TUTgvm6ryraubpqrj6tvdvuh Password: TUTbiloykvn0m6zdz9fnnrjp Database: TUTcvypyquac6k972l58jgnc
# print 5 mappings of movieIds
list(movie_mapping.items())[:5]
[(31, 0), (1029, 1), (1061, 2), (1129, 3), (1172, 4)]
print("%d number of unique movie ids" %len(m_id))
9066 number of unique movie ids
# remove ids which dont have meta data information
def remove_movies(m_id):
no_metadata = []
for idx in range(len(m_id)):
tmdb_id = id_map.loc[id_map['movieId'] == m_id[idx]]
if tmdb_id.size == 0:
no_metadata.append(m_id[idx])
#print('No Meta data information at:', m_id[idx])
return no_metadata
no_metadata = remove_movies(m_id)
## remove ids which dont have meta data information
for element in no_metadata:
if element in m_id:
print("ids with no metadata information:",element)
m_id.remove(element)
ids with no metadata information: 720 ids with no metadata information: 7502 ids with no metadata information: 26587 ids with no metadata information: 90647 ids with no metadata information: 2851 ids with no metadata information: 52281 ids with no metadata information: 27611 ids with no metadata information: 32352 ids with no metadata information: 55207 ids with no metadata information: 73759 ids with no metadata information: 108583 ids with no metadata information: 162376 ids with no metadata information: 94466 ids with no metadata information: 77658 ids with no metadata information: 4207 ids with no metadata information: 7669 ids with no metadata information: 31193 ids with no metadata information: 106642 ids with no metadata information: 108979 ids with no metadata information: 150548 ids with no metadata information: 26693 ids with no metadata information: 108548 ids with no metadata information: 5069 ids with no metadata information: 69849 ids with no metadata information: 1133 ids with no metadata information: 26649 ids with no metadata information: 62336 ids with no metadata information: 85780 ids with no metadata information: 79299 ids with no metadata information: 100450 ids with no metadata information: 72781 ids with no metadata information: 108727 ids with no metadata information: 126106 ids with no metadata information: 77359 ids with no metadata information: 769 ids with no metadata information: 721 ids with no metadata information: 4568 ids with no metadata information: 96075 ids with no metadata information: 150856 ids with no metadata information: 27724 ids with no metadata information: 4051
print("Number of movies with metadata information:", len(m_id))
Number of movies with metadata information: 9025
# create new movie_mapping dict with only m_ids having metadata information
movie_mappings = {}
for idx, m in enumerate(m_id):
movie_mappings[m] = idx
In this section, we are going to create a "Movie" collection in ArangoDB where each document of the collection represents a unique movie along with its metadata information.
# create a new collection named "Movie" if it does not exist.
# This returns an API wrapper for "Movie" collection.
if not movie_rec_db.has_collection("Movie"):
movie_rec_db.create_collection("Movie", replication_factor=3)
batch = []
BATCH_SIZE = 128
batch_idx = 1
index = 0
movie_collection = movie_rec_db["Movie"]
# loading movies metadata information into ArangoDB's Movie collection
for idx in tqdm(range(len(m_id))):
insert_doc = {}
tmdb_id = id_map.loc[id_map['movieId'] == m_id[idx]]
if tmdb_id.size == 0:
print('No Meta data information at:', m_id[idx])
else:
tmdb_id = int(tmdb_id.iloc[:,1][0])
emb_id = "Movie/" + str(movie_mappings[m_id[idx]])
insert_doc["_id"] = emb_id
m_meta = sampled_md.loc[sampled_md['id'] == tmdb_id]
# adding movie metadata information
m_title = m_meta.iloc[0]['title']
m_poster = m_meta.iloc[0]['poster_path']
m_description = m_meta.iloc[0]['description']
m_language = m_meta.iloc[0]['original_language']
m_genre = m_meta.iloc[0]['genres']
m_genre = yaml.load(m_genre, Loader=yaml.BaseLoader)
genres = [g['name'] for g in m_genre]
insert_doc["movieId"] = m_id[idx]
insert_doc["mapped_movieId"] = movie_mappings[m_id[idx]]
insert_doc["tmdbId"] = tmdb_id
insert_doc['movie_title'] = m_title
insert_doc['description'] = m_description
insert_doc['genres'] = genres
insert_doc['language'] = m_language
if str(m_poster) == "nan":
insert_doc['poster_path'] = "No poster path available"
else:
insert_doc['poster_path'] = m_poster
batch.append(insert_doc)
index +=1
last_record = (idx == (len(m_id) - 1))
if index % BATCH_SIZE == 0:
#print("Inserting batch %d" % (batch_idx))
batch_idx += 1
movie_collection.import_bulk(batch)
batch = []
if last_record and len(batch) > 0:
print("Inserting batch the last batch!")
movie_collection.import_bulk(batch)
100%|██████████| 9025/9025 [00:47<00:00, 191.58it/s]
Inserting batch the last batch!
Since "The Movies Dataset" does not contain any metadata information about users, therefore we are just going to create a user collection with user_ids only.
# create a new collection named "Users" if it does not exist.
# This returns an API wrapper for "Users" collection.
if not movie_rec_db.has_collection("Users"):
movie_rec_db.create_collection("Users", replication_factor=3)
# Users has no side information
total_users = np.unique(ratings_df[['userId']].values.flatten()).shape[0]
print("Total number of Users:", total_users)
Total number of Users: 671
def populate_user_collection(total_users):
batch = []
BATCH_SIZE = 50
batch_idx = 1
index = 0
user_ids = list(user_mapping.keys())
user_collection = movie_rec_db["Users"]
for idx in tqdm(range(total_users)):
insert_doc = {}
insert_doc["_id"] = "Users/" + str(user_mapping[user_ids[idx]])
insert_doc["original_id"] = str(user_ids[idx])
batch.append(insert_doc)
index +=1
last_record = (idx == (total_users - 1))
if index % BATCH_SIZE == 0:
#print("Inserting batch %d" % (batch_idx))
batch_idx += 1
user_collection.import_bulk(batch)
batch = []
if last_record and len(batch) > 0:
print("Inserting batch the last batch!")
user_collection.import_bulk(batch)
populate_user_collection(total_users)
100%|██████████| 671/671 [00:01<00:00, 343.99it/s]
Inserting batch the last batch!
Here, we first create a Ratings (Edge) collection in ArangoDB and then populate this collection with edges of a bipartite graph. Each edge document in this collection will contain the information about _from (user) and _to (movie) node along with the rating data given by a user to that particular movie. Once the creation of this collection is completed, a bipartite graph (user and movie nodes) is formed in ArangoDB which can be viewed using ArangoDB Web UI under the Graphs->movie_rating_graph.
# create a new collection named "Ratings" if it does not exist.
# This returns an API wrapper for "Ratings" collection.
if not movie_rec_db.has_collection("Ratings"):
movie_rec_db.create_collection("Ratings", edge=True, replication_factor=3)
# defining graph schema
# create a new graph called movie_rating_graph in the temp database if it does not already exist.
if not movie_rec_db.has_graph("movie_rating_graph"):
movie_rec_db.create_graph('movie_rating_graph', smart=True)
# This returns and API wrapper for the above created graphs
movie_rating_graph = movie_rec_db.graph("movie_rating_graph")
# Create a new vertex collection named "Users" if it does not exist.
if not movie_rating_graph.has_vertex_collection("Users"):
movie_rating_graph.vertex_collection("Users")
# Create a new vertex collection named "Movie" if it does not exist.
if not movie_rating_graph.has_vertex_collection("Movie"):
movie_rating_graph.vertex_collection("Movie")
# creating edge definitions named "Ratings. This creates any missing
# collections and returns an API wrapper for "Ratings" edge collection.
if not movie_rating_graph.has_edge_definition("Ratings"):
Ratings = movie_rating_graph.create_edge_definition(
edge_collection='Ratings',
from_vertex_collections=['Users'],
to_vertex_collections=['Movie']
)
user_id, movie_id, ratings = ratings_df[['userId']].values.flatten(), ratings_df[['movieId']].values.flatten() , ratings_df[['rating']].values.flatten()
def create_ratings_graph(user_id, movie_id, ratings):
batch = []
BATCH_SIZE = 100
batch_idx = 1
index = 0
edge_collection = movie_rec_db["Ratings"]
for idx in tqdm(range(ratings.shape[0])):
# removing edges (movies) with no metatdata
if movie_id[idx] in no_metadata:
print('Removing edges with no metadata', movie_id[idx])
else:
insert_doc = {}
insert_doc = {"_id": "Ratings" + "/" + 'user-' + str(user_mapping[user_id[idx]]) + "-r-" + "movie-" + str(movie_mappings[movie_id[idx]]),
"_from": ("Users" + "/" + str(user_mapping[user_id[idx]])),
"_to": ("Movie" + "/" + str(movie_mappings[movie_id[idx]])),
"_rating": float(ratings[idx])}
batch.append(insert_doc)
index += 1
last_record = (idx == (ratings.shape[0] - 1))
if index % BATCH_SIZE == 0:
#print("Inserting batch %d" % (batch_idx))
batch_idx += 1
edge_collection.import_bulk(batch)
batch = []
if last_record and len(batch) > 0:
print("Inserting batch the last batch!")
edge_collection.import_bulk(batch)
create_ratings_graph(user_id, movie_id, ratings)
Viusalization of User-Movie-Ratings graph in ArangoDB's Web UI
So far we have seen how to construct a graph from mutiple csv files and load that graph into ArangoDB. The next step would be to export this graph from ArangoDB and construct a heterogeneous PyG graph.
# Get API wrappers for collections.
users = movie_rec_db.collection('Users')
movies = movie_rec_db.collection('Movie')
ratings_graph = movie_rec_db.collection('Ratings')
len(users), len(movies), len(ratings_graph)
(671, 9025, 99810)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Data handling of graphs in PyG: In order to construct edges of the graph in PyG we need to represent graph connectivity in COO format (edge_index) i.e with shape [2, num_edges]. Therefore, create_pyg_edges method can be seen as a generic function which reads the documents from edge collection (Ratings) and create edges (edge_index) in PyG using _from (src) and _to (dst) attributes of rating documents. Since the edge of the graph is accompanied with ratings information, hence, create_pyg_edges method is also going to read the _rating attribute from an edge_collection and store it in a PyG data object using edge_attr variable.
def create_pyg_edges(rating_docs):
src = []
dst = []
ratings = []
for doc in rating_docs:
_from = int(doc['_from'].split('/')[1])
_to = int(doc['_to'].split('/')[1])
src.append(_from)
dst.append(_to)
ratings.append(int(doc['_rating']))
edge_index = torch.tensor([src, dst])
edge_attr = torch.tensor(ratings)
return edge_index, edge_attr
edge_index, edge_label = create_pyg_edges(movie_rec_db.aql.execute('FOR doc IN Ratings RETURN doc'))
print(edge_index.shape)
print(edge_label.shape)
torch.Size([2, 99810]) torch.Size([99810])
So, in the above section we read the "Ratings” edge collection from ArangoDB and exported edges into PyG acceptable data format i.e edge_index and edge_label. Now, the next step would be to construct movie node features, in order to construct them, I have written the two following methods:
Sequence Encoder: This method takes two arguments, the first one is movie_docs with the help of which we can access metadata information of each movie stored inside the "Movie" collection. The second argument is model_name which takes a pretrained NLP (based on transformers) model from the SentenceTransformers library and generates text embeddings. In this blogpost, I am generating embeddings for movie titles and representing it as a movie node feature. However, instead of movie title we can also use movie description attribute to generate embeddings for movie nodes. Curious readers can try this out and see if results get better.
Genres Encoder: In this method we perform the one-hot-encodings of the genres present inside the Movie collection.
Once, the features are generated from sequence encoder and genre encoder method, we concatenate these two feature vectors to construct one feature vector for a movie node.
Note: This process of feature generation for movie nodes is inspired from PyG examples.
def SequenceEncoder(movie_docs , model_name=None):
movie_titles = [doc['movie_title'] for doc in movie_docs]
model = SentenceTransformer(model_name, device=device)
title_embeddings = model.encode(movie_titles, show_progress_bar=True,
convert_to_tensor=True, device=device)
return title_embeddings
def GenresEncoder(movie_docs):
gen = []
#sep = '|'
for doc in movie_docs:
gen.append(doc['genres'])
#genre = doc['movie_genres']
#gen.append(genre.split(sep))
# getting unique genres
unique_gen = set(list(itertools.chain(*gen)))
print("Number of unqiue genres we have:", unique_gen)
mapping = {g: i for i, g in enumerate(unique_gen)}
x = torch.zeros(len(gen), len(mapping))
for i, m_gen in enumerate(gen):
for genre in m_gen:
x[i, mapping[genre]] = 1
return x.to(device)
title_emb = SequenceEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'), model_name='all-MiniLM-L6-v2')
encoded_genres = GenresEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'))
print('Title Embeddings shape:', title_emb.shape)
print("Encoded Genres shape:", encoded_genres.shape)
Downloading: 0%| | 0.00/1.18k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/10.2k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/612 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/116 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/39.3k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/349 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/90.9M [00:00<?, ?B/s]
Downloading: 0%| | 0.00/53.0 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/112 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/466k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/350 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/13.2k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/232k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/190 [00:00<?, ?B/s]
Batches: 0%| | 0/283 [00:00<?, ?it/s]
Number of unqiue genres we have: {'Fantasy', 'Music', 'Mystery', 'Family', 'Thriller', 'Science Fiction', 'Horror', 'TV Movie', 'Animation', 'Documentary', 'History', 'Foreign', 'Action', 'Adventure', 'Crime', 'Western', 'Drama', 'Romance', 'Comedy', 'War'} Title Embeddings shape: torch.Size([9025, 384]) Encoded Genres shape: torch.Size([9025, 20])
# concat title and genres features of movies
movie_x = torch.cat((title_emb, encoded_genres), dim=-1)
print("Shape of the concatenated features:", movie_x.shape)
Shape of the concatenated features: torch.Size([9025, 404])
Heterogeneous graphs are those graphs which have different types of nodes and edges in the graph for e.g. Knowledge Graphs. The bipartite graph which we have stored in ArangoDB is also a heterogeneous graph since it constitutes two types of nodes in it i.e. user and movie nodes. Therefore, our next step would be to export the graph present inside ArangoDB to a PyG heterogeneous data object.
Since now we have PyG edges, labels and node feature matrix, the next step would be to add these tensors to PyG HeteroData object in order to construct a heterogeneous graph.
data = HeteroData()
data['user'].num_nodes = len(users) # Users do not have any features.
data['movie'].x = movie_x
data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_label
# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes
We can now convert data into an appropriate format for training a graph-based machine learning model:
Here, ToUndirected() transforms a directed graph into (the PyG representation of) an undirected graph, by adding reverse edges for all edges in the graph. Thus, future message passing is performed in both direction of all edges. The function may add reverse edge types to the heterogeneous graph, if necessary.
# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing.
data = ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label # Remove "reverse" label.
data = data.to(device)
# Perform a link-level split into training, validation, and test edges.
train_data, val_data, test_data = T.RandomLinkSplit(
num_val=0.1,
num_test=0.1,
neg_sampling_ratio=0.0,
edge_types=[('user', 'rates', 'movie')],
rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)
print('Train data:', train_data)
print('Val data:', val_data)
print('Test data', test_data)
Train data: HeteroData( user={ x=[671, 671] }, movie={ x=[9025, 404] }, (user, rates, movie)={ edge_index=[2, 79848], edge_label=[79848], edge_label_index=[2, 79848] }, (movie, rev_rates, user)={ edge_index=[2, 79848] } ) Val data: HeteroData( user={ x=[671, 671] }, movie={ x=[9025, 404] }, (user, rates, movie)={ edge_index=[2, 79848], edge_label=[9981], edge_label_index=[2, 9981] }, (movie, rev_rates, user)={ edge_index=[2, 79848] } ) Test data HeteroData( user={ x=[671, 671] }, movie={ x=[9025, 404] }, (user, rates, movie)={ edge_index=[2, 89829], edge_label=[9981], edge_label_index=[2, 9981] }, (movie, rev_rates, user)={ edge_index=[2, 89829] } )
Some Heterogeneous graph statistics present inside PyG
# Slicing edge label to get the corresponding split (hence this gives train split)
train_data['user', 'movie'].edge_label_index
tensor([[ 20, 595, 22, ..., 623, 74, 242], [ 180, 416, 485, ..., 8771, 119, 164]], device='cuda:0')
# fetaure matrix for all the node types
data.x_dict
{'movie': tensor([[-0.0174, 0.0269, -0.0453, ..., 0.0000, 0.0000, 0.0000], [ 0.0046, 0.0148, 0.0306, ..., 0.0000, 0.0000, 0.0000], [-0.0476, 0.0260, -0.0114, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0203, 0.0347, 0.0071, ..., 1.0000, 1.0000, 0.0000], [-0.0611, 0.0251, 0.0478, ..., 1.0000, 1.0000, 0.0000], [-0.0880, 0.0280, -0.0884, ..., 0.0000, 1.0000, 0.0000]], device='cuda:0'), 'user': tensor([[1., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 1., 0.], [0., 0., 0., ..., 0., 0., 1.]], device='cuda:0')}
data.x_dict['user'].shape, data.x_dict['movie'].shape
(torch.Size([671, 671]), torch.Size([9025, 404]))
# converting everything to dict
data.to_dict()
{('movie', 'rev_rates', 'user'): {'edge_index': tensor([[ 0, 1, 2, ..., 1327, 1329, 2941], [ 0, 0, 0, ..., 670, 670, 670]], device='cuda:0')}, ('user', 'rates', 'movie'): {'edge_index': tensor([[ 0, 0, 0, ..., 670, 670, 670], [ 0, 1, 2, ..., 1327, 1329, 2941]], device='cuda:0'), 'edge_label': tensor([2, 3, 3, ..., 4, 2, 3], device='cuda:0')}, 'movie': {'x': tensor([[-0.0174, 0.0269, -0.0453, ..., 0.0000, 0.0000, 0.0000], [ 0.0046, 0.0148, 0.0306, ..., 0.0000, 0.0000, 0.0000], [-0.0476, 0.0260, -0.0114, ..., 0.0000, 0.0000, 0.0000], ..., [ 0.0203, 0.0347, 0.0071, ..., 1.0000, 1.0000, 0.0000], [-0.0611, 0.0251, 0.0478, ..., 1.0000, 1.0000, 0.0000], [-0.0880, 0.0280, -0.0884, ..., 0.0000, 1.0000, 0.0000]], device='cuda:0')}, 'user': {'x': tensor([[1., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 1., 0.], [0., 0., 0., ..., 0., 0., 1.]], device='cuda:0')}}
data.edge_index_dict
{('movie', 'rev_rates', 'user'): tensor([[ 0, 1, 2, ..., 1327, 1329, 2941], [ 0, 0, 0, ..., 670, 670, 670]], device='cuda:0'), ('user', 'rates', 'movie'): tensor([[ 0, 0, 0, ..., 670, 670, 670], [ 0, 1, 2, ..., 1327, 1329, 2941]], device='cuda:0')}
data.edge_label_dict
{('user', 'rates', 'movie'): tensor([2, 3, 3, ..., 4, 2, 3], device='cuda:0')}
# different types of nodes in Hetero graph
node_types, edge_types = data.metadata()
print('Different types of nodes in graph:',node_types)
print('Different types of edges in graph:',edge_types)
Different types of nodes in graph: ['user', 'movie'] Different types of edges in graph: [('user', 'rates', 'movie'), ('movie', 'rev_rates', 'user')]
# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.
weight = torch.bincount(train_data['user', 'movie'].edge_label)
weight = weight.max() / weight
def weighted_mse_loss(pred, target, weight=None):
weight = 1. if weight is None else weight[target].to(pred.dtype)
return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()
The above code section shows how it's easy to transform your ArangoDB graph into a PyG heterogeneous graph object. Once the graph object is created we can apply heterogeneous graph learning (or heterogeneous graph neural networks) on this graph object in order to predict missing links between users and movies. Since, explain the working of Heterogeneous Graph Learning is beyond the scope of this article, but one can check this awesome documentation by PyG team to learn about it more.
The below diagram shows an example of a bipartite graph of user and movie nodes with the predicted links generated using Heterogeneous Graph Learning.
class GNNEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
super().__init__()
# these convolutions have been replicated to match the number of edge types
self.conv1 = SAGEConv((-1, -1), hidden_channels)
self.conv2 = SAGEConv((-1, -1), out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
class EdgeDecoder(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.lin1 = Linear(2 * hidden_channels, hidden_channels)
self.lin2 = Linear(hidden_channels, 1)
def forward(self, z_dict, edge_label_index):
row, col = edge_label_index
# concat user and movie embeddings
z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
# concatenated embeddings passed to linear layer
z = self.lin1(z).relu()
z = self.lin2(z)
return z.view(-1)
class Model(torch.nn.Module):
def __init__(self, hidden_channels):
super().__init__()
self.encoder = GNNEncoder(hidden_channels, hidden_channels)
self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
self.decoder = EdgeDecoder(hidden_channels)
def forward(self, x_dict, edge_index_dict, edge_label_index):
# z_dict contains dictionary of movie and user embeddings returned from GraphSage
z_dict = self.encoder(x_dict, edge_index_dict)
return self.decoder(z_dict, edge_label_index)
model = Model(hidden_channels=32).to(device)
# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
model.encoder(train_data.x_dict, train_data.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train():
model.train()
optimizer.zero_grad()
pred = model(train_data.x_dict, train_data.edge_index_dict,
train_data['user', 'movie'].edge_label_index)
target = train_data['user', 'movie'].edge_label
loss = weighted_mse_loss(pred, target, weight)
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test(data):
model.eval()
pred = model(data.x_dict, data.edge_index_dict,
data['user', 'movie'].edge_label_index)
pred = pred.clamp(min=0, max=5)
target = data['user', 'movie'].edge_label.float()
rmse = F.mse_loss(pred, target).sqrt()
return float(rmse)
for epoch in range(1, 300):
loss = train()
train_rmse = test(train_data)
val_rmse = test(val_data)
test_rmse = test(test_data)
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
Epoch: 001, Loss: 22.2784, Train: 3.5865, Val: 3.5914, Test: 3.6023 Epoch: 002, Loss: 20.3715, Train: 3.4158, Val: 3.4221, Test: 3.4327 Epoch: 003, Loss: 18.1650, Train: 3.0623, Val: 3.0708, Test: 3.0808 Epoch: 004, Loss: 14.6564, Train: 2.4797, Val: 2.4919, Test: 2.5009 Epoch: 005, Loss: 10.1307, Train: 1.5875, Val: 1.6027, Test: 1.6097 Epoch: 006, Loss: 6.5024, Train: 1.1967, Val: 1.1825, Test: 1.1837 Epoch: 007, Loss: 10.3720, Train: 1.1482, Val: 1.1384, Test: 1.1401 Epoch: 008, Loss: 9.5455, Train: 1.1540, Val: 1.1623, Test: 1.1668 Epoch: 009, Loss: 6.9450, Train: 1.4911, Val: 1.5062, Test: 1.5129 Epoch: 010, Loss: 6.2815, Train: 1.8320, Val: 1.8470, Test: 1.8547 Epoch: 011, Loss: 6.9489, Train: 2.0241, Val: 2.0382, Test: 2.0463 Epoch: 012, Loss: 7.6396, Train: 2.0684, Val: 2.0823, Test: 2.0904 Epoch: 013, Loss: 7.8156, Train: 1.9945, Val: 2.0088, Test: 2.0167 Epoch: 014, Loss: 7.4798, Train: 1.8276, Val: 1.8428, Test: 1.8503 Epoch: 015, Loss: 6.8426, Train: 1.5962, Val: 1.6124, Test: 1.6192 Epoch: 016, Loss: 6.2175, Train: 1.3502, Val: 1.3663, Test: 1.3721 Epoch: 017, Loss: 5.9493, Train: 1.1663, Val: 1.1802, Test: 1.1850 Epoch: 018, Loss: 6.1836, Train: 1.0887, Val: 1.0998, Test: 1.1039 Epoch: 019, Loss: 6.5598, Train: 1.0814, Val: 1.0927, Test: 1.0967 Epoch: 020, Loss: 6.5292, Train: 1.1288, Val: 1.1433, Test: 1.1477 Epoch: 021, Loss: 6.1160, Train: 1.2430, Val: 1.2601, Test: 1.2652 Epoch: 022, Loss: 5.7558, Train: 1.3919, Val: 1.4100, Test: 1.4155 Epoch: 023, Loss: 5.6795, Train: 1.5220, Val: 1.5401, Test: 1.5459 Epoch: 024, Loss: 5.7957, Train: 1.5992, Val: 1.6172, Test: 1.6231 Epoch: 025, Loss: 5.9129, Train: 1.6123, Val: 1.6304, Test: 1.6362 Epoch: 026, Loss: 5.9067, Train: 1.5638, Val: 1.5824, Test: 1.5879 Epoch: 027, Loss: 5.7577, Train: 1.4659, Val: 1.4852, Test: 1.4902 Epoch: 028, Loss: 5.5330, Train: 1.3408, Val: 1.3607, Test: 1.3652 Epoch: 029, Loss: 5.3456, Train: 1.2204, Val: 1.2404, Test: 1.2444 Epoch: 030, Loss: 5.2851, Train: 1.1356, Val: 1.1552, Test: 1.1588 Epoch: 031, Loss: 5.3319, Train: 1.0988, Val: 1.1180, Test: 1.1214 Epoch: 032, Loss: 5.3516, Train: 1.1052, Val: 1.1250, Test: 1.1281 Epoch: 033, Loss: 5.2442, Train: 1.1499, Val: 1.1707, Test: 1.1737 Epoch: 034, Loss: 5.0641, Train: 1.2228, Val: 1.2443, Test: 1.2470 Epoch: 035, Loss: 4.9293, Train: 1.3002, Val: 1.3220, Test: 1.3243 Epoch: 036, Loss: 4.8827, Train: 1.3554, Val: 1.3772, Test: 1.3791 Epoch: 037, Loss: 4.8734, Train: 1.3710, Val: 1.3930, Test: 1.3944 Epoch: 038, Loss: 4.8297, Train: 1.3428, Val: 1.3651, Test: 1.3659 Epoch: 039, Loss: 4.7220, Train: 1.2788, Val: 1.3013, Test: 1.3016 Epoch: 040, Loss: 4.5793, Train: 1.1984, Val: 1.2207, Test: 1.2208 Epoch: 041, Loss: 4.4668, Train: 1.1308, Val: 1.1522, Test: 1.1523 Epoch: 042, Loss: 4.4240, Train: 1.1038, Val: 1.1243, Test: 1.1242 Epoch: 043, Loss: 4.3879, Train: 1.1133, Val: 1.1339, Test: 1.1329 Epoch: 044, Loss: 4.2991, Train: 1.1472, Val: 1.1683, Test: 1.1659 Epoch: 045, Loss: 4.1880, Train: 1.1953, Val: 1.2169, Test: 1.2130 Epoch: 046, Loss: 4.1127, Train: 1.2381, Val: 1.2599, Test: 1.2547 Epoch: 047, Loss: 4.0769, Train: 1.2593, Val: 1.2809, Test: 1.2746 Epoch: 048, Loss: 4.0441, Train: 1.2535, Val: 1.2747, Test: 1.2676 Epoch: 049, Loss: 3.9916, Train: 1.2279, Val: 1.2482, Test: 1.2407 Epoch: 050, Loss: 3.9349, Train: 1.1997, Val: 1.2188, Test: 1.2112 Epoch: 051, Loss: 3.9034, Train: 1.1846, Val: 1.2024, Test: 1.1948 Epoch: 052, Loss: 3.8982, Train: 1.1876, Val: 1.2051, Test: 1.1968 Epoch: 053, Loss: 3.8860, Train: 1.2068, Val: 1.2248, Test: 1.2153 Epoch: 054, Loss: 3.8524, Train: 1.2356, Val: 1.2547, Test: 1.2436 Epoch: 055, Loss: 3.8228, Train: 1.2631, Val: 1.2831, Test: 1.2708 Epoch: 056, Loss: 3.8123, Train: 1.2763, Val: 1.2970, Test: 1.2837 Epoch: 057, Loss: 3.8048, Train: 1.2686, Val: 1.2897, Test: 1.2758 Epoch: 058, Loss: 3.7824, Train: 1.2455, Val: 1.2666, Test: 1.2527 Epoch: 059, Loss: 3.7555, Train: 1.2202, Val: 1.2413, Test: 1.2277 Epoch: 060, Loss: 3.7394, Train: 1.2058, Val: 1.2274, Test: 1.2136 Epoch: 061, Loss: 3.7242, Train: 1.2063, Val: 1.2290, Test: 1.2145 Epoch: 062, Loss: 3.6964, Train: 1.2177, Val: 1.2417, Test: 1.2263 Epoch: 063, Loss: 3.6685, Train: 1.2285, Val: 1.2538, Test: 1.2377 Epoch: 064, Loss: 3.6507, Train: 1.2273, Val: 1.2534, Test: 1.2370 Epoch: 065, Loss: 3.6328, Train: 1.2112, Val: 1.2378, Test: 1.2214 Epoch: 066, Loss: 3.6096, Train: 1.1880, Val: 1.2151, Test: 1.1989 Epoch: 067, Loss: 3.5903, Train: 1.1698, Val: 1.1973, Test: 1.1813 Epoch: 068, Loss: 3.5779, Train: 1.1636, Val: 1.1918, Test: 1.1758 Epoch: 069, Loss: 3.5635, Train: 1.1689, Val: 1.1980, Test: 1.1817 Epoch: 070, Loss: 3.5457, Train: 1.1787, Val: 1.2087, Test: 1.1922 Epoch: 071, Loss: 3.5313, Train: 1.1836, Val: 1.2144, Test: 1.1977 Epoch: 072, Loss: 3.5191, Train: 1.1780, Val: 1.2096, Test: 1.1929 Epoch: 073, Loss: 3.5040, Train: 1.1646, Val: 1.1966, Test: 1.1804 Epoch: 074, Loss: 3.4878, Train: 1.1514, Val: 1.1839, Test: 1.1680 Epoch: 075, Loss: 3.4749, Train: 1.1457, Val: 1.1787, Test: 1.1631 Epoch: 076, Loss: 3.4624, Train: 1.1495, Val: 1.1832, Test: 1.1675 Epoch: 077, Loss: 3.4475, Train: 1.1583, Val: 1.1928, Test: 1.1770 Epoch: 078, Loss: 3.4332, Train: 1.1650, Val: 1.2001, Test: 1.1843 Epoch: 079, Loss: 3.4208, Train: 1.1640, Val: 1.1996, Test: 1.1839 Epoch: 080, Loss: 3.4070, Train: 1.1562, Val: 1.1922, Test: 1.1768 Epoch: 081, Loss: 3.3919, Train: 1.1476, Val: 1.1840, Test: 1.1688 Epoch: 082, Loss: 3.3786, Train: 1.1440, Val: 1.1809, Test: 1.1658 Epoch: 083, Loss: 3.3661, Train: 1.1473, Val: 1.1849, Test: 1.1697 Epoch: 084, Loss: 3.3529, Train: 1.1545, Val: 1.1927, Test: 1.1773 Epoch: 085, Loss: 3.3407, Train: 1.1597, Val: 1.1985, Test: 1.1830 Epoch: 086, Loss: 3.3300, Train: 1.1589, Val: 1.1984, Test: 1.1829 Epoch: 087, Loss: 3.3189, Train: 1.1532, Val: 1.1931, Test: 1.1779 Epoch: 088, Loss: 3.3077, Train: 1.1474, Val: 1.1879, Test: 1.1728 Epoch: 089, Loss: 3.2978, Train: 1.1460, Val: 1.1871, Test: 1.1721 Epoch: 090, Loss: 3.2881, Train: 1.1495, Val: 1.1911, Test: 1.1761 Epoch: 091, Loss: 3.2783, Train: 1.1541, Val: 1.1963, Test: 1.1812 Epoch: 092, Loss: 3.2694, Train: 1.1553, Val: 1.1979, Test: 1.1828 Epoch: 093, Loss: 3.2610, Train: 1.1516, Val: 1.1946, Test: 1.1796 Epoch: 094, Loss: 3.2523, Train: 1.1458, Val: 1.1892, Test: 1.1745 Epoch: 095, Loss: 3.2443, Train: 1.1425, Val: 1.1864, Test: 1.1719 Epoch: 096, Loss: 3.2370, Train: 1.1436, Val: 1.1879, Test: 1.1734 Epoch: 097, Loss: 3.2296, Train: 1.1468, Val: 1.1915, Test: 1.1770 Epoch: 098, Loss: 3.2227, Train: 1.1482, Val: 1.1934, Test: 1.1788 Epoch: 099, Loss: 3.2163, Train: 1.1458, Val: 1.1916, Test: 1.1771 Epoch: 100, Loss: 3.2099, Train: 1.1415, Val: 1.1877, Test: 1.1733 Epoch: 101, Loss: 3.2038, Train: 1.1384, Val: 1.1852, Test: 1.1708 Epoch: 102, Loss: 3.1980, Train: 1.1385, Val: 1.1857, Test: 1.1713 Epoch: 103, Loss: 3.1923, Train: 1.1404, Val: 1.1881, Test: 1.1735 Epoch: 104, Loss: 3.1868, Train: 1.1413, Val: 1.1894, Test: 1.1746 Epoch: 105, Loss: 3.1816, Train: 1.1397, Val: 1.1883, Test: 1.1734 Epoch: 106, Loss: 3.1764, Train: 1.1369, Val: 1.1861, Test: 1.1711 Epoch: 107, Loss: 3.1714, Train: 1.1357, Val: 1.1852, Test: 1.1701 Epoch: 108, Loss: 3.1666, Train: 1.1365, Val: 1.1863, Test: 1.1710 Epoch: 109, Loss: 3.1619, Train: 1.1376, Val: 1.1876, Test: 1.1723 Epoch: 110, Loss: 3.1575, Train: 1.1376, Val: 1.1878, Test: 1.1724 Epoch: 111, Loss: 3.1532, Train: 1.1362, Val: 1.1866, Test: 1.1712 Epoch: 112, Loss: 3.1490, Train: 1.1347, Val: 1.1853, Test: 1.1698 Epoch: 113, Loss: 3.1449, Train: 1.1343, Val: 1.1849, Test: 1.1694 Epoch: 114, Loss: 3.1411, Train: 1.1349, Val: 1.1856, Test: 1.1699 Epoch: 115, Loss: 3.1374, Train: 1.1352, Val: 1.1860, Test: 1.1703 Epoch: 116, Loss: 3.1338, Train: 1.1348, Val: 1.1856, Test: 1.1699 Epoch: 117, Loss: 3.1304, Train: 1.1339, Val: 1.1848, Test: 1.1690 Epoch: 118, Loss: 3.1271, Train: 1.1333, Val: 1.1841, Test: 1.1683 Epoch: 119, Loss: 3.1238, Train: 1.1330, Val: 1.1839, Test: 1.1679 Epoch: 120, Loss: 3.1207, Train: 1.1328, Val: 1.1837, Test: 1.1676 Epoch: 121, Loss: 3.1177, Train: 1.1326, Val: 1.1835, Test: 1.1673 Epoch: 122, Loss: 3.1147, Train: 1.1322, Val: 1.1831, Test: 1.1668 Epoch: 123, Loss: 3.1119, Train: 1.1319, Val: 1.1829, Test: 1.1665 Epoch: 124, Loss: 3.1091, Train: 1.1317, Val: 1.1827, Test: 1.1662 Epoch: 125, Loss: 3.1064, Train: 1.1314, Val: 1.1824, Test: 1.1658 Epoch: 126, Loss: 3.1038, Train: 1.1311, Val: 1.1821, Test: 1.1654 Epoch: 127, Loss: 3.1013, Train: 1.1310, Val: 1.1820, Test: 1.1652 Epoch: 128, Loss: 3.0988, Train: 1.1311, Val: 1.1820, Test: 1.1652 Epoch: 129, Loss: 3.0963, Train: 1.1310, Val: 1.1818, Test: 1.1650 Epoch: 130, Loss: 3.0939, Train: 1.1303, Val: 1.1811, Test: 1.1643 Epoch: 131, Loss: 3.0916, Train: 1.1299, Val: 1.1807, Test: 1.1638 Epoch: 132, Loss: 3.0893, Train: 1.1303, Val: 1.1810, Test: 1.1642 Epoch: 133, Loss: 3.0871, Train: 1.1305, Val: 1.1812, Test: 1.1644 Epoch: 134, Loss: 3.0850, Train: 1.1304, Val: 1.1811, Test: 1.1643 Epoch: 135, Loss: 3.0828, Train: 1.1298, Val: 1.1805, Test: 1.1636 Epoch: 136, Loss: 3.0808, Train: 1.1296, Val: 1.1803, Test: 1.1634 Epoch: 137, Loss: 3.0787, Train: 1.1298, Val: 1.1803, Test: 1.1635 Epoch: 138, Loss: 3.0768, Train: 1.1295, Val: 1.1799, Test: 1.1633 Epoch: 139, Loss: 3.0747, Train: 1.1292, Val: 1.1796, Test: 1.1630 Epoch: 140, Loss: 3.0729, Train: 1.1294, Val: 1.1797, Test: 1.1632 Epoch: 141, Loss: 3.0709, Train: 1.1293, Val: 1.1795, Test: 1.1631 Epoch: 142, Loss: 3.0690, Train: 1.1285, Val: 1.1788, Test: 1.1624 Epoch: 143, Loss: 3.0672, Train: 1.1283, Val: 1.1786, Test: 1.1622 Epoch: 144, Loss: 3.0653, Train: 1.1285, Val: 1.1786, Test: 1.1624 Epoch: 145, Loss: 3.0636, Train: 1.1280, Val: 1.1781, Test: 1.1619 Epoch: 146, Loss: 3.0618, Train: 1.1278, Val: 1.1778, Test: 1.1616 Epoch: 147, Loss: 3.0600, Train: 1.1277, Val: 1.1776, Test: 1.1615 Epoch: 148, Loss: 3.0584, Train: 1.1272, Val: 1.1772, Test: 1.1610 Epoch: 149, Loss: 3.0566, Train: 1.1275, Val: 1.1773, Test: 1.1613 Epoch: 150, Loss: 3.0549, Train: 1.1276, Val: 1.1774, Test: 1.1614 Epoch: 151, Loss: 3.0533, Train: 1.1268, Val: 1.1765, Test: 1.1606 Epoch: 152, Loss: 3.0515, Train: 1.1266, Val: 1.1763, Test: 1.1604 Epoch: 153, Loss: 3.0499, Train: 1.1269, Val: 1.1765, Test: 1.1606 Epoch: 154, Loss: 3.0483, Train: 1.1269, Val: 1.1765, Test: 1.1607 Epoch: 155, Loss: 3.0466, Train: 1.1269, Val: 1.1762, Test: 1.1606 Epoch: 156, Loss: 3.0450, Train: 1.1256, Val: 1.1750, Test: 1.1595 Epoch: 157, Loss: 3.0434, Train: 1.1258, Val: 1.1751, Test: 1.1597 Epoch: 158, Loss: 3.0418, Train: 1.1268, Val: 1.1760, Test: 1.1606 Epoch: 159, Loss: 3.0401, Train: 1.1263, Val: 1.1755, Test: 1.1601 Epoch: 160, Loss: 3.0386, Train: 1.1256, Val: 1.1747, Test: 1.1594 Epoch: 161, Loss: 3.0369, Train: 1.1253, Val: 1.1744, Test: 1.1592 Epoch: 162, Loss: 3.0353, Train: 1.1260, Val: 1.1749, Test: 1.1598 Epoch: 163, Loss: 3.0337, Train: 1.1254, Val: 1.1744, Test: 1.1594 Epoch: 164, Loss: 3.0322, Train: 1.1254, Val: 1.1743, Test: 1.1593 Epoch: 165, Loss: 3.0305, Train: 1.1252, Val: 1.1740, Test: 1.1592 Epoch: 166, Loss: 3.0289, Train: 1.1247, Val: 1.1736, Test: 1.1588 Epoch: 167, Loss: 3.0273, Train: 1.1253, Val: 1.1740, Test: 1.1593 Epoch: 168, Loss: 3.0257, Train: 1.1248, Val: 1.1734, Test: 1.1588 Epoch: 169, Loss: 3.0241, Train: 1.1235, Val: 1.1722, Test: 1.1577 Epoch: 170, Loss: 3.0225, Train: 1.1240, Val: 1.1726, Test: 1.1583 Epoch: 171, Loss: 3.0207, Train: 1.1243, Val: 1.1728, Test: 1.1586 Epoch: 172, Loss: 3.0190, Train: 1.1237, Val: 1.1721, Test: 1.1581 Epoch: 173, Loss: 3.0174, Train: 1.1231, Val: 1.1715, Test: 1.1574 Epoch: 174, Loss: 3.0158, Train: 1.1228, Val: 1.1713, Test: 1.1572 Epoch: 175, Loss: 3.0141, Train: 1.1237, Val: 1.1721, Test: 1.1579 Epoch: 176, Loss: 3.0123, Train: 1.1230, Val: 1.1715, Test: 1.1572 Epoch: 177, Loss: 3.0106, Train: 1.1218, Val: 1.1703, Test: 1.1562 Epoch: 178, Loss: 3.0090, Train: 1.1230, Val: 1.1713, Test: 1.1571 Epoch: 179, Loss: 3.0073, Train: 1.1226, Val: 1.1710, Test: 1.1569 Epoch: 180, Loss: 3.0055, Train: 1.1221, Val: 1.1705, Test: 1.1564 Epoch: 181, Loss: 3.0037, Train: 1.1219, Val: 1.1702, Test: 1.1562 Epoch: 182, Loss: 3.0020, Train: 1.1213, Val: 1.1699, Test: 1.1558 Epoch: 183, Loss: 3.0005, Train: 1.1223, Val: 1.1707, Test: 1.1566 Epoch: 184, Loss: 2.9985, Train: 1.1213, Val: 1.1697, Test: 1.1557 Epoch: 185, Loss: 2.9969, Train: 1.1198, Val: 1.1684, Test: 1.1546 Epoch: 186, Loss: 2.9950, Train: 1.1222, Val: 1.1705, Test: 1.1568 Epoch: 187, Loss: 2.9931, Train: 1.1221, Val: 1.1703, Test: 1.1566 Epoch: 188, Loss: 2.9916, Train: 1.1186, Val: 1.1673, Test: 1.1536 Epoch: 189, Loss: 2.9896, Train: 1.1207, Val: 1.1693, Test: 1.1555 Epoch: 190, Loss: 2.9877, Train: 1.1224, Val: 1.1708, Test: 1.1571 Epoch: 191, Loss: 2.9860, Train: 1.1183, Val: 1.1670, Test: 1.1536 Epoch: 192, Loss: 2.9840, Train: 1.1190, Val: 1.1676, Test: 1.1542 Epoch: 193, Loss: 2.9820, Train: 1.1217, Val: 1.1701, Test: 1.1567 Epoch: 194, Loss: 2.9802, Train: 1.1189, Val: 1.1678, Test: 1.1544 Epoch: 195, Loss: 2.9784, Train: 1.1185, Val: 1.1674, Test: 1.1541 Epoch: 196, Loss: 2.9763, Train: 1.1199, Val: 1.1686, Test: 1.1553 Epoch: 197, Loss: 2.9746, Train: 1.1181, Val: 1.1671, Test: 1.1540 Epoch: 198, Loss: 2.9729, Train: 1.1193, Val: 1.1681, Test: 1.1552 Epoch: 199, Loss: 2.9705, Train: 1.1190, Val: 1.1677, Test: 1.1548 Epoch: 200, Loss: 2.9688, Train: 1.1160, Val: 1.1652, Test: 1.1524 Epoch: 201, Loss: 2.9665, Train: 1.1189, Val: 1.1679, Test: 1.1551 Epoch: 202, Loss: 2.9646, Train: 1.1196, Val: 1.1684, Test: 1.1557 Epoch: 203, Loss: 2.9624, Train: 1.1146, Val: 1.1638, Test: 1.1511 Epoch: 204, Loss: 2.9604, Train: 1.1167, Val: 1.1659, Test: 1.1534 Epoch: 205, Loss: 2.9582, Train: 1.1205, Val: 1.1695, Test: 1.1570 Epoch: 206, Loss: 2.9561, Train: 1.1143, Val: 1.1636, Test: 1.1512 Epoch: 207, Loss: 2.9540, Train: 1.1143, Val: 1.1638, Test: 1.1515 Epoch: 208, Loss: 2.9515, Train: 1.1202, Val: 1.1695, Test: 1.1571 Epoch: 209, Loss: 2.9495, Train: 1.1142, Val: 1.1639, Test: 1.1517 Epoch: 210, Loss: 2.9469, Train: 1.1127, Val: 1.1627, Test: 1.1506 Epoch: 211, Loss: 2.9446, Train: 1.1186, Val: 1.1683, Test: 1.1564 Epoch: 212, Loss: 2.9425, Train: 1.1146, Val: 1.1648, Test: 1.1528 Epoch: 213, Loss: 2.9398, Train: 1.1121, Val: 1.1625, Test: 1.1506 Epoch: 214, Loss: 2.9378, Train: 1.1167, Val: 1.1669, Test: 1.1551 Epoch: 215, Loss: 2.9350, Train: 1.1141, Val: 1.1648, Test: 1.1531 Epoch: 216, Loss: 2.9327, Train: 1.1130, Val: 1.1638, Test: 1.1520 Epoch: 217, Loss: 2.9303, Train: 1.1139, Val: 1.1648, Test: 1.1531 Epoch: 218, Loss: 2.9276, Train: 1.1140, Val: 1.1651, Test: 1.1535 Epoch: 219, Loss: 2.9251, Train: 1.1125, Val: 1.1638, Test: 1.1521 Epoch: 220, Loss: 2.9226, Train: 1.1126, Val: 1.1641, Test: 1.1525 Epoch: 221, Loss: 2.9200, Train: 1.1138, Val: 1.1654, Test: 1.1538 Epoch: 222, Loss: 2.9174, Train: 1.1100, Val: 1.1621, Test: 1.1505 Epoch: 223, Loss: 2.9149, Train: 1.1136, Val: 1.1655, Test: 1.1541 Epoch: 224, Loss: 2.9122, Train: 1.1098, Val: 1.1622, Test: 1.1511 Epoch: 225, Loss: 2.9097, Train: 1.1126, Val: 1.1648, Test: 1.1537 Epoch: 226, Loss: 2.9070, Train: 1.1086, Val: 1.1614, Test: 1.1506 Epoch: 227, Loss: 2.9042, Train: 1.1123, Val: 1.1649, Test: 1.1542 Epoch: 228, Loss: 2.9015, Train: 1.1081, Val: 1.1613, Test: 1.1508 Epoch: 229, Loss: 2.8986, Train: 1.1105, Val: 1.1637, Test: 1.1533 Epoch: 230, Loss: 2.8957, Train: 1.1089, Val: 1.1624, Test: 1.1520 Epoch: 231, Loss: 2.8930, Train: 1.1074, Val: 1.1614, Test: 1.1512 Epoch: 232, Loss: 2.8903, Train: 1.1118, Val: 1.1657, Test: 1.1553 Epoch: 233, Loss: 2.8875, Train: 1.1022, Val: 1.1571, Test: 1.1471 Epoch: 234, Loss: 2.8849, Train: 1.1168, Val: 1.1705, Test: 1.1604 Epoch: 235, Loss: 2.8829, Train: 1.0958, Val: 1.1518, Test: 1.1425 Epoch: 236, Loss: 2.8813, Train: 1.1216, Val: 1.1753, Test: 1.1656 Epoch: 237, Loss: 2.8795, Train: 1.0932, Val: 1.1500, Test: 1.1411 Epoch: 238, Loss: 2.8763, Train: 1.1173, Val: 1.1718, Test: 1.1626 Epoch: 239, Loss: 2.8717, Train: 1.1008, Val: 1.1573, Test: 1.1485 Epoch: 240, Loss: 2.8670, Train: 1.1037, Val: 1.1601, Test: 1.1514 Epoch: 241, Loss: 2.8633, Train: 1.1117, Val: 1.1675, Test: 1.1589 Epoch: 242, Loss: 2.8610, Train: 1.0945, Val: 1.1524, Test: 1.1444 Epoch: 243, Loss: 2.8593, Train: 1.1173, Val: 1.1730, Test: 1.1645 Epoch: 244, Loss: 2.8573, Train: 1.0916, Val: 1.1505, Test: 1.1426 Epoch: 245, Loss: 2.8540, Train: 1.1141, Val: 1.1708, Test: 1.1626 Epoch: 246, Loss: 2.8495, Train: 1.0970, Val: 1.1560, Test: 1.1481 Epoch: 247, Loss: 2.8448, Train: 1.1031, Val: 1.1618, Test: 1.1539 Epoch: 248, Loss: 2.8407, Train: 1.1057, Val: 1.1643, Test: 1.1565 Epoch: 249, Loss: 2.8376, Train: 1.0941, Val: 1.1544, Test: 1.1470 Epoch: 250, Loss: 2.8351, Train: 1.1120, Val: 1.1708, Test: 1.1631 Epoch: 251, Loss: 2.8328, Train: 1.0886, Val: 1.1506, Test: 1.1436 Epoch: 252, Loss: 2.8310, Train: 1.1161, Val: 1.1752, Test: 1.1676 Epoch: 253, Loss: 2.8284, Train: 1.0848, Val: 1.1480, Test: 1.1411 Epoch: 254, Loss: 2.8254, Train: 1.1155, Val: 1.1754, Test: 1.1679 Epoch: 255, Loss: 2.8218, Train: 1.0863, Val: 1.1501, Test: 1.1434 Epoch: 256, Loss: 2.8175, Train: 1.1097, Val: 1.1710, Test: 1.1639 Epoch: 257, Loss: 2.8126, Train: 1.0909, Val: 1.1549, Test: 1.1482 Epoch: 258, Loss: 2.8078, Train: 1.1013, Val: 1.1644, Test: 1.1575 Epoch: 259, Loss: 2.8035, Train: 1.0977, Val: 1.1618, Test: 1.1550 Epoch: 260, Loss: 2.7997, Train: 1.0929, Val: 1.1581, Test: 1.1514 Epoch: 261, Loss: 2.7965, Train: 1.1046, Val: 1.1687, Test: 1.1618 Epoch: 262, Loss: 2.7939, Train: 1.0832, Val: 1.1504, Test: 1.1442 Epoch: 263, Loss: 2.7930, Train: 1.1191, Val: 1.1820, Test: 1.1754 Epoch: 264, Loss: 2.7958, Train: 1.0671, Val: 1.1375, Test: 1.1320 Epoch: 265, Loss: 2.8024, Train: 1.1427, Val: 1.2035, Test: 1.1971 Epoch: 266, Loss: 2.8145, Train: 1.0597, Val: 1.1322, Test: 1.1270 Epoch: 267, Loss: 2.8085, Train: 1.1270, Val: 1.1903, Test: 1.1840 Epoch: 268, Loss: 2.7918, Train: 1.0841, Val: 1.1533, Test: 1.1477 Epoch: 269, Loss: 2.7708, Train: 1.0790, Val: 1.1493, Test: 1.1439 Epoch: 270, Loss: 2.7699, Train: 1.1260, Val: 1.1903, Test: 1.1845 Epoch: 271, Loss: 2.7811, Train: 1.0634, Val: 1.1368, Test: 1.1322 Epoch: 272, Loss: 2.7802, Train: 1.1179, Val: 1.1838, Test: 1.1784 Epoch: 273, Loss: 2.7679, Train: 1.0834, Val: 1.1544, Test: 1.1495 Epoch: 274, Loss: 2.7532, Train: 1.0793, Val: 1.1512, Test: 1.1465 Epoch: 275, Loss: 2.7512, Train: 1.1157, Val: 1.1828, Test: 1.1778 Epoch: 276, Loss: 2.7572, Train: 1.0656, Val: 1.1402, Test: 1.1363 Epoch: 277, Loss: 2.7568, Train: 1.1139, Val: 1.1819, Test: 1.1773 Epoch: 278, Loss: 2.7494, Train: 1.0779, Val: 1.1513, Test: 1.1475 Epoch: 279, Loss: 2.7376, Train: 1.0854, Val: 1.1582, Test: 1.1543 Epoch: 280, Loss: 2.7317, Train: 1.1024, Val: 1.1732, Test: 1.1691 Epoch: 281, Loss: 2.7322, Train: 1.0685, Val: 1.1448, Test: 1.1415 Epoch: 282, Loss: 2.7340, Train: 1.1148, Val: 1.1848, Test: 1.1807 Epoch: 283, Loss: 2.7350, Train: 1.0653, Val: 1.1431, Test: 1.1400 Epoch: 284, Loss: 2.7298, Train: 1.1078, Val: 1.1796, Test: 1.1758 Epoch: 285, Loss: 2.7233, Train: 1.0731, Val: 1.1504, Test: 1.1474 Epoch: 286, Loss: 2.7150, Train: 1.0907, Val: 1.1658, Test: 1.1624 Epoch: 287, Loss: 2.7087, Train: 1.0865, Val: 1.1626, Test: 1.1595 Epoch: 288, Loss: 2.7046, Train: 1.0777, Val: 1.1556, Test: 1.1529 Epoch: 289, Loss: 2.7025, Train: 1.0996, Val: 1.1747, Test: 1.1717 Epoch: 290, Loss: 2.7024, Train: 1.0639, Val: 1.1450, Test: 1.1428 Epoch: 291, Loss: 2.7051, Train: 1.1197, Val: 1.1929, Test: 1.1898 Epoch: 292, Loss: 2.7145, Train: 1.0482, Val: 1.1331, Test: 1.1316 Epoch: 293, Loss: 2.7254, Train: 1.1466, Val: 1.2169, Test: 1.2140 Epoch: 294, Loss: 2.7471, Train: 1.0410, Val: 1.1281, Test: 1.1270 Epoch: 295, Loss: 2.7388, Train: 1.1320, Val: 1.2047, Test: 1.2022 Epoch: 296, Loss: 2.7221, Train: 1.0616, Val: 1.1451, Test: 1.1438 Epoch: 297, Loss: 2.6866, Train: 1.0735, Val: 1.1552, Test: 1.1540 Epoch: 298, Loss: 2.6756, Train: 1.1128, Val: 1.1888, Test: 1.1873 Epoch: 299, Loss: 2.6897, Train: 1.0465, Val: 1.1337, Test: 1.1334
Here, we will predict the new links between users and movies with the trained model. We will only select movies for a user whose predicting ratings are equal to 5.
total_users = len(users)
total_movies = len(movies)
movie_recs = []
for user_id in tqdm(range(0, total_users)):
user_row = torch.tensor([user_id] * total_movies)
all_movie_ids = torch.arange(total_movies)
edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
pred = model(data.x_dict, data.edge_index_dict,
edge_label_index)
pred = pred.clamp(min=0, max=5)
# we will only select movies for the user where the predicting rating is =5
rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]]
movie_recs.append({'user': user_id, 'rec_movies': top_ten_recs})
100%|██████████| 671/671 [00:19<00:00, 34.08it/s]
Storing predictions back to ArangoDB
We will create a new collection name "Recommendation_Inferences" in ArangoDB to store movie recommendations for each of the user.
# create a new collection named "Recommendation_Inferences" if it does not exist.
# This returns an API wrapper for "Recommendation_Inferences" collection.
if not movie_rec_db.has_collection("Recommendation_Inferences"):
movie_rec_db.create_collection("Recommendation_Inferences", edge=True, replication_factor=3)
def populate_movies_recommendations(movie_recs):
batch = []
BATCH_SIZE = 100
batch_idx = 1
index = 0
rec_collection = movie_rec_db["Recommendation_Inferences"]
for idx in tqdm(range(total_users)):
insert_doc = {}
to_insert = []
user_id = movie_recs[idx]['user']
movie_ids = movie_recs[idx]['rec_movies']
for m_id in movie_ids:
insert_doc = {
"_from": ("Users" + "/" + str(user_id)),
"_to": ("Movie" + "/" + str(m_id)),
"_rating": 5}
to_insert.append(insert_doc)
batch.extend(to_insert)
index +=1
last_record = (idx == (total_users - 1))
if len(batch) > BATCH_SIZE:
rec_collection.import_bulk(batch)
batch = []
if last_record and len(batch) > 0:
print("Inserting batch the last batch!")
rec_collection.import_bulk(batch)
populate_movies_recommendations(movie_recs)
100%|██████████| 671/671 [00:09<00:00, 74.09it/s]
Inserting batch the last batch!
Working of HETEROGENEOUS GRAPH LEARNING.
The heterogeneous graph neural network code is inspired from PyG example hetero_link_pred.py