Integrate ArangoDB with PyTorch Geometric to build Recommender Systems using Graph Neural Networks¶

Screenshot 2022-03-09 at 12.15.11.png

Introduction¶

With the advancement of Internet technology, recommendations acted as an integral part of the internet whether we are watching some videos on YouTube, enjoying movies with friends on Netflix, getting new friends suggestion on facebook, online advertisements or Amazon product suggestions all of these scenarios uses recommendation systems at their backend to suggest users an items that could interest them. The main objective of recommendation systems is to suggest a user suitable item based on his/her past behavior of engagements with the items and user's side information (e.g. age, gender, demographics, etc). Therefore, recommender systems (RS) are used to search for the small number of relevant items that match a user's personalized interests. Here, I am assuming a recommender system with no-cold start problem.

On the other hand, Graph Neural Networks (GNNs) based methods have shown a great success for tackling the recommendation problems when compared to the traditional recommendation technique like collaborative filtering (CF). There are several use cases out there in industry which use GNNs to solve their recommendation problems. For example, Food Discovery with Uber Eats Uber uses the power of GNNs to suggest to its users the dishes, restaurants, and cuisines they might like next. To make these recommendations Uber eats uses the GraphSAGE algorithm because of its inductive nature and the power to scale up-to billion of nodes. Another interesting application is PinSage ( a random-walk Graph Convolutional Network capable of learning embeddings for nodes in web-scale graphs containing billions of objects) an algorithm developed by Pinterest company to perform visual recommendations.

Therefore, in this blogpost, we will together build a complete movie recommendation application using ArangoDB (open-source native multi-model graph database) and PyTorch Geometric (library built upon PyTorch to easily write and train custom GNNs). In graph machine learning (or graph neural networks) we solve mainly three types of tasks i.e. node classification, link prediction and graph classification. In this blogpost, we are going to tackle the challenge of building movie recommendation application by transforming it into the task of link prediction. The aim of the link prediction task is to predict whether there is an edge existing between two given nodes. First, we are going to build a bipartite graph of user and movie nodes where an edge between a user and movie represents how much rating (lies between 1-5) the user has given to that movie. Then, our goal is to predict missing links between a user and the movies which he has not watched yet. The missing links are computed with the graph neural networks where we predict the ratings for all the unseen movies. At the end, only those movies are recommended to a user where the predicting ratings are equal to 5. Below figure depicts the bipartite graph of user and movie nodes:

Screenshot 2022-03-24 at 13.36.25.png

Outline¶

We are going to use The Movies Dataset from Kaggle which contains the metadata for all 45,000 movies listed in the Full MovieLens Dataset. With the help of metadata information we are going to prepare features for our movie nodes which then can be used as an input for our GNNs. The focus of this blogpost revolves around the following key content:

  • Loading the data present in csv files to ArangoDB (in graph format).
  • Converting the graph present inside the ArangoDB into a PyTorch Geometric (PyG) data object.
  • Train GNN model on this PyG data object.
  • Generate predictions and store them back to ArangoDB.

N.B. Before you run this notebook!!!¶

If you are running this notebook on Google Colab, please make sure to enable hardware acceleration using either a GPU or a TPU. If it is run with CPU-only enabled, training GNNs will take an incredibly long time! Hardware acceleration can be enabled by navigating to Runtime -> Change Runtime. This will present you with a popup, where you can select an appropriate Hardware Accelerator

Let's Start with Installing Necessary Libraries¶

In [ ]:
%%capture
!git clone -b oasis_connector --single-branch https://github.com/arangodb/interactive_tutorials.git
!git clone -b movie-data-source --single-branch https://github.com/arangodb/interactive_tutorials.git movie_data_source
!rsync -av interactive_tutorials/ ./ --exclude=.git
!chmod -R 755 ./tools
!pip install pyarango
!pip install "python-arango>=5.0"
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install sentence-transformers

Setting up the dataset¶

We are going to use the sampled version of The Movies Dataset . This dataset contains mainly three csv files:

  1. movies_metadata.csv: Contains information on 45,000 movies featured in the Full MovieLens dataset. Features include posters, backdrops, budget, revenue, release dates, languages, production countries and companies.
  2. links_small.csv: Contains the TMDB and IMDB IDs of a small subset of 9,000 movies of the Full Dataset.
  3. ratings_small.csv: The subset of 100,000 ratings from 700 users on 9,000 movies.
In [ ]:
# unzip movies dataset zip file
!unzip ./movie_data_source/sampled_movie_dataset.zip
Archive:  ./movie_data_source/sampled_movie_dataset.zip
   creating: sampled_movie_dataset/
  inflating: sampled_movie_dataset/links_small.csv  
  inflating: __MACOSX/sampled_movie_dataset/._links_small.csv  
  inflating: sampled_movie_dataset/.DS_Store  
  inflating: __MACOSX/sampled_movie_dataset/._.DS_Store  
  inflating: sampled_movie_dataset/movies_metadata.csv  
  inflating: __MACOSX/sampled_movie_dataset/._movies_metadata.csv  
  inflating: sampled_movie_dataset/ratings_small.csv  
  inflating: __MACOSX/sampled_movie_dataset/._ratings_small.csv  

Import Libraries¶

In [ ]:
import pandas as pd
from arango import ArangoClient
from tqdm import tqdm
import numpy as np
import itertools
import requests
import sys
import oasis
from arango import ArangoClient

import torch
import torch.nn.functional as F
from torch.nn import Linear
from arango import ArangoClient
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
from torch_geometric.transforms import RandomLinkSplit, ToUndirected
from sentence_transformers import SentenceTransformer
from torch_geometric.data import HeteroData
import yaml

Loading the data present in csv files to ArangoDB¶

In this section, we will read the data from multiple csv files and construct a graph out of it. The graph which we are going to construct is a bipartite graph having user nodes on the one side and movie nodes on the other side of the graph. In addition to this user and movies nodes of the graph will also be accompanied by their corresponding attributes features. Next, we are going to store this generated graph inside the ArangoDB.

In [ ]:
metadata_path = './sampled_movie_dataset/movies_metadata.csv'
df = pd.read_csv(metadata_path)
/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py:2882: DtypeWarning: Columns (10) have mixed types.Specify dtype option on import or set low_memory=False.
  exec(code_obj, self.user_global_ns, self.user_ns)
In [ ]:
df.head()
Out[ ]:
adult belongs_to_collection budget genres homepage id imdb_id original_language original_title overview ... release_date revenue runtime spoken_languages status tagline title video vote_average vote_count
0 False {'id': 10194, 'name': 'Toy Story Collection', ... 30000000 [{'id': 16, 'name': 'Animation'}, {'id': 35, '... http://toystory.disney.com/toy-story 862 tt0114709 en Toy Story Led by Woody, Andy's toys live happily in his ... ... 1995-10-30 373554033.0 81.0 [{'iso_639_1': 'en', 'name': 'English'}] Released NaN Toy Story False 7.7 5415.0
1 False NaN 65000000 [{'id': 12, 'name': 'Adventure'}, {'id': 14, '... NaN 8844 tt0113497 en Jumanji When siblings Judy and Peter discover an encha... ... 1995-12-15 262797249.0 104.0 [{'iso_639_1': 'en', 'name': 'English'}, {'iso... Released Roll the dice and unleash the excitement! Jumanji False 6.9 2413.0
2 False {'id': 119050, 'name': 'Grumpy Old Men Collect... 0 [{'id': 10749, 'name': 'Romance'}, {'id': 35, ... NaN 15602 tt0113228 en Grumpier Old Men A family wedding reignites the ancient feud be... ... 1995-12-22 0.0 101.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Still Yelling. Still Fighting. Still Ready for... Grumpier Old Men False 6.5 92.0
3 False NaN 16000000 [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... NaN 31357 tt0114885 en Waiting to Exhale Cheated on, mistreated and stepped on, the wom... ... 1995-12-22 81452156.0 127.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Friends are the people who let you be yourself... Waiting to Exhale False 6.1 34.0
4 False {'id': 96871, 'name': 'Father of the Bride Col... 0 [{'id': 35, 'name': 'Comedy'}] NaN 11862 tt0113041 en Father of the Bride Part II Just when George Banks has recovered from his ... ... 1995-02-10 76578911.0 106.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Just When His World Is Back To Normal... He's ... Father of the Bride Part II False 5.7 173.0

5 rows × 24 columns

In [ ]:
df.columns
Out[ ]:
Index(['adult', 'belongs_to_collection', 'budget', 'genres', 'homepage', 'id',
       'imdb_id', 'original_language', 'original_title', 'overview',
       'popularity', 'poster_path', 'production_companies',
       'production_countries', 'release_date', 'revenue', 'runtime',
       'spoken_languages', 'status', 'tagline', 'title', 'video',
       'vote_average', 'vote_count'],
      dtype='object')
In [ ]:
# on these rows metadata information is missing
df = df.drop([19730, 29503, 35587])
In [ ]:
# sampled from links.csv file
links_small = pd.read_csv('./sampled_movie_dataset/links_small.csv')
In [ ]:
links_small.head()
Out[ ]:
movieId imdbId tmdbId
0 1 114709 862.0
1 2 113497 8844.0
2 3 113228 15602.0
3 4 114885 31357.0
4 5 113041 11862.0
In [ ]:
# selecting tmdbId coloumn from links_small file
links_small = links_small[links_small['tmdbId'].notnull()]['tmdbId'].astype('int')
In [ ]:
df['id'] = df['id'].astype('int')
In [ ]:
sampled_md = df[df['id'].isin(links_small)]
sampled_md.shape
Out[ ]:
(9099, 24)
In [ ]:
sampled_md['tagline'] = sampled_md['tagline'].fillna('')
sampled_md['description'] = sampled_md['overview'] + sampled_md['tagline']
sampled_md['description'] = sampled_md['description'].fillna('')
In [ ]:
sampled_md = sampled_md.reset_index()
In [ ]:
sampled_md.head()
Out[ ]:
index adult belongs_to_collection budget genres homepage id imdb_id original_language original_title ... revenue runtime spoken_languages status tagline title video vote_average vote_count description
0 0 False {'id': 10194, 'name': 'Toy Story Collection', ... 30000000 [{'id': 16, 'name': 'Animation'}, {'id': 35, '... http://toystory.disney.com/toy-story 862 tt0114709 en Toy Story ... 373554033.0 81.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Toy Story False 7.7 5415.0 Led by Woody, Andy's toys live happily in his ...
1 1 False NaN 65000000 [{'id': 12, 'name': 'Adventure'}, {'id': 14, '... NaN 8844 tt0113497 en Jumanji ... 262797249.0 104.0 [{'iso_639_1': 'en', 'name': 'English'}, {'iso... Released Roll the dice and unleash the excitement! Jumanji False 6.9 2413.0 When siblings Judy and Peter discover an encha...
2 2 False {'id': 119050, 'name': 'Grumpy Old Men Collect... 0 [{'id': 10749, 'name': 'Romance'}, {'id': 35, ... NaN 15602 tt0113228 en Grumpier Old Men ... 0.0 101.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Still Yelling. Still Fighting. Still Ready for... Grumpier Old Men False 6.5 92.0 A family wedding reignites the ancient feud be...
3 3 False NaN 16000000 [{'id': 35, 'name': 'Comedy'}, {'id': 18, 'nam... NaN 31357 tt0114885 en Waiting to Exhale ... 81452156.0 127.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Friends are the people who let you be yourself... Waiting to Exhale False 6.1 34.0 Cheated on, mistreated and stepped on, the wom...
4 4 False {'id': 96871, 'name': 'Father of the Bride Col... 0 [{'id': 35, 'name': 'Comedy'}] NaN 11862 tt0113041 en Father of the Bride Part II ... 76578911.0 106.0 [{'iso_639_1': 'en', 'name': 'English'}] Released Just When His World Is Back To Normal... He's ... Father of the Bride Part II False 5.7 173.0 Just when George Banks has recovered from his ...

5 rows × 26 columns

In [ ]:
indices = pd.Series(sampled_md.index, index=sampled_md['title'])
In [ ]:
ind_gen = pd.Series(sampled_md.index, index=sampled_md['genres'])

Let's Load Ratings File¶

We are going to use the ratings file to construct a bipartite graph. This file includes movies rated by different users on the scale of 1-5, rating of 1 implies very bad movie and 5 corresponds to a very good movie.

In [ ]:
ratings_path = './sampled_movie_dataset/ratings_small.csv'
In [ ]:
ratings_df = pd.read_csv(ratings_path)
ratings_df.head()
Out[ ]:
userId movieId rating timestamp
0 1 31 2.5 1260759144
1 1 1029 3.0 1260759179
2 1 1061 3.0 1260759182
3 1 1129 2.0 1260759185
4 1 1172 4.0 1260759205
In [ ]:
# performs user and movie mappings
def node_mappings(path, index_col):
    df = pd.read_csv(path, index_col=index_col)
    mapping = {index: i for i, index in enumerate(df.index.unique())}

    return mapping
In [ ]:
user_mapping = node_mappings(ratings_path, index_col='userId')
In [ ]:
movie_mapping = node_mappings(ratings_path, index_col='movieId')
In [ ]:
m_id = ratings_df['movieId'].tolist()
In [ ]:
# all unique movie_ids present inside ratings file
#m_id = list(set(m_id))
m_id = list(dict.fromkeys(m_id))
len(m_id)
Out[ ]:
9066
In [ ]:
def convert_int(x):
    try:
        return int(x)
    except:
        return np.nan
In [ ]:
id_map = pd.read_csv('./sampled_movie_dataset/links_small.csv')[['movieId', 'tmdbId']]
In [ ]:
id_map['tmdbId'] = id_map['tmdbId'].apply(convert_int)
In [ ]:
id_map.columns = ['movieId', 'id']
In [ ]:
id_map.head()
Out[ ]:
movieId id
0 1 862.0
1 2 8844.0
2 3 15602.0
3 4 31357.0
4 5 11862.0
In [ ]:
# tmbdid is same (of links_small) as of id in sampled_md
id_map = id_map.merge(sampled_md[['title', 'id']], on='id').set_index('title')
In [ ]:
indices_map = id_map.set_index('id')

ArangoDB Setup¶

In this section, we first connect to the temporary ArangoDB on cloud using Oasis (our managed service for cloud). Once it is connected, we can then load movies metadata (movie title, genres, description, etc.) and ratings (ratings given by users to movies) information inside ArangoDB collections.

In [ ]:
# get temporary credentials for ArangoDB on cloud
login = oasis.getTempCredentials(tutorialName="MovieRecommendations", credentialProvider="https://tutorials.arangodb.cloud:8529/_db/_system/tutorialDB/tutorialDB")

# Connect to the temp database
# Please note that we use the python-arango driver as it has better support for ArangoSearch 
movie_rec_db = oasis.connect_python_arango(login)
Requesting new temp credentials.
Temp database ready to use.

Printing authentication credentials

In [ ]:
# url to access the ArangoDB Web UI
print("https://"+login["hostname"]+":"+str(login["port"]))
print("Username: " + login["username"])
print("Password: " + login["password"])
print("Database: " + login["dbName"])
https://tutorials.arangodb.cloud:8529
Username: TUTgvm6ryraubpqrj6tvdvuh
Password: TUTbiloykvn0m6zdz9fnnrjp
Database: TUTcvypyquac6k972l58jgnc
In [ ]:
# print 5 mappings of movieIds
list(movie_mapping.items())[:5]
Out[ ]:
[(31, 0), (1029, 1), (1061, 2), (1129, 3), (1172, 4)]
In [ ]:
print("%d number of unique movie ids" %len(m_id))
9066 number of unique movie ids
In [ ]:
# remove ids which dont have meta data information

def remove_movies(m_id):
    no_metadata = []
    for idx in range(len(m_id)):
        tmdb_id = id_map.loc[id_map['movieId'] == m_id[idx]]
  
        if tmdb_id.size == 0:
            no_metadata.append(m_id[idx])
            #print('No Meta data information at:', m_id[idx])
    return no_metadata
In [ ]:
no_metadata = remove_movies(m_id)
In [ ]:
## remove ids which dont have meta data information
for element in no_metadata:
    if element in m_id:
        print("ids with no metadata information:",element)
        m_id.remove(element)
ids with no metadata information: 720
ids with no metadata information: 7502
ids with no metadata information: 26587
ids with no metadata information: 90647
ids with no metadata information: 2851
ids with no metadata information: 52281
ids with no metadata information: 27611
ids with no metadata information: 32352
ids with no metadata information: 55207
ids with no metadata information: 73759
ids with no metadata information: 108583
ids with no metadata information: 162376
ids with no metadata information: 94466
ids with no metadata information: 77658
ids with no metadata information: 4207
ids with no metadata information: 7669
ids with no metadata information: 31193
ids with no metadata information: 106642
ids with no metadata information: 108979
ids with no metadata information: 150548
ids with no metadata information: 26693
ids with no metadata information: 108548
ids with no metadata information: 5069
ids with no metadata information: 69849
ids with no metadata information: 1133
ids with no metadata information: 26649
ids with no metadata information: 62336
ids with no metadata information: 85780
ids with no metadata information: 79299
ids with no metadata information: 100450
ids with no metadata information: 72781
ids with no metadata information: 108727
ids with no metadata information: 126106
ids with no metadata information: 77359
ids with no metadata information: 769
ids with no metadata information: 721
ids with no metadata information: 4568
ids with no metadata information: 96075
ids with no metadata information: 150856
ids with no metadata information: 27724
ids with no metadata information: 4051
In [ ]:
print("Number of movies with metadata information:", len(m_id))
Number of movies with metadata information: 9025
In [ ]:
# create new movie_mapping dict with only m_ids having metadata information
movie_mappings = {}
for idx, m in enumerate(m_id):
    movie_mappings[m] = idx

Loading movies metadata into ArangoDB's Movie collection¶

In this section, we are going to create a "Movie" collection in ArangoDB where each document of the collection represents a unique movie along with its metadata information.

In [ ]:
# create a new collection named "Movie" if it does not exist.
# This returns an API wrapper for "Movie" collection.
if not movie_rec_db.has_collection("Movie"):
    movie_rec_db.create_collection("Movie", replication_factor=3)
In [ ]:
batch = []
BATCH_SIZE = 128
batch_idx = 1
index = 0
movie_collection = movie_rec_db["Movie"]
In [ ]:
# loading movies metadata information into ArangoDB's Movie collection
for idx in tqdm(range(len(m_id))):
    insert_doc = {}
    tmdb_id = id_map.loc[id_map['movieId'] == m_id[idx]]
  
    if tmdb_id.size == 0:
        print('No Meta data information at:', m_id[idx])
        

    else:
        tmdb_id = int(tmdb_id.iloc[:,1][0])
        emb_id = "Movie/" + str(movie_mappings[m_id[idx]])
        insert_doc["_id"] = emb_id
        m_meta = sampled_md.loc[sampled_md['id'] == tmdb_id]
        # adding movie metadata information 
        m_title = m_meta.iloc[0]['title']
        m_poster = m_meta.iloc[0]['poster_path']
        m_description = m_meta.iloc[0]['description']
        m_language = m_meta.iloc[0]['original_language']
        m_genre = m_meta.iloc[0]['genres']
        m_genre = yaml.load(m_genre, Loader=yaml.BaseLoader)
        genres = [g['name'] for g in m_genre]
         
        insert_doc["movieId"] = m_id[idx]
        insert_doc["mapped_movieId"] = movie_mappings[m_id[idx]]
        insert_doc["tmdbId"] = tmdb_id
        insert_doc['movie_title'] = m_title
     
        insert_doc['description'] = m_description
        insert_doc['genres'] = genres
        insert_doc['language'] = m_language
        
        if str(m_poster) == "nan":
            insert_doc['poster_path'] = "No poster path available"
        else:
            insert_doc['poster_path'] = m_poster
        
        batch.append(insert_doc)
        index +=1
        last_record = (idx == (len(m_id) - 1))
        if index % BATCH_SIZE == 0:
            #print("Inserting batch %d" % (batch_idx))
            batch_idx += 1
            movie_collection.import_bulk(batch)
            batch = []   
        if last_record and len(batch) > 0:
            print("Inserting batch the last batch!")
            movie_collection.import_bulk(batch)
100%|██████████| 9025/9025 [00:47<00:00, 191.58it/s]
Inserting batch the last batch!

Creating User Collection in ArangoDB¶

Since "The Movies Dataset" does not contain any metadata information about users, therefore we are just going to create a user collection with user_ids only.

In [ ]:
# create a new collection named "Users" if it does not exist.
# This returns an API wrapper for "Users" collection.
if not movie_rec_db.has_collection("Users"):
    movie_rec_db.create_collection("Users", replication_factor=3)
In [ ]:
# Users has no side information
total_users = np.unique(ratings_df[['userId']].values.flatten()).shape[0]
print("Total number of Users:", total_users)
Total number of Users: 671
In [ ]:
def populate_user_collection(total_users):
    batch = []
    BATCH_SIZE = 50
    batch_idx = 1
    index = 0
    user_ids = list(user_mapping.keys())
    user_collection = movie_rec_db["Users"]
    for idx in tqdm(range(total_users)):
        insert_doc = {}

        insert_doc["_id"] = "Users/" + str(user_mapping[user_ids[idx]])
        insert_doc["original_id"] = str(user_ids[idx])
        
        batch.append(insert_doc)
        index +=1
        last_record = (idx == (total_users - 1))
        if index % BATCH_SIZE == 0:
            #print("Inserting batch %d" % (batch_idx))
            batch_idx += 1
            user_collection.import_bulk(batch)
            batch = []   
        if last_record and len(batch) > 0:
            print("Inserting batch the last batch!")
            user_collection.import_bulk(batch)
In [ ]:
populate_user_collection(total_users)
100%|██████████| 671/671 [00:01<00:00, 343.99it/s]
Inserting batch the last batch!

Creating Ratings (Edge) Collection¶

Here, we first create a Ratings (Edge) collection in ArangoDB and then populate this collection with edges of a bipartite graph. Each edge document in this collection will contain the information about _from (user) and _to (movie) node along with the rating data given by a user to that particular movie. Once the creation of this collection is completed, a bipartite graph (user and movie nodes) is formed in ArangoDB which can be viewed using ArangoDB Web UI under the Graphs->movie_rating_graph.

In [ ]:
# create a new collection named "Ratings" if it does not exist.
# This returns an API wrapper for "Ratings" collection.
if not movie_rec_db.has_collection("Ratings"):
    movie_rec_db.create_collection("Ratings", edge=True, replication_factor=3)
In [ ]:
# defining graph schema

# create a new graph called movie_rating_graph in the temp database if it does not already exist.
if not movie_rec_db.has_graph("movie_rating_graph"):
    movie_rec_db.create_graph('movie_rating_graph', smart=True)

# This returns and API wrapper for the above created graphs
movie_rating_graph = movie_rec_db.graph("movie_rating_graph")
In [ ]:
# Create a new vertex collection named "Users" if it does not exist.
if not movie_rating_graph.has_vertex_collection("Users"):
    movie_rating_graph.vertex_collection("Users")
In [ ]:
# Create a new vertex collection named "Movie" if it does not exist.
if not movie_rating_graph.has_vertex_collection("Movie"):
    movie_rating_graph.vertex_collection("Movie")
In [ ]:
# creating edge definitions named "Ratings. This creates any missing
# collections and returns an API wrapper for "Ratings" edge collection.
if not movie_rating_graph.has_edge_definition("Ratings"):
    Ratings = movie_rating_graph.create_edge_definition(
        edge_collection='Ratings',
        from_vertex_collections=['Users'],
        to_vertex_collections=['Movie']
    )
In [ ]:
user_id, movie_id, ratings = ratings_df[['userId']].values.flatten(), ratings_df[['movieId']].values.flatten() , ratings_df[['rating']].values.flatten()
In [ ]:
def create_ratings_graph(user_id, movie_id, ratings):
    batch = []
    BATCH_SIZE = 100
    batch_idx = 1
    index = 0
    edge_collection = movie_rec_db["Ratings"]
    for idx in tqdm(range(ratings.shape[0])):
        
        # removing edges (movies) with no metatdata
        if movie_id[idx] in no_metadata:
            print('Removing edges with no metadata', movie_id[idx])
            
        else:
            insert_doc = {}
            insert_doc = {"_id":    "Ratings" + "/" + 'user-' + str(user_mapping[user_id[idx]]) + "-r-" + "movie-" + str(movie_mappings[movie_id[idx]]), 
                          "_from":  ("Users" + "/" + str(user_mapping[user_id[idx]])),
                          "_to":    ("Movie" + "/" + str(movie_mappings[movie_id[idx]])),
                          "_rating": float(ratings[idx])}

            batch.append(insert_doc)
            index += 1
            last_record = (idx == (ratings.shape[0] - 1))

            if index % BATCH_SIZE == 0:
                #print("Inserting batch %d" % (batch_idx))
                batch_idx += 1
                edge_collection.import_bulk(batch)
                batch = []
            if last_record and len(batch) > 0:
                print("Inserting batch the last batch!")
                edge_collection.import_bulk(batch)
In [ ]:
create_ratings_graph(user_id, movie_id, ratings)

Viusalization of User-Movie-Ratings graph in ArangoDB's Web UI

graph.png

Converting the Graph present inside the ArangoDB into a PyTorch Geometric (PyG) data object¶

So far we have seen how to construct a graph from mutiple csv files and load that graph into ArangoDB. The next step would be to export this graph from ArangoDB and construct a heterogeneous PyG graph.

In [ ]:
# Get API wrappers for collections.
users = movie_rec_db.collection('Users')
movies = movie_rec_db.collection('Movie')
ratings_graph = movie_rec_db.collection('Ratings')
In [ ]:
len(users), len(movies), len(ratings_graph)
Out[ ]:
(671, 9025, 99810)
In [ ]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Load edges from Ratings collection in ArangoDB and export them to PyG data format.¶

Data handling of graphs in PyG: In order to construct edges of the graph in PyG we need to represent graph connectivity in COO format (edge_index) i.e with shape [2, num_edges]. Therefore, create_pyg_edges method can be seen as a generic function which reads the documents from edge collection (Ratings) and create edges (edge_index) in PyG using _from (src) and _to (dst) attributes of rating documents. Since the edge of the graph is accompanied with ratings information, hence, create_pyg_edges method is also going to read the _rating attribute from an edge_collection and store it in a PyG data object using edge_attr variable.

In [ ]:
def create_pyg_edges(rating_docs):
    src = []
    dst = []
    ratings = []
    for doc in rating_docs:
        _from = int(doc['_from'].split('/')[1])
        _to   = int(doc['_to'].split('/')[1])
         
        src.append(_from)
        dst.append(_to)
        ratings.append(int(doc['_rating']))
        
    edge_index = torch.tensor([src, dst])
    edge_attr = torch.tensor(ratings)

    return edge_index, edge_attr 
In [ ]:
edge_index, edge_label = create_pyg_edges(movie_rec_db.aql.execute('FOR doc IN Ratings RETURN doc'))
In [ ]:
print(edge_index.shape)
print(edge_label.shape)
torch.Size([2, 99810])
torch.Size([99810])

Load nodes from Ratings collection in ArangoDB and export them PyG data format.¶

So, in the above section we read the "Ratings” edge collection from ArangoDB and exported edges into PyG acceptable data format i.e edge_index and edge_label. Now, the next step would be to construct movie node features, in order to construct them, I have written the two following methods:

  1. Sequence Encoder: This method takes two arguments, the first one is movie_docs with the help of which we can access metadata information of each movie stored inside the "Movie" collection. The second argument is model_name which takes a pretrained NLP (based on transformers) model from the SentenceTransformers library and generates text embeddings. In this blogpost, I am generating embeddings for movie titles and representing it as a movie node feature. However, instead of movie title we can also use movie description attribute to generate embeddings for movie nodes. Curious readers can try this out and see if results get better.

  2. Genres Encoder: In this method we perform the one-hot-encodings of the genres present inside the Movie collection.

Once, the features are generated from sequence encoder and genre encoder method, we concatenate these two feature vectors to construct one feature vector for a movie node.

Note: This process of feature generation for movie nodes is inspired from PyG examples.

In [ ]:
def SequenceEncoder(movie_docs , model_name=None):
    movie_titles = [doc['movie_title'] for doc in movie_docs]
    model = SentenceTransformer(model_name, device=device)
    title_embeddings = model.encode(movie_titles, show_progress_bar=True,
                              convert_to_tensor=True, device=device)
    
    return title_embeddings
In [ ]:
def GenresEncoder(movie_docs):
    gen = []
    #sep = '|'
    for doc in movie_docs:
        gen.append(doc['genres'])
        #genre = doc['movie_genres']
        #gen.append(genre.split(sep))
    
    # getting unique genres
    unique_gen = set(list(itertools.chain(*gen)))
    print("Number of unqiue genres we have:", unique_gen)
    
    mapping = {g: i for i, g in enumerate(unique_gen)}
    x = torch.zeros(len(gen), len(mapping))
    for i, m_gen in enumerate(gen):
        for genre in m_gen:
            x[i, mapping[genre]] = 1
    return x.to(device)
In [ ]:
title_emb = SequenceEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'), model_name='all-MiniLM-L6-v2')
encoded_genres = GenresEncoder(movie_rec_db.aql.execute('FOR doc IN Movie RETURN doc'))
print('Title Embeddings shape:', title_emb.shape)
print("Encoded Genres shape:", encoded_genres.shape)
Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/10.2k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]
Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]
Batches:   0%|          | 0/283 [00:00<?, ?it/s]
Number of unqiue genres we have: {'Fantasy', 'Music', 'Mystery', 'Family', 'Thriller', 'Science Fiction', 'Horror', 'TV Movie', 'Animation', 'Documentary', 'History', 'Foreign', 'Action', 'Adventure', 'Crime', 'Western', 'Drama', 'Romance', 'Comedy', 'War'}
Title Embeddings shape: torch.Size([9025, 384])
Encoded Genres shape: torch.Size([9025, 20])
In [ ]:
# concat title and genres features of movies
movie_x = torch.cat((title_emb, encoded_genres), dim=-1)
print("Shape of the concatenated features:", movie_x.shape)
Shape of the concatenated features: torch.Size([9025, 404])

Creating PyG Heterogeneous Graph¶

Heterogeneous graphs are those graphs which have different types of nodes and edges in the graph for e.g. Knowledge Graphs. The bipartite graph which we have stored in ArangoDB is also a heterogeneous graph since it constitutes two types of nodes in it i.e. user and movie nodes. Therefore, our next step would be to export the graph present inside ArangoDB to a PyG heterogeneous data object.

Since now we have PyG edges, labels and node feature matrix, the next step would be to add these tensors to PyG HeteroData object in order to construct a heterogeneous graph.

In [ ]:
data = HeteroData()
In [ ]:
data['user'].num_nodes = len(users)  # Users do not have any features.
data['movie'].x = movie_x
data['user', 'rates', 'movie'].edge_index = edge_index
data['user', 'rates', 'movie'].edge_label = edge_label
In [ ]:
# Add user node features for message passing:
data['user'].x = torch.eye(data['user'].num_nodes, device=device)
del data['user'].num_nodes

We can now convert data into an appropriate format for training a graph-based machine learning model:

Here, ToUndirected() transforms a directed graph into (the PyG representation of) an undirected graph, by adding reverse edges for all edges in the graph. Thus, future message passing is performed in both direction of all edges. The function may add reverse edge types to the heterogeneous graph, if necessary.

In [ ]:
# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing.
data = ToUndirected()(data)
del data['movie', 'rev_rates', 'user'].edge_label  # Remove "reverse" label.
In [ ]:
data = data.to(device)
In [ ]:
# Perform a link-level split into training, validation, and test edges.
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    edge_types=[('user', 'rates', 'movie')],
    rev_edge_types=[('movie', 'rev_rates', 'user')],
)(data)
In [ ]:
print('Train data:', train_data)
print('Val data:', val_data)
print('Test data', test_data)
Train data: HeteroData(
  user={ x=[671, 671] },
  movie={ x=[9025, 404] },
  (user, rates, movie)={
    edge_index=[2, 79848],
    edge_label=[79848],
    edge_label_index=[2, 79848]
  },
  (movie, rev_rates, user)={ edge_index=[2, 79848] }
)
Val data: HeteroData(
  user={ x=[671, 671] },
  movie={ x=[9025, 404] },
  (user, rates, movie)={
    edge_index=[2, 79848],
    edge_label=[9981],
    edge_label_index=[2, 9981]
  },
  (movie, rev_rates, user)={ edge_index=[2, 79848] }
)
Test data HeteroData(
  user={ x=[671, 671] },
  movie={ x=[9025, 404] },
  (user, rates, movie)={
    edge_index=[2, 89829],
    edge_label=[9981],
    edge_label_index=[2, 9981]
  },
  (movie, rev_rates, user)={ edge_index=[2, 89829] }
)

Some Heterogeneous graph statistics present inside PyG

In [ ]:
# Slicing edge label to get the corresponding split (hence this gives train split)
train_data['user', 'movie'].edge_label_index
Out[ ]:
tensor([[  20,  595,   22,  ...,  623,   74,  242],
        [ 180,  416,  485,  ..., 8771,  119,  164]], device='cuda:0')
In [ ]:
# fetaure matrix for all the node types
data.x_dict
Out[ ]:
{'movie': tensor([[-0.0174,  0.0269, -0.0453,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0046,  0.0148,  0.0306,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0476,  0.0260, -0.0114,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0203,  0.0347,  0.0071,  ...,  1.0000,  1.0000,  0.0000],
         [-0.0611,  0.0251,  0.0478,  ...,  1.0000,  1.0000,  0.0000],
         [-0.0880,  0.0280, -0.0884,  ...,  0.0000,  1.0000,  0.0000]],
        device='cuda:0'), 'user': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 1., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')}
In [ ]:
data.x_dict['user'].shape, data.x_dict['movie'].shape
Out[ ]:
(torch.Size([671, 671]), torch.Size([9025, 404]))
In [ ]:
# converting everything to dict
data.to_dict()
Out[ ]:
{('movie',
  'rev_rates',
  'user'): {'edge_index': tensor([[   0,    1,    2,  ..., 1327, 1329, 2941],
          [   0,    0,    0,  ...,  670,  670,  670]], device='cuda:0')},
 ('user',
  'rates',
  'movie'): {'edge_index': tensor([[   0,    0,    0,  ...,  670,  670,  670],
          [   0,    1,    2,  ..., 1327, 1329, 2941]], device='cuda:0'),
  'edge_label': tensor([2, 3, 3,  ..., 4, 2, 3], device='cuda:0')},
 'movie': {'x': tensor([[-0.0174,  0.0269, -0.0453,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0046,  0.0148,  0.0306,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0476,  0.0260, -0.0114,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0203,  0.0347,  0.0071,  ...,  1.0000,  1.0000,  0.0000],
          [-0.0611,  0.0251,  0.0478,  ...,  1.0000,  1.0000,  0.0000],
          [-0.0880,  0.0280, -0.0884,  ...,  0.0000,  1.0000,  0.0000]],
         device='cuda:0')},
 'user': {'x': tensor([[1., 0., 0.,  ..., 0., 0., 0.],
          [0., 1., 0.,  ..., 0., 0., 0.],
          [0., 0., 1.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 1., 0., 0.],
          [0., 0., 0.,  ..., 0., 1., 0.],
          [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')}}
In [ ]:
data.edge_index_dict
Out[ ]:
{('movie',
  'rev_rates',
  'user'): tensor([[   0,    1,    2,  ..., 1327, 1329, 2941],
         [   0,    0,    0,  ...,  670,  670,  670]], device='cuda:0'),
 ('user',
  'rates',
  'movie'): tensor([[   0,    0,    0,  ...,  670,  670,  670],
         [   0,    1,    2,  ..., 1327, 1329, 2941]], device='cuda:0')}
In [ ]:
data.edge_label_dict
Out[ ]:
{('user', 'rates', 'movie'): tensor([2, 3, 3,  ..., 4, 2, 3], device='cuda:0')}
In [ ]:
# different types of nodes in Hetero graph
node_types, edge_types = data.metadata()
print('Different types of nodes in graph:',node_types)
print('Different types of edges in graph:',edge_types)
Different types of nodes in graph: ['user', 'movie']
Different types of edges in graph: [('user', 'rates', 'movie'), ('movie', 'rev_rates', 'user')]
In [ ]:
# We have an unbalanced dataset with many labels for rating 3 and 4, and very
# few for 0 and 1. Therefore we use a weighted MSE loss.
weight = torch.bincount(train_data['user', 'movie'].edge_label)
weight = weight.max() / weight
In [ ]:
def weighted_mse_loss(pred, target, weight=None):
    weight = 1. if weight is None else weight[target].to(pred.dtype)
    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()

The above code section shows how it's easy to transform your ArangoDB graph into a PyG heterogeneous graph object. Once the graph object is created we can apply heterogeneous graph learning (or heterogeneous graph neural networks) on this graph object in order to predict missing links between users and movies. Since, explain the working of Heterogeneous Graph Learning is beyond the scope of this article, but one can check this awesome documentation by PyG team to learn about it more.

The below diagram shows an example of a bipartite graph of user and movie nodes with the predicted links generated using Heterogeneous Graph Learning.

Screenshot 2022-03-25 at 14.34.10.png

In [ ]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        # these convolutions have been replicated to match the number of edge types
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
In [ ]:
class EdgeDecoder(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.lin1 = Linear(2 * hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, 1)
        
    def forward(self, z_dict, edge_label_index):
        row, col = edge_label_index
        # concat user and movie embeddings
        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)
        # concatenated embeddings passed to linear layer
        z = self.lin1(z).relu()
        z = self.lin2(z)
        return z.view(-1)
In [ ]:
class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.encoder = GNNEncoder(hidden_channels, hidden_channels)
        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')
        self.decoder = EdgeDecoder(hidden_channels)

    def forward(self, x_dict, edge_index_dict, edge_label_index):
        # z_dict contains dictionary of movie and user embeddings returned from GraphSage
        z_dict = self.encoder(x_dict, edge_index_dict)
        return self.decoder(z_dict, edge_label_index)
In [ ]:
model = Model(hidden_channels=32).to(device)
In [ ]:
# Due to lazy initialization, we need to run one model step so the number
# of parameters can be inferred:
with torch.no_grad():
    model.encoder(train_data.x_dict, train_data.edge_index_dict)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
In [ ]:
def train():
    model.train()
    optimizer.zero_grad()
    pred = model(train_data.x_dict, train_data.edge_index_dict,
                 train_data['user', 'movie'].edge_label_index)
    target = train_data['user', 'movie'].edge_label
    loss = weighted_mse_loss(pred, target, weight)
    loss.backward()
    optimizer.step()
    return float(loss)
In [ ]:
@torch.no_grad()
def test(data):
    model.eval()
    pred = model(data.x_dict, data.edge_index_dict,
                 data['user', 'movie'].edge_label_index)
    pred = pred.clamp(min=0, max=5)
    target = data['user', 'movie'].edge_label.float()
    rmse = F.mse_loss(pred, target).sqrt()
    return float(rmse)
In [ ]:
for epoch in range(1, 300):
    loss = train()
    train_rmse = test(train_data)
    val_rmse = test(val_data)
    test_rmse = test(test_data)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '
          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')
Epoch: 001, Loss: 22.2784, Train: 3.5865, Val: 3.5914, Test: 3.6023
Epoch: 002, Loss: 20.3715, Train: 3.4158, Val: 3.4221, Test: 3.4327
Epoch: 003, Loss: 18.1650, Train: 3.0623, Val: 3.0708, Test: 3.0808
Epoch: 004, Loss: 14.6564, Train: 2.4797, Val: 2.4919, Test: 2.5009
Epoch: 005, Loss: 10.1307, Train: 1.5875, Val: 1.6027, Test: 1.6097
Epoch: 006, Loss: 6.5024, Train: 1.1967, Val: 1.1825, Test: 1.1837
Epoch: 007, Loss: 10.3720, Train: 1.1482, Val: 1.1384, Test: 1.1401
Epoch: 008, Loss: 9.5455, Train: 1.1540, Val: 1.1623, Test: 1.1668
Epoch: 009, Loss: 6.9450, Train: 1.4911, Val: 1.5062, Test: 1.5129
Epoch: 010, Loss: 6.2815, Train: 1.8320, Val: 1.8470, Test: 1.8547
Epoch: 011, Loss: 6.9489, Train: 2.0241, Val: 2.0382, Test: 2.0463
Epoch: 012, Loss: 7.6396, Train: 2.0684, Val: 2.0823, Test: 2.0904
Epoch: 013, Loss: 7.8156, Train: 1.9945, Val: 2.0088, Test: 2.0167
Epoch: 014, Loss: 7.4798, Train: 1.8276, Val: 1.8428, Test: 1.8503
Epoch: 015, Loss: 6.8426, Train: 1.5962, Val: 1.6124, Test: 1.6192
Epoch: 016, Loss: 6.2175, Train: 1.3502, Val: 1.3663, Test: 1.3721
Epoch: 017, Loss: 5.9493, Train: 1.1663, Val: 1.1802, Test: 1.1850
Epoch: 018, Loss: 6.1836, Train: 1.0887, Val: 1.0998, Test: 1.1039
Epoch: 019, Loss: 6.5598, Train: 1.0814, Val: 1.0927, Test: 1.0967
Epoch: 020, Loss: 6.5292, Train: 1.1288, Val: 1.1433, Test: 1.1477
Epoch: 021, Loss: 6.1160, Train: 1.2430, Val: 1.2601, Test: 1.2652
Epoch: 022, Loss: 5.7558, Train: 1.3919, Val: 1.4100, Test: 1.4155
Epoch: 023, Loss: 5.6795, Train: 1.5220, Val: 1.5401, Test: 1.5459
Epoch: 024, Loss: 5.7957, Train: 1.5992, Val: 1.6172, Test: 1.6231
Epoch: 025, Loss: 5.9129, Train: 1.6123, Val: 1.6304, Test: 1.6362
Epoch: 026, Loss: 5.9067, Train: 1.5638, Val: 1.5824, Test: 1.5879
Epoch: 027, Loss: 5.7577, Train: 1.4659, Val: 1.4852, Test: 1.4902
Epoch: 028, Loss: 5.5330, Train: 1.3408, Val: 1.3607, Test: 1.3652
Epoch: 029, Loss: 5.3456, Train: 1.2204, Val: 1.2404, Test: 1.2444
Epoch: 030, Loss: 5.2851, Train: 1.1356, Val: 1.1552, Test: 1.1588
Epoch: 031, Loss: 5.3319, Train: 1.0988, Val: 1.1180, Test: 1.1214
Epoch: 032, Loss: 5.3516, Train: 1.1052, Val: 1.1250, Test: 1.1281
Epoch: 033, Loss: 5.2442, Train: 1.1499, Val: 1.1707, Test: 1.1737
Epoch: 034, Loss: 5.0641, Train: 1.2228, Val: 1.2443, Test: 1.2470
Epoch: 035, Loss: 4.9293, Train: 1.3002, Val: 1.3220, Test: 1.3243
Epoch: 036, Loss: 4.8827, Train: 1.3554, Val: 1.3772, Test: 1.3791
Epoch: 037, Loss: 4.8734, Train: 1.3710, Val: 1.3930, Test: 1.3944
Epoch: 038, Loss: 4.8297, Train: 1.3428, Val: 1.3651, Test: 1.3659
Epoch: 039, Loss: 4.7220, Train: 1.2788, Val: 1.3013, Test: 1.3016
Epoch: 040, Loss: 4.5793, Train: 1.1984, Val: 1.2207, Test: 1.2208
Epoch: 041, Loss: 4.4668, Train: 1.1308, Val: 1.1522, Test: 1.1523
Epoch: 042, Loss: 4.4240, Train: 1.1038, Val: 1.1243, Test: 1.1242
Epoch: 043, Loss: 4.3879, Train: 1.1133, Val: 1.1339, Test: 1.1329
Epoch: 044, Loss: 4.2991, Train: 1.1472, Val: 1.1683, Test: 1.1659
Epoch: 045, Loss: 4.1880, Train: 1.1953, Val: 1.2169, Test: 1.2130
Epoch: 046, Loss: 4.1127, Train: 1.2381, Val: 1.2599, Test: 1.2547
Epoch: 047, Loss: 4.0769, Train: 1.2593, Val: 1.2809, Test: 1.2746
Epoch: 048, Loss: 4.0441, Train: 1.2535, Val: 1.2747, Test: 1.2676
Epoch: 049, Loss: 3.9916, Train: 1.2279, Val: 1.2482, Test: 1.2407
Epoch: 050, Loss: 3.9349, Train: 1.1997, Val: 1.2188, Test: 1.2112
Epoch: 051, Loss: 3.9034, Train: 1.1846, Val: 1.2024, Test: 1.1948
Epoch: 052, Loss: 3.8982, Train: 1.1876, Val: 1.2051, Test: 1.1968
Epoch: 053, Loss: 3.8860, Train: 1.2068, Val: 1.2248, Test: 1.2153
Epoch: 054, Loss: 3.8524, Train: 1.2356, Val: 1.2547, Test: 1.2436
Epoch: 055, Loss: 3.8228, Train: 1.2631, Val: 1.2831, Test: 1.2708
Epoch: 056, Loss: 3.8123, Train: 1.2763, Val: 1.2970, Test: 1.2837
Epoch: 057, Loss: 3.8048, Train: 1.2686, Val: 1.2897, Test: 1.2758
Epoch: 058, Loss: 3.7824, Train: 1.2455, Val: 1.2666, Test: 1.2527
Epoch: 059, Loss: 3.7555, Train: 1.2202, Val: 1.2413, Test: 1.2277
Epoch: 060, Loss: 3.7394, Train: 1.2058, Val: 1.2274, Test: 1.2136
Epoch: 061, Loss: 3.7242, Train: 1.2063, Val: 1.2290, Test: 1.2145
Epoch: 062, Loss: 3.6964, Train: 1.2177, Val: 1.2417, Test: 1.2263
Epoch: 063, Loss: 3.6685, Train: 1.2285, Val: 1.2538, Test: 1.2377
Epoch: 064, Loss: 3.6507, Train: 1.2273, Val: 1.2534, Test: 1.2370
Epoch: 065, Loss: 3.6328, Train: 1.2112, Val: 1.2378, Test: 1.2214
Epoch: 066, Loss: 3.6096, Train: 1.1880, Val: 1.2151, Test: 1.1989
Epoch: 067, Loss: 3.5903, Train: 1.1698, Val: 1.1973, Test: 1.1813
Epoch: 068, Loss: 3.5779, Train: 1.1636, Val: 1.1918, Test: 1.1758
Epoch: 069, Loss: 3.5635, Train: 1.1689, Val: 1.1980, Test: 1.1817
Epoch: 070, Loss: 3.5457, Train: 1.1787, Val: 1.2087, Test: 1.1922
Epoch: 071, Loss: 3.5313, Train: 1.1836, Val: 1.2144, Test: 1.1977
Epoch: 072, Loss: 3.5191, Train: 1.1780, Val: 1.2096, Test: 1.1929
Epoch: 073, Loss: 3.5040, Train: 1.1646, Val: 1.1966, Test: 1.1804
Epoch: 074, Loss: 3.4878, Train: 1.1514, Val: 1.1839, Test: 1.1680
Epoch: 075, Loss: 3.4749, Train: 1.1457, Val: 1.1787, Test: 1.1631
Epoch: 076, Loss: 3.4624, Train: 1.1495, Val: 1.1832, Test: 1.1675
Epoch: 077, Loss: 3.4475, Train: 1.1583, Val: 1.1928, Test: 1.1770
Epoch: 078, Loss: 3.4332, Train: 1.1650, Val: 1.2001, Test: 1.1843
Epoch: 079, Loss: 3.4208, Train: 1.1640, Val: 1.1996, Test: 1.1839
Epoch: 080, Loss: 3.4070, Train: 1.1562, Val: 1.1922, Test: 1.1768
Epoch: 081, Loss: 3.3919, Train: 1.1476, Val: 1.1840, Test: 1.1688
Epoch: 082, Loss: 3.3786, Train: 1.1440, Val: 1.1809, Test: 1.1658
Epoch: 083, Loss: 3.3661, Train: 1.1473, Val: 1.1849, Test: 1.1697
Epoch: 084, Loss: 3.3529, Train: 1.1545, Val: 1.1927, Test: 1.1773
Epoch: 085, Loss: 3.3407, Train: 1.1597, Val: 1.1985, Test: 1.1830
Epoch: 086, Loss: 3.3300, Train: 1.1589, Val: 1.1984, Test: 1.1829
Epoch: 087, Loss: 3.3189, Train: 1.1532, Val: 1.1931, Test: 1.1779
Epoch: 088, Loss: 3.3077, Train: 1.1474, Val: 1.1879, Test: 1.1728
Epoch: 089, Loss: 3.2978, Train: 1.1460, Val: 1.1871, Test: 1.1721
Epoch: 090, Loss: 3.2881, Train: 1.1495, Val: 1.1911, Test: 1.1761
Epoch: 091, Loss: 3.2783, Train: 1.1541, Val: 1.1963, Test: 1.1812
Epoch: 092, Loss: 3.2694, Train: 1.1553, Val: 1.1979, Test: 1.1828
Epoch: 093, Loss: 3.2610, Train: 1.1516, Val: 1.1946, Test: 1.1796
Epoch: 094, Loss: 3.2523, Train: 1.1458, Val: 1.1892, Test: 1.1745
Epoch: 095, Loss: 3.2443, Train: 1.1425, Val: 1.1864, Test: 1.1719
Epoch: 096, Loss: 3.2370, Train: 1.1436, Val: 1.1879, Test: 1.1734
Epoch: 097, Loss: 3.2296, Train: 1.1468, Val: 1.1915, Test: 1.1770
Epoch: 098, Loss: 3.2227, Train: 1.1482, Val: 1.1934, Test: 1.1788
Epoch: 099, Loss: 3.2163, Train: 1.1458, Val: 1.1916, Test: 1.1771
Epoch: 100, Loss: 3.2099, Train: 1.1415, Val: 1.1877, Test: 1.1733
Epoch: 101, Loss: 3.2038, Train: 1.1384, Val: 1.1852, Test: 1.1708
Epoch: 102, Loss: 3.1980, Train: 1.1385, Val: 1.1857, Test: 1.1713
Epoch: 103, Loss: 3.1923, Train: 1.1404, Val: 1.1881, Test: 1.1735
Epoch: 104, Loss: 3.1868, Train: 1.1413, Val: 1.1894, Test: 1.1746
Epoch: 105, Loss: 3.1816, Train: 1.1397, Val: 1.1883, Test: 1.1734
Epoch: 106, Loss: 3.1764, Train: 1.1369, Val: 1.1861, Test: 1.1711
Epoch: 107, Loss: 3.1714, Train: 1.1357, Val: 1.1852, Test: 1.1701
Epoch: 108, Loss: 3.1666, Train: 1.1365, Val: 1.1863, Test: 1.1710
Epoch: 109, Loss: 3.1619, Train: 1.1376, Val: 1.1876, Test: 1.1723
Epoch: 110, Loss: 3.1575, Train: 1.1376, Val: 1.1878, Test: 1.1724
Epoch: 111, Loss: 3.1532, Train: 1.1362, Val: 1.1866, Test: 1.1712
Epoch: 112, Loss: 3.1490, Train: 1.1347, Val: 1.1853, Test: 1.1698
Epoch: 113, Loss: 3.1449, Train: 1.1343, Val: 1.1849, Test: 1.1694
Epoch: 114, Loss: 3.1411, Train: 1.1349, Val: 1.1856, Test: 1.1699
Epoch: 115, Loss: 3.1374, Train: 1.1352, Val: 1.1860, Test: 1.1703
Epoch: 116, Loss: 3.1338, Train: 1.1348, Val: 1.1856, Test: 1.1699
Epoch: 117, Loss: 3.1304, Train: 1.1339, Val: 1.1848, Test: 1.1690
Epoch: 118, Loss: 3.1271, Train: 1.1333, Val: 1.1841, Test: 1.1683
Epoch: 119, Loss: 3.1238, Train: 1.1330, Val: 1.1839, Test: 1.1679
Epoch: 120, Loss: 3.1207, Train: 1.1328, Val: 1.1837, Test: 1.1676
Epoch: 121, Loss: 3.1177, Train: 1.1326, Val: 1.1835, Test: 1.1673
Epoch: 122, Loss: 3.1147, Train: 1.1322, Val: 1.1831, Test: 1.1668
Epoch: 123, Loss: 3.1119, Train: 1.1319, Val: 1.1829, Test: 1.1665
Epoch: 124, Loss: 3.1091, Train: 1.1317, Val: 1.1827, Test: 1.1662
Epoch: 125, Loss: 3.1064, Train: 1.1314, Val: 1.1824, Test: 1.1658
Epoch: 126, Loss: 3.1038, Train: 1.1311, Val: 1.1821, Test: 1.1654
Epoch: 127, Loss: 3.1013, Train: 1.1310, Val: 1.1820, Test: 1.1652
Epoch: 128, Loss: 3.0988, Train: 1.1311, Val: 1.1820, Test: 1.1652
Epoch: 129, Loss: 3.0963, Train: 1.1310, Val: 1.1818, Test: 1.1650
Epoch: 130, Loss: 3.0939, Train: 1.1303, Val: 1.1811, Test: 1.1643
Epoch: 131, Loss: 3.0916, Train: 1.1299, Val: 1.1807, Test: 1.1638
Epoch: 132, Loss: 3.0893, Train: 1.1303, Val: 1.1810, Test: 1.1642
Epoch: 133, Loss: 3.0871, Train: 1.1305, Val: 1.1812, Test: 1.1644
Epoch: 134, Loss: 3.0850, Train: 1.1304, Val: 1.1811, Test: 1.1643
Epoch: 135, Loss: 3.0828, Train: 1.1298, Val: 1.1805, Test: 1.1636
Epoch: 136, Loss: 3.0808, Train: 1.1296, Val: 1.1803, Test: 1.1634
Epoch: 137, Loss: 3.0787, Train: 1.1298, Val: 1.1803, Test: 1.1635
Epoch: 138, Loss: 3.0768, Train: 1.1295, Val: 1.1799, Test: 1.1633
Epoch: 139, Loss: 3.0747, Train: 1.1292, Val: 1.1796, Test: 1.1630
Epoch: 140, Loss: 3.0729, Train: 1.1294, Val: 1.1797, Test: 1.1632
Epoch: 141, Loss: 3.0709, Train: 1.1293, Val: 1.1795, Test: 1.1631
Epoch: 142, Loss: 3.0690, Train: 1.1285, Val: 1.1788, Test: 1.1624
Epoch: 143, Loss: 3.0672, Train: 1.1283, Val: 1.1786, Test: 1.1622
Epoch: 144, Loss: 3.0653, Train: 1.1285, Val: 1.1786, Test: 1.1624
Epoch: 145, Loss: 3.0636, Train: 1.1280, Val: 1.1781, Test: 1.1619
Epoch: 146, Loss: 3.0618, Train: 1.1278, Val: 1.1778, Test: 1.1616
Epoch: 147, Loss: 3.0600, Train: 1.1277, Val: 1.1776, Test: 1.1615
Epoch: 148, Loss: 3.0584, Train: 1.1272, Val: 1.1772, Test: 1.1610
Epoch: 149, Loss: 3.0566, Train: 1.1275, Val: 1.1773, Test: 1.1613
Epoch: 150, Loss: 3.0549, Train: 1.1276, Val: 1.1774, Test: 1.1614
Epoch: 151, Loss: 3.0533, Train: 1.1268, Val: 1.1765, Test: 1.1606
Epoch: 152, Loss: 3.0515, Train: 1.1266, Val: 1.1763, Test: 1.1604
Epoch: 153, Loss: 3.0499, Train: 1.1269, Val: 1.1765, Test: 1.1606
Epoch: 154, Loss: 3.0483, Train: 1.1269, Val: 1.1765, Test: 1.1607
Epoch: 155, Loss: 3.0466, Train: 1.1269, Val: 1.1762, Test: 1.1606
Epoch: 156, Loss: 3.0450, Train: 1.1256, Val: 1.1750, Test: 1.1595
Epoch: 157, Loss: 3.0434, Train: 1.1258, Val: 1.1751, Test: 1.1597
Epoch: 158, Loss: 3.0418, Train: 1.1268, Val: 1.1760, Test: 1.1606
Epoch: 159, Loss: 3.0401, Train: 1.1263, Val: 1.1755, Test: 1.1601
Epoch: 160, Loss: 3.0386, Train: 1.1256, Val: 1.1747, Test: 1.1594
Epoch: 161, Loss: 3.0369, Train: 1.1253, Val: 1.1744, Test: 1.1592
Epoch: 162, Loss: 3.0353, Train: 1.1260, Val: 1.1749, Test: 1.1598
Epoch: 163, Loss: 3.0337, Train: 1.1254, Val: 1.1744, Test: 1.1594
Epoch: 164, Loss: 3.0322, Train: 1.1254, Val: 1.1743, Test: 1.1593
Epoch: 165, Loss: 3.0305, Train: 1.1252, Val: 1.1740, Test: 1.1592
Epoch: 166, Loss: 3.0289, Train: 1.1247, Val: 1.1736, Test: 1.1588
Epoch: 167, Loss: 3.0273, Train: 1.1253, Val: 1.1740, Test: 1.1593
Epoch: 168, Loss: 3.0257, Train: 1.1248, Val: 1.1734, Test: 1.1588
Epoch: 169, Loss: 3.0241, Train: 1.1235, Val: 1.1722, Test: 1.1577
Epoch: 170, Loss: 3.0225, Train: 1.1240, Val: 1.1726, Test: 1.1583
Epoch: 171, Loss: 3.0207, Train: 1.1243, Val: 1.1728, Test: 1.1586
Epoch: 172, Loss: 3.0190, Train: 1.1237, Val: 1.1721, Test: 1.1581
Epoch: 173, Loss: 3.0174, Train: 1.1231, Val: 1.1715, Test: 1.1574
Epoch: 174, Loss: 3.0158, Train: 1.1228, Val: 1.1713, Test: 1.1572
Epoch: 175, Loss: 3.0141, Train: 1.1237, Val: 1.1721, Test: 1.1579
Epoch: 176, Loss: 3.0123, Train: 1.1230, Val: 1.1715, Test: 1.1572
Epoch: 177, Loss: 3.0106, Train: 1.1218, Val: 1.1703, Test: 1.1562
Epoch: 178, Loss: 3.0090, Train: 1.1230, Val: 1.1713, Test: 1.1571
Epoch: 179, Loss: 3.0073, Train: 1.1226, Val: 1.1710, Test: 1.1569
Epoch: 180, Loss: 3.0055, Train: 1.1221, Val: 1.1705, Test: 1.1564
Epoch: 181, Loss: 3.0037, Train: 1.1219, Val: 1.1702, Test: 1.1562
Epoch: 182, Loss: 3.0020, Train: 1.1213, Val: 1.1699, Test: 1.1558
Epoch: 183, Loss: 3.0005, Train: 1.1223, Val: 1.1707, Test: 1.1566
Epoch: 184, Loss: 2.9985, Train: 1.1213, Val: 1.1697, Test: 1.1557
Epoch: 185, Loss: 2.9969, Train: 1.1198, Val: 1.1684, Test: 1.1546
Epoch: 186, Loss: 2.9950, Train: 1.1222, Val: 1.1705, Test: 1.1568
Epoch: 187, Loss: 2.9931, Train: 1.1221, Val: 1.1703, Test: 1.1566
Epoch: 188, Loss: 2.9916, Train: 1.1186, Val: 1.1673, Test: 1.1536
Epoch: 189, Loss: 2.9896, Train: 1.1207, Val: 1.1693, Test: 1.1555
Epoch: 190, Loss: 2.9877, Train: 1.1224, Val: 1.1708, Test: 1.1571
Epoch: 191, Loss: 2.9860, Train: 1.1183, Val: 1.1670, Test: 1.1536
Epoch: 192, Loss: 2.9840, Train: 1.1190, Val: 1.1676, Test: 1.1542
Epoch: 193, Loss: 2.9820, Train: 1.1217, Val: 1.1701, Test: 1.1567
Epoch: 194, Loss: 2.9802, Train: 1.1189, Val: 1.1678, Test: 1.1544
Epoch: 195, Loss: 2.9784, Train: 1.1185, Val: 1.1674, Test: 1.1541
Epoch: 196, Loss: 2.9763, Train: 1.1199, Val: 1.1686, Test: 1.1553
Epoch: 197, Loss: 2.9746, Train: 1.1181, Val: 1.1671, Test: 1.1540
Epoch: 198, Loss: 2.9729, Train: 1.1193, Val: 1.1681, Test: 1.1552
Epoch: 199, Loss: 2.9705, Train: 1.1190, Val: 1.1677, Test: 1.1548
Epoch: 200, Loss: 2.9688, Train: 1.1160, Val: 1.1652, Test: 1.1524
Epoch: 201, Loss: 2.9665, Train: 1.1189, Val: 1.1679, Test: 1.1551
Epoch: 202, Loss: 2.9646, Train: 1.1196, Val: 1.1684, Test: 1.1557
Epoch: 203, Loss: 2.9624, Train: 1.1146, Val: 1.1638, Test: 1.1511
Epoch: 204, Loss: 2.9604, Train: 1.1167, Val: 1.1659, Test: 1.1534
Epoch: 205, Loss: 2.9582, Train: 1.1205, Val: 1.1695, Test: 1.1570
Epoch: 206, Loss: 2.9561, Train: 1.1143, Val: 1.1636, Test: 1.1512
Epoch: 207, Loss: 2.9540, Train: 1.1143, Val: 1.1638, Test: 1.1515
Epoch: 208, Loss: 2.9515, Train: 1.1202, Val: 1.1695, Test: 1.1571
Epoch: 209, Loss: 2.9495, Train: 1.1142, Val: 1.1639, Test: 1.1517
Epoch: 210, Loss: 2.9469, Train: 1.1127, Val: 1.1627, Test: 1.1506
Epoch: 211, Loss: 2.9446, Train: 1.1186, Val: 1.1683, Test: 1.1564
Epoch: 212, Loss: 2.9425, Train: 1.1146, Val: 1.1648, Test: 1.1528
Epoch: 213, Loss: 2.9398, Train: 1.1121, Val: 1.1625, Test: 1.1506
Epoch: 214, Loss: 2.9378, Train: 1.1167, Val: 1.1669, Test: 1.1551
Epoch: 215, Loss: 2.9350, Train: 1.1141, Val: 1.1648, Test: 1.1531
Epoch: 216, Loss: 2.9327, Train: 1.1130, Val: 1.1638, Test: 1.1520
Epoch: 217, Loss: 2.9303, Train: 1.1139, Val: 1.1648, Test: 1.1531
Epoch: 218, Loss: 2.9276, Train: 1.1140, Val: 1.1651, Test: 1.1535
Epoch: 219, Loss: 2.9251, Train: 1.1125, Val: 1.1638, Test: 1.1521
Epoch: 220, Loss: 2.9226, Train: 1.1126, Val: 1.1641, Test: 1.1525
Epoch: 221, Loss: 2.9200, Train: 1.1138, Val: 1.1654, Test: 1.1538
Epoch: 222, Loss: 2.9174, Train: 1.1100, Val: 1.1621, Test: 1.1505
Epoch: 223, Loss: 2.9149, Train: 1.1136, Val: 1.1655, Test: 1.1541
Epoch: 224, Loss: 2.9122, Train: 1.1098, Val: 1.1622, Test: 1.1511
Epoch: 225, Loss: 2.9097, Train: 1.1126, Val: 1.1648, Test: 1.1537
Epoch: 226, Loss: 2.9070, Train: 1.1086, Val: 1.1614, Test: 1.1506
Epoch: 227, Loss: 2.9042, Train: 1.1123, Val: 1.1649, Test: 1.1542
Epoch: 228, Loss: 2.9015, Train: 1.1081, Val: 1.1613, Test: 1.1508
Epoch: 229, Loss: 2.8986, Train: 1.1105, Val: 1.1637, Test: 1.1533
Epoch: 230, Loss: 2.8957, Train: 1.1089, Val: 1.1624, Test: 1.1520
Epoch: 231, Loss: 2.8930, Train: 1.1074, Val: 1.1614, Test: 1.1512
Epoch: 232, Loss: 2.8903, Train: 1.1118, Val: 1.1657, Test: 1.1553
Epoch: 233, Loss: 2.8875, Train: 1.1022, Val: 1.1571, Test: 1.1471
Epoch: 234, Loss: 2.8849, Train: 1.1168, Val: 1.1705, Test: 1.1604
Epoch: 235, Loss: 2.8829, Train: 1.0958, Val: 1.1518, Test: 1.1425
Epoch: 236, Loss: 2.8813, Train: 1.1216, Val: 1.1753, Test: 1.1656
Epoch: 237, Loss: 2.8795, Train: 1.0932, Val: 1.1500, Test: 1.1411
Epoch: 238, Loss: 2.8763, Train: 1.1173, Val: 1.1718, Test: 1.1626
Epoch: 239, Loss: 2.8717, Train: 1.1008, Val: 1.1573, Test: 1.1485
Epoch: 240, Loss: 2.8670, Train: 1.1037, Val: 1.1601, Test: 1.1514
Epoch: 241, Loss: 2.8633, Train: 1.1117, Val: 1.1675, Test: 1.1589
Epoch: 242, Loss: 2.8610, Train: 1.0945, Val: 1.1524, Test: 1.1444
Epoch: 243, Loss: 2.8593, Train: 1.1173, Val: 1.1730, Test: 1.1645
Epoch: 244, Loss: 2.8573, Train: 1.0916, Val: 1.1505, Test: 1.1426
Epoch: 245, Loss: 2.8540, Train: 1.1141, Val: 1.1708, Test: 1.1626
Epoch: 246, Loss: 2.8495, Train: 1.0970, Val: 1.1560, Test: 1.1481
Epoch: 247, Loss: 2.8448, Train: 1.1031, Val: 1.1618, Test: 1.1539
Epoch: 248, Loss: 2.8407, Train: 1.1057, Val: 1.1643, Test: 1.1565
Epoch: 249, Loss: 2.8376, Train: 1.0941, Val: 1.1544, Test: 1.1470
Epoch: 250, Loss: 2.8351, Train: 1.1120, Val: 1.1708, Test: 1.1631
Epoch: 251, Loss: 2.8328, Train: 1.0886, Val: 1.1506, Test: 1.1436
Epoch: 252, Loss: 2.8310, Train: 1.1161, Val: 1.1752, Test: 1.1676
Epoch: 253, Loss: 2.8284, Train: 1.0848, Val: 1.1480, Test: 1.1411
Epoch: 254, Loss: 2.8254, Train: 1.1155, Val: 1.1754, Test: 1.1679
Epoch: 255, Loss: 2.8218, Train: 1.0863, Val: 1.1501, Test: 1.1434
Epoch: 256, Loss: 2.8175, Train: 1.1097, Val: 1.1710, Test: 1.1639
Epoch: 257, Loss: 2.8126, Train: 1.0909, Val: 1.1549, Test: 1.1482
Epoch: 258, Loss: 2.8078, Train: 1.1013, Val: 1.1644, Test: 1.1575
Epoch: 259, Loss: 2.8035, Train: 1.0977, Val: 1.1618, Test: 1.1550
Epoch: 260, Loss: 2.7997, Train: 1.0929, Val: 1.1581, Test: 1.1514
Epoch: 261, Loss: 2.7965, Train: 1.1046, Val: 1.1687, Test: 1.1618
Epoch: 262, Loss: 2.7939, Train: 1.0832, Val: 1.1504, Test: 1.1442
Epoch: 263, Loss: 2.7930, Train: 1.1191, Val: 1.1820, Test: 1.1754
Epoch: 264, Loss: 2.7958, Train: 1.0671, Val: 1.1375, Test: 1.1320
Epoch: 265, Loss: 2.8024, Train: 1.1427, Val: 1.2035, Test: 1.1971
Epoch: 266, Loss: 2.8145, Train: 1.0597, Val: 1.1322, Test: 1.1270
Epoch: 267, Loss: 2.8085, Train: 1.1270, Val: 1.1903, Test: 1.1840
Epoch: 268, Loss: 2.7918, Train: 1.0841, Val: 1.1533, Test: 1.1477
Epoch: 269, Loss: 2.7708, Train: 1.0790, Val: 1.1493, Test: 1.1439
Epoch: 270, Loss: 2.7699, Train: 1.1260, Val: 1.1903, Test: 1.1845
Epoch: 271, Loss: 2.7811, Train: 1.0634, Val: 1.1368, Test: 1.1322
Epoch: 272, Loss: 2.7802, Train: 1.1179, Val: 1.1838, Test: 1.1784
Epoch: 273, Loss: 2.7679, Train: 1.0834, Val: 1.1544, Test: 1.1495
Epoch: 274, Loss: 2.7532, Train: 1.0793, Val: 1.1512, Test: 1.1465
Epoch: 275, Loss: 2.7512, Train: 1.1157, Val: 1.1828, Test: 1.1778
Epoch: 276, Loss: 2.7572, Train: 1.0656, Val: 1.1402, Test: 1.1363
Epoch: 277, Loss: 2.7568, Train: 1.1139, Val: 1.1819, Test: 1.1773
Epoch: 278, Loss: 2.7494, Train: 1.0779, Val: 1.1513, Test: 1.1475
Epoch: 279, Loss: 2.7376, Train: 1.0854, Val: 1.1582, Test: 1.1543
Epoch: 280, Loss: 2.7317, Train: 1.1024, Val: 1.1732, Test: 1.1691
Epoch: 281, Loss: 2.7322, Train: 1.0685, Val: 1.1448, Test: 1.1415
Epoch: 282, Loss: 2.7340, Train: 1.1148, Val: 1.1848, Test: 1.1807
Epoch: 283, Loss: 2.7350, Train: 1.0653, Val: 1.1431, Test: 1.1400
Epoch: 284, Loss: 2.7298, Train: 1.1078, Val: 1.1796, Test: 1.1758
Epoch: 285, Loss: 2.7233, Train: 1.0731, Val: 1.1504, Test: 1.1474
Epoch: 286, Loss: 2.7150, Train: 1.0907, Val: 1.1658, Test: 1.1624
Epoch: 287, Loss: 2.7087, Train: 1.0865, Val: 1.1626, Test: 1.1595
Epoch: 288, Loss: 2.7046, Train: 1.0777, Val: 1.1556, Test: 1.1529
Epoch: 289, Loss: 2.7025, Train: 1.0996, Val: 1.1747, Test: 1.1717
Epoch: 290, Loss: 2.7024, Train: 1.0639, Val: 1.1450, Test: 1.1428
Epoch: 291, Loss: 2.7051, Train: 1.1197, Val: 1.1929, Test: 1.1898
Epoch: 292, Loss: 2.7145, Train: 1.0482, Val: 1.1331, Test: 1.1316
Epoch: 293, Loss: 2.7254, Train: 1.1466, Val: 1.2169, Test: 1.2140
Epoch: 294, Loss: 2.7471, Train: 1.0410, Val: 1.1281, Test: 1.1270
Epoch: 295, Loss: 2.7388, Train: 1.1320, Val: 1.2047, Test: 1.2022
Epoch: 296, Loss: 2.7221, Train: 1.0616, Val: 1.1451, Test: 1.1438
Epoch: 297, Loss: 2.6866, Train: 1.0735, Val: 1.1552, Test: 1.1540
Epoch: 298, Loss: 2.6756, Train: 1.1128, Val: 1.1888, Test: 1.1873
Epoch: 299, Loss: 2.6897, Train: 1.0465, Val: 1.1337, Test: 1.1334

Generate predictions and store them back to ArangoDB¶

Here, we will predict the new links between users and movies with the trained model. We will only select movies for a user whose predicting ratings are equal to 5.

In [ ]:
total_users = len(users)
total_movies = len(movies)
movie_recs = []
for user_id in tqdm(range(0, total_users)):
    user_row = torch.tensor([user_id] * total_movies)
    all_movie_ids = torch.arange(total_movies)
    edge_label_index = torch.stack([user_row, all_movie_ids], dim=0)
    pred = model(data.x_dict, data.edge_index_dict,
             edge_label_index)
    pred = pred.clamp(min=0, max=5)
    # we will only select movies for the user where the predicting rating is =5
    rec_movie_ids = (pred == 5).nonzero(as_tuple=True)
    top_ten_recs = [rec_movies for rec_movies in rec_movie_ids[0].tolist()[:10]] 
    movie_recs.append({'user': user_id, 'rec_movies': top_ten_recs})
100%|██████████| 671/671 [00:19<00:00, 34.08it/s]

Storing predictions back to ArangoDB

We will create a new collection name "Recommendation_Inferences" in ArangoDB to store movie recommendations for each of the user.

In [ ]:
# create a new collection named "Recommendation_Inferences" if it does not exist.
# This returns an API wrapper for "Recommendation_Inferences" collection.
if not movie_rec_db.has_collection("Recommendation_Inferences"):
    movie_rec_db.create_collection("Recommendation_Inferences", edge=True, replication_factor=3)
    
In [ ]:
def populate_movies_recommendations(movie_recs):
    batch = []

    BATCH_SIZE = 100
    batch_idx = 1
    index = 0
    rec_collection = movie_rec_db["Recommendation_Inferences"]
    for idx in tqdm(range(total_users)):
        insert_doc = {}
        to_insert = []
        user_id = movie_recs[idx]['user']
        movie_ids = movie_recs[idx]['rec_movies']
        
        for m_id in movie_ids:
            insert_doc = {
                           "_from":  ("Users" + "/" + str(user_id)),
                           "_to":    ("Movie" + "/" + str(m_id)),
                           "_rating": 5}
            to_insert.append(insert_doc)
        
        batch.extend(to_insert)
        index +=1
        last_record = (idx == (total_users - 1))
        if len(batch) > BATCH_SIZE:
          rec_collection.import_bulk(batch)
          batch = []   
        if last_record and len(batch) > 0:
          print("Inserting batch the last batch!")
          rec_collection.import_bulk(batch)
In [ ]:
populate_movies_recommendations(movie_recs)
100%|██████████| 671/671 [00:09<00:00, 74.09it/s]
Inserting batch the last batch!

References¶

  1. Working of HETEROGENEOUS GRAPH LEARNING.

  2. The heterogeneous graph neural network code is inspired from PyG example hetero_link_pred.py