NimaKL commited on
Commit
e7e5b40
·
verified ·
1 Parent(s): 2b309fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import pandas as pd
4
+ import numpy as np
5
+ from torch_geometric.data import Data
6
+ from torch_geometric.nn import GATConv
7
+ from sentence_transformers import SentenceTransformer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+ # Define the GATConv model architecture
11
+ class ModeratelySimplifiedGATConvModel(torch.nn.Module):
12
+ def __init__(self, in_channels, hidden_channels, out_channels):
13
+ super().__init__()
14
+ self.conv1 = GATConv(in_channels, hidden_channels, heads=2)
15
+ self.dropout1 = torch.nn.Dropout(0.45)
16
+ self.conv2 = GATConv(hidden_channels * 2, out_channels, heads=1)
17
+
18
+ def forward(self, x, edge_index, edge_attr=None):
19
+ x = self.conv1(x, edge_index, edge_attr)
20
+ x = torch.relu(x)
21
+ x = self.dropout1(x)
22
+ x = self.conv2(x, edge_index, edge_attr)
23
+ return x
24
+
25
+ # Load the dataset and the GATConv model
26
+ data = torch.load("graph_data.pt", map_location=torch.device("cpu"))
27
+
28
+ # Correct the state dictionary's key names
29
+ original_state_dict = torch.load("graph_model.pth", map_location=torch.device("cpu"))
30
+ corrected_state_dict = {}
31
+ for key, value in original_state_dict.items():
32
+ if "lin.weight" in key:
33
+ corrected_state_dict[key.replace("lin.weight", "lin_src.weight")] = value
34
+ corrected_state_dict[key.replace("lin.weight", "lin_dst.weight")] = value
35
+ else:
36
+ corrected_state_dict[key] = value
37
+
38
+ # Initialize the GATConv model with the corrected state dictionary
39
+ gatconv_model = ModeratelySimplifiedGATConvModel(
40
+ in_channels=data.x.shape[1], hidden_channels=32, out_channels=768
41
+ )
42
+ gatconv_model.load_state_dict(corrected_state_dict)
43
+
44
+ # Load the BERT-based sentence transformer model
45
+ model_bert = SentenceTransformer("all-mpnet-base-v2")
46
+
47
+ # Ensure the DataFrame is loaded properly
48
+ try:
49
+ df = pd.read_json("combined_data.json.gz", orient='records', lines=True, compression='gzip')
50
+ except Exception as e:
51
+ st.error(f"Error reading JSON file: {e}")
52
+
53
+ # Generate GNN-based embeddings
54
+ with torch.no_grad():
55
+ all_video_embeddings = gatconv_model(data.x, data.edge_index, data.edge_attr).cpu()
56
+
57
+ # Function to find the most similar video and recommend the top 10 based on GNN embeddings
58
+ def get_similar_and_recommend(input_text):
59
+ # Find the most similar video based on cosine similarity
60
+ embeddings_matrix = np.array(df["embeddings"].tolist())
61
+ input_embedding = model_bert.encode([input_text])[0]
62
+ similarities = cosine_similarity([input_embedding], embeddings_matrix)[0]
63
+
64
+ most_similar_index = np.argmax(similarities) # Find the most similar video
65
+
66
+ # Get all features of the most similar video
67
+ most_similar_video_features = df.iloc[most_similar_index].to_dict()
68
+
69
+ # Clean up certain fields
70
+ if "text_for_embedding" in most_similar_video_features:
71
+ del most_similar_video_features["text_for_embedding"]
72
+ if "embeddings" in most_similar_video_features:
73
+ del most_similar_video_features["embeddings"]
74
+
75
+ # Recommend the top 10 videos based on GNN embeddings
76
+ def recommend_top_10(given_video_index, all_video_embeddings):
77
+ dot_products = [
78
+ torch.dot(all_video_embeddings[given_video_index], all_video_embeddings[i])
79
+ for i in range(all_video_embeddings.shape[0])
80
+ ]
81
+ dot_products[given_video_index] = -float("inf") # Exclude the most similar video
82
+
83
+ top_10_indices = np.argsort(dot_products)[::-1][:10]
84
+ return [df.iloc[idx].to_dict() for idx in top_10_indices]
85
+
86
+ top_10_recommended_videos_features = recommend_top_10(most_similar_index, all_video_embeddings)
87
+
88
+ # Apply search context to determine weights for GNN results
89
+ user_keywords = input_text.split() # Create a list of keywords from user input
90
+ video_weights = []
91
+ weight = 1.0 # Initial weight factor
92
+
93
+ for keyword in user_keywords:
94
+ if keyword.lower() in df["title"].str.lower().tolist(): # Check for matching keywords
95
+ weight += 0.1 # Increase weight for matching keyword
96
+
97
+ # Calculate the weight for each GNN output
98
+ video_weights = [weight] * len(top_10_recommended_videos_features)
99
+
100
+ # Clean up certain fields in recommendations
101
+ for recommended_video in top_10_recommended_videos_features:
102
+ if "text_for_embedding" in recommended_video:
103
+ del recommended_video["text_for_embedding"]
104
+ if "embeddings" in recommended_video:
105
+ del recommended_video["embeddings"]
106
+
107
+ # Create the output JSON with the most similar video, final recommendations, and weights
108
+ output = {
109
+ "search_context": {
110
+ "input_text": input_text, # What the user provided
111
+ "weights": video_weights, # Weights for each GNN-based recommendation
112
+ },
113
+ "most_similar_video": most_similar_video_features,
114
+ "final_recommendations": top 10 recommended videos with individual weights for each recommendation
115
+ }
116
+
117
+ return output
118
+
119
+ # Create a Streamlit text input widget for entering text and retrieve the most similar video and top 10 recommended videos
120
+ user_input = st.text_input("Enter text to find the most similar video")
121
+
122
+ if user_input:
123
+ recommendations = get_similar_and_recommend(user_input)
124
+ st.json(recommendations)