Spaces:
Sleeping
Sleeping
Upload 11 files
Browse files- Dockerfile +25 -0
- chat_management.py +94 -0
- document_loaders.py +193 -0
- embedding.py +11 -0
- llm_initialization.py +17 -0
- main.py +355 -0
- prompt_templates.py +242 -0
- requirements.txt +23 -0
- retrieval_chain.py +110 -0
- text_splitter.py +30 -0
- vector_store.py +234 -0
Dockerfile
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base image using Python 3.9
|
2 |
+
FROM python:3.9
|
3 |
+
|
4 |
+
# Create a new user to run the app
|
5 |
+
RUN useradd -m -u 1000 user
|
6 |
+
USER user
|
7 |
+
|
8 |
+
# Set environment variables
|
9 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
10 |
+
|
11 |
+
# Set the working directory
|
12 |
+
WORKDIR /app
|
13 |
+
|
14 |
+
# Copy the requirements and install dependencies
|
15 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
16 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
17 |
+
|
18 |
+
# Copy the rest of the application
|
19 |
+
COPY --chown=user . /app
|
20 |
+
|
21 |
+
# Expose port 7860 for the application
|
22 |
+
EXPOSE 7860
|
23 |
+
|
24 |
+
# Command to run the FastAPI app using uvicorn
|
25 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
chat_management.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from pymongo import MongoClient
|
3 |
+
from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
|
4 |
+
|
5 |
+
|
6 |
+
class ChatManagement:
|
7 |
+
def __init__(self, cluster_url, database_name, collection_name):
|
8 |
+
"""
|
9 |
+
Initializes the ChatManagement class with MongoDB connection details.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
cluster_url (str): MongoDB cluster URL.
|
13 |
+
database_name (str): Name of the database.
|
14 |
+
collection_name (str): Name of the collection.
|
15 |
+
"""
|
16 |
+
self.connection_string = cluster_url
|
17 |
+
self.database_name = database_name
|
18 |
+
self.collection_name = collection_name
|
19 |
+
self.chat_sessions = {} # Dictionary to store chat history objects for each session
|
20 |
+
|
21 |
+
def create_new_chat(self):
|
22 |
+
"""
|
23 |
+
Creates a new chat session by initializing a MongoDBChatMessageHistory object.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: The unique chat ID.
|
27 |
+
"""
|
28 |
+
# Generate a unique chat ID
|
29 |
+
chat_id = str(uuid.uuid4())
|
30 |
+
|
31 |
+
# Initialize MongoDBChatMessageHistory for the chat session
|
32 |
+
chat_message_history = MongoDBChatMessageHistory(
|
33 |
+
session_id=chat_id,
|
34 |
+
connection_string=self.connection_string,
|
35 |
+
database_name=self.database_name,
|
36 |
+
collection_name=self.collection_name,
|
37 |
+
)
|
38 |
+
|
39 |
+
# Store the chat_message_history object in the session dictionary
|
40 |
+
self.chat_sessions[chat_id] = chat_message_history
|
41 |
+
return chat_id
|
42 |
+
|
43 |
+
def get_chat_history(self, chat_id):
|
44 |
+
"""
|
45 |
+
Retrieves the MongoDBChatMessageHistory object for a given chat session by its chat ID.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
chat_id (str): The unique ID of the chat session.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
MongoDBChatMessageHistory or None: The chat history object of the chat session, or None if not found.
|
52 |
+
"""
|
53 |
+
# Check if the chat session is already in memory
|
54 |
+
if chat_id in self.chat_sessions:
|
55 |
+
return self.chat_sessions[chat_id]
|
56 |
+
|
57 |
+
# If not in memory, try to fetch from the database
|
58 |
+
chat_message_history = MongoDBChatMessageHistory(
|
59 |
+
session_id=chat_id,
|
60 |
+
connection_string=self.connection_string,
|
61 |
+
database_name=self.database_name,
|
62 |
+
collection_name=self.collection_name,
|
63 |
+
)
|
64 |
+
if chat_message_history.messages: # Check if the session exists in the database
|
65 |
+
self.chat_sessions[chat_id] = chat_message_history
|
66 |
+
return chat_message_history
|
67 |
+
|
68 |
+
return None # Chat session not found
|
69 |
+
|
70 |
+
def initialize_chat_history(self, chat_id):
|
71 |
+
"""
|
72 |
+
Initializes a new chat history for the given chat ID if it does not already exist.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
chat_id (str): The unique ID of the chat session.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
MongoDBChatMessageHistory: The initialized chat history object.
|
79 |
+
"""
|
80 |
+
# If the chat history already exists, return it
|
81 |
+
if chat_id in self.chat_sessions:
|
82 |
+
return self.chat_sessions[chat_id]
|
83 |
+
|
84 |
+
# Otherwise, create a new chat history
|
85 |
+
chat_message_history = MongoDBChatMessageHistory(
|
86 |
+
session_id=chat_id,
|
87 |
+
connection_string=self.connection_string,
|
88 |
+
database_name=self.database_name,
|
89 |
+
collection_name=self.collection_name,
|
90 |
+
)
|
91 |
+
|
92 |
+
# Save the new chat session to the session dictionary
|
93 |
+
self.chat_sessions[chat_id] = chat_message_history
|
94 |
+
return chat_message_history
|
document_loaders.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import (CSVLoader, WikipediaLoader, UnstructuredURLLoader,
|
2 |
+
YoutubeLoader, PyPDFLoader, BSHTMLLoader,
|
3 |
+
Docx2txtLoader, UnstructuredMarkdownLoader)
|
4 |
+
|
5 |
+
from langchain_unstructured import UnstructuredLoader
|
6 |
+
|
7 |
+
|
8 |
+
class DocumentLoader:
|
9 |
+
def load_unstructured(self, path):
|
10 |
+
"""
|
11 |
+
Load data from a file at the specified path:
|
12 |
+
|
13 |
+
supported files:
|
14 |
+
"csv", "doc", "docx", "epub", "image", "md", "msg", "odt", "org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx"
|
15 |
+
|
16 |
+
|
17 |
+
Args:
|
18 |
+
path (str): The file paths
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
The loaded data.
|
22 |
+
|
23 |
+
Exceptions:
|
24 |
+
Prints an error message if the loading fails.
|
25 |
+
"""
|
26 |
+
try:
|
27 |
+
loader = UnstructuredLoader(path)
|
28 |
+
data = loader.load()
|
29 |
+
return data
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Error loading Unstructured: {e}")
|
32 |
+
|
33 |
+
def load_csv(self, path):
|
34 |
+
"""
|
35 |
+
Load data from a CSV file at the specified path.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
path (str): The file path to the CSV file.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The loaded CSV data.
|
42 |
+
|
43 |
+
Exceptions:
|
44 |
+
Prints an error message if the CSV loading fails.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
loader = CSVLoader(file_path=path)
|
48 |
+
data = loader.load()
|
49 |
+
return data
|
50 |
+
except Exception as e:
|
51 |
+
print(f"Error loading CSV: {e}")
|
52 |
+
|
53 |
+
def wikipedia_query(self, search_query):
|
54 |
+
"""
|
55 |
+
Query Wikipedia using a given search term and return the results.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
search_query (str): The search term to query on Wikipedia.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
The query results.
|
62 |
+
|
63 |
+
Exceptions:
|
64 |
+
Prints an error message if the Wikipedia query fails.
|
65 |
+
"""
|
66 |
+
try:
|
67 |
+
data = WikipediaLoader(query=search_query, load_max_docs=2).load()
|
68 |
+
return data
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error querying Wikipedia: {e}")
|
71 |
+
|
72 |
+
def load_urls(self, urls):
|
73 |
+
"""
|
74 |
+
Load and parse content from a list of URLs.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
urls (list): A list of URLs to load.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
The loaded data from the URLs.
|
81 |
+
|
82 |
+
Exceptions:
|
83 |
+
Prints an error message if loading URLs fails.
|
84 |
+
"""
|
85 |
+
try:
|
86 |
+
loader = UnstructuredURLLoader(urls=urls)
|
87 |
+
data = loader.load()
|
88 |
+
return data
|
89 |
+
except Exception as e:
|
90 |
+
print(f"Error loading URLs: {e}")
|
91 |
+
|
92 |
+
def load_YouTubeVideo(self, urls):
|
93 |
+
"""
|
94 |
+
Load YouTube video information from provided URLs.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
urls (list): A list of YouTube video URLs.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
The loaded documents from the YouTube URLs.
|
101 |
+
|
102 |
+
Exceptions:
|
103 |
+
Prints an error message if loading YouTube videos fails.
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
loader = YoutubeLoader.from_youtube_url(
|
107 |
+
urls, add_video_info=True, language=["en", "pt", "zh-Hans", "es", "ur", "hi"],
|
108 |
+
translation="en")
|
109 |
+
documents = loader.load()
|
110 |
+
return documents
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error loading YouTube video: {e}")
|
113 |
+
|
114 |
+
def load_pdf(self, path):
|
115 |
+
"""
|
116 |
+
Load data from a PDF file at the specified path.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
path (str): The file path to the PDF file.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
The loaded and split PDF pages.
|
123 |
+
|
124 |
+
Exceptions:
|
125 |
+
Prints an error message if the PDF loading fails.
|
126 |
+
"""
|
127 |
+
try:
|
128 |
+
loader = PyPDFLoader(path)
|
129 |
+
pages = loader.load_and_split()
|
130 |
+
return pages
|
131 |
+
except Exception as e:
|
132 |
+
print(f"Error loading PDF: {e}")
|
133 |
+
|
134 |
+
def load_text_from_html(self, path):
|
135 |
+
"""
|
136 |
+
Load and parse text content from an HTML file at the specified path.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
path (str): The file path to the HTML file.
|
140 |
+
|
141 |
+
Returns:
|
142 |
+
The loaded HTML data.
|
143 |
+
|
144 |
+
Exceptions:
|
145 |
+
Prints an error message if loading text from HTML fails.
|
146 |
+
"""
|
147 |
+
try:
|
148 |
+
loader = BSHTMLLoader(path)
|
149 |
+
data = loader.load()
|
150 |
+
return data
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error loading text from HTML: {e}")
|
153 |
+
|
154 |
+
def load_markdown(self, path):
|
155 |
+
"""
|
156 |
+
Load data from a Markdown file at the specified path.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
path (str): The file path to the Markdown file.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
The loaded Markdown data.
|
163 |
+
|
164 |
+
Exceptions:
|
165 |
+
Prints an error message if loading Markdown fails.
|
166 |
+
"""
|
167 |
+
try:
|
168 |
+
loader = UnstructuredMarkdownLoader(path)
|
169 |
+
data = loader.load()
|
170 |
+
return data
|
171 |
+
except Exception as e:
|
172 |
+
print(f"Error loading Markdown: {e}")
|
173 |
+
|
174 |
+
def load_doc(self, path):
|
175 |
+
"""
|
176 |
+
Load data from a DOCX file at the specified path.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
path (str): The file path to the DOCX file.
|
180 |
+
|
181 |
+
Returns:
|
182 |
+
The loaded DOCX data.
|
183 |
+
|
184 |
+
Exceptions:
|
185 |
+
Prints an error message if loading DOCX fails.
|
186 |
+
"""
|
187 |
+
try:
|
188 |
+
loader = Docx2txtLoader(path)
|
189 |
+
data = loader.load()
|
190 |
+
return data
|
191 |
+
except Exception as e:
|
192 |
+
print(f"Error loading DOCX: {e}")
|
193 |
+
|
embedding.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
2 |
+
|
3 |
+
def get_embeddings():
|
4 |
+
# Initialize HuggingFace embeddings
|
5 |
+
model_name = "BAAI/bge-small-en"
|
6 |
+
model_kwargs = {"device": "cpu"}
|
7 |
+
encode_kwargs = {"normalize_embeddings": True}
|
8 |
+
embeddings = HuggingFaceEmbeddings(
|
9 |
+
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
|
10 |
+
)
|
11 |
+
return embeddings
|
llm_initialization.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_groq import ChatGroq
|
2 |
+
|
3 |
+
def get_llm():
|
4 |
+
"""
|
5 |
+
Returns the language model instance (LLM) using ChatGroq API.
|
6 |
+
The LLM used is Llama 3.1 with a versatile 70 billion parameters model.
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
llm (ChatGroq): An instance of the ChatGroq LLM.
|
10 |
+
"""
|
11 |
+
llm = ChatGroq(
|
12 |
+
model="llama-3.3-70b-versatile",
|
13 |
+
temperature=0,
|
14 |
+
max_tokens=1024,
|
15 |
+
api_key='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'
|
16 |
+
)
|
17 |
+
return llm
|
main.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import zipfile
|
4 |
+
from typing import List, Optional
|
5 |
+
|
6 |
+
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
|
7 |
+
from fastapi.responses import FileResponse, StreamingResponse
|
8 |
+
|
9 |
+
from llm_initialization import get_llm
|
10 |
+
from embedding import get_embeddings
|
11 |
+
from document_loaders import DocumentLoader
|
12 |
+
from text_splitter import TextSplitter
|
13 |
+
from vector_store import VectorStoreManager
|
14 |
+
from prompt_templates import PromptTemplates
|
15 |
+
from chat_management import ChatManagement
|
16 |
+
from retrieval_chain import RetrievalChain
|
17 |
+
from urllib.parse import quote_plus
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
from pymongo import MongoClient
|
20 |
+
|
21 |
+
# Load environment variables
|
22 |
+
load_dotenv()
|
23 |
+
MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD"))
|
24 |
+
MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME")
|
25 |
+
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME")
|
26 |
+
MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING")
|
27 |
+
|
28 |
+
app = FastAPI(title="VectorStore & Document Management API")
|
29 |
+
|
30 |
+
# Global variables (initialized on startup)
|
31 |
+
llm = None
|
32 |
+
embeddings = None
|
33 |
+
chat_manager = None
|
34 |
+
document_loader = None
|
35 |
+
text_splitter = None
|
36 |
+
vector_store_manager = None
|
37 |
+
vector_store = None
|
38 |
+
k = 3 # Number of documents to retrieve per query
|
39 |
+
|
40 |
+
# Global MongoDB collection to store retrieval chain configuration per chat session.
|
41 |
+
chat_chains_collection = None
|
42 |
+
|
43 |
+
# ----------------------- Startup Event -----------------------
|
44 |
+
@app.on_event("startup")
|
45 |
+
async def startup_event():
|
46 |
+
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection
|
47 |
+
|
48 |
+
print("Starting up: Initializing components...")
|
49 |
+
|
50 |
+
# Initialize LLM and embeddings
|
51 |
+
llm = get_llm()
|
52 |
+
print("LLM initialized.")
|
53 |
+
embeddings = get_embeddings()
|
54 |
+
print("Embeddings initialized.")
|
55 |
+
|
56 |
+
# Setup chat management
|
57 |
+
chat_manager = ChatManagement(
|
58 |
+
cluster_url=MONGO_CLUSTER_URL,
|
59 |
+
database_name=MONGO_DATABASE_NAME,
|
60 |
+
collection_name=MONGO_COLLECTION_NAME,
|
61 |
+
)
|
62 |
+
print("Chat management initialized.")
|
63 |
+
|
64 |
+
# Initialize document loader and text splitter
|
65 |
+
document_loader = DocumentLoader()
|
66 |
+
text_splitter = TextSplitter()
|
67 |
+
print("Document loader and text splitter initialized.")
|
68 |
+
|
69 |
+
# Initialize vector store manager and ensure vectorstore is set
|
70 |
+
vector_store_manager = VectorStoreManager(embeddings)
|
71 |
+
vector_store = vector_store_manager.vectorstore # Now properly initialized
|
72 |
+
print("Vector store initialized.")
|
73 |
+
|
74 |
+
# Connect to MongoDB and get the collection.
|
75 |
+
client = MongoClient(MONGO_CLUSTER_URL)
|
76 |
+
db = client[MONGO_DATABASE_NAME]
|
77 |
+
chat_chains_collection = db["chat_chains"]
|
78 |
+
print("Chat chains collection initialized in MongoDB.")
|
79 |
+
|
80 |
+
|
81 |
+
# ----------------------- Root Endpoint -----------------------
|
82 |
+
@app.get("/")
|
83 |
+
def root():
|
84 |
+
"""
|
85 |
+
Root endpoint that returns a welcome message.
|
86 |
+
"""
|
87 |
+
return {"message": "Welcome to the VectorStore & Document Management API!"}
|
88 |
+
|
89 |
+
|
90 |
+
# ----------------------- New Chat Endpoint -----------------------
|
91 |
+
@app.post("/new_chat")
|
92 |
+
def new_chat():
|
93 |
+
"""
|
94 |
+
Create a new chat session.
|
95 |
+
"""
|
96 |
+
new_chat_id = chat_manager.create_new_chat()
|
97 |
+
return {"chat_id": new_chat_id}
|
98 |
+
|
99 |
+
|
100 |
+
# ----------------------- Create Chain Endpoint -----------------------
|
101 |
+
@app.post("/create_chain")
|
102 |
+
def create_chain(
|
103 |
+
chat_id: str = Query(..., description="Existing chat session ID"),
|
104 |
+
template: str = Query(
|
105 |
+
"quiz_solving",
|
106 |
+
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
|
107 |
+
),
|
108 |
+
):
|
109 |
+
global chat_chains_collection # Ensure we reference the global variable
|
110 |
+
|
111 |
+
valid_templates = [
|
112 |
+
"quiz_solving",
|
113 |
+
"assignment_solving",
|
114 |
+
"paper_solving",
|
115 |
+
"quiz_creation",
|
116 |
+
"assignment_creation",
|
117 |
+
"paper_creation",
|
118 |
+
]
|
119 |
+
if template not in valid_templates:
|
120 |
+
raise HTTPException(status_code=400, detail="Invalid template selection.")
|
121 |
+
|
122 |
+
# Upsert the configuration document for this chat session.
|
123 |
+
chat_chains_collection.update_one(
|
124 |
+
{"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True
|
125 |
+
)
|
126 |
+
|
127 |
+
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
|
128 |
+
|
129 |
+
|
130 |
+
# ----------------------- Chat Endpoint -----------------------
|
131 |
+
@app.get("/chat")
|
132 |
+
def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")):
|
133 |
+
"""
|
134 |
+
Process a chat query using the retrieval chain associated with the given chat_id.
|
135 |
+
|
136 |
+
This endpoint uses the following code:
|
137 |
+
|
138 |
+
try:
|
139 |
+
stream_generator = retrieval_chain.stream_chat_response(
|
140 |
+
query=query,
|
141 |
+
chat_id=chat_id,
|
142 |
+
get_chat_history=chat_manager.get_chat_history,
|
143 |
+
initialize_chat_history=chat_manager.initialize_chat_history,
|
144 |
+
)
|
145 |
+
except Exception as e:
|
146 |
+
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
147 |
+
|
148 |
+
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
149 |
+
|
150 |
+
It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response.
|
151 |
+
"""
|
152 |
+
# Retrieve the chat configuration from MongoDB.
|
153 |
+
config = chat_chains_collection.find_one({"chat_id": chat_id})
|
154 |
+
if not config:
|
155 |
+
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
|
156 |
+
|
157 |
+
template = config.get("template", "quiz_solving")
|
158 |
+
if template == "quiz_solving":
|
159 |
+
prompt = PromptTemplates.get_quiz_solving_prompt()
|
160 |
+
elif template == "assignment_solving":
|
161 |
+
prompt = PromptTemplates.get_assignment_solving_prompt()
|
162 |
+
elif template == "paper_solving":
|
163 |
+
prompt = PromptTemplates.get_paper_solving_prompt()
|
164 |
+
elif template == "quiz_creation":
|
165 |
+
prompt = PromptTemplates.get_quiz_creation_prompt()
|
166 |
+
elif template == "assignment_creation":
|
167 |
+
prompt = PromptTemplates.get_assignment_creation_prompt()
|
168 |
+
elif template == "paper_creation":
|
169 |
+
prompt = PromptTemplates.get_paper_creation_prompt()
|
170 |
+
else:
|
171 |
+
raise HTTPException(status_code=400, detail="Invalid chat configuration.")
|
172 |
+
|
173 |
+
# Re-create the retrieval chain for this chat session.
|
174 |
+
retrieval_chain = RetrievalChain(
|
175 |
+
llm,
|
176 |
+
vector_store.as_retriever(search_kwargs={"k": k}),
|
177 |
+
prompt,
|
178 |
+
verbose=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
try:
|
182 |
+
stream_generator = retrieval_chain.stream_chat_response(
|
183 |
+
query=query,
|
184 |
+
chat_id=chat_id,
|
185 |
+
get_chat_history=chat_manager.get_chat_history,
|
186 |
+
initialize_chat_history=chat_manager.initialize_chat_history,
|
187 |
+
)
|
188 |
+
except Exception as e:
|
189 |
+
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
|
190 |
+
|
191 |
+
return StreamingResponse(stream_generator, media_type="text/event-stream")
|
192 |
+
|
193 |
+
|
194 |
+
# ----------------------- Add Document Endpoint -----------------------
|
195 |
+
from typing import Any, Optional
|
196 |
+
|
197 |
+
@app.post("/add_document")
|
198 |
+
async def add_document(
|
199 |
+
file: Optional[Any] = File(None),
|
200 |
+
wiki_query: Optional[str] = Query(None),
|
201 |
+
wiki_url: Optional[str] = Query(None)
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Upload a document OR load data from a Wikipedia query or URL.
|
205 |
+
|
206 |
+
- If a file is provided, the document is loaded from the file.
|
207 |
+
- If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query.
|
208 |
+
- If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls.
|
209 |
+
|
210 |
+
The loaded document(s) are then split into chunks and added to the vector store.
|
211 |
+
"""
|
212 |
+
# If file is provided but not as an UploadFile (e.g. an empty string), set it to None.
|
213 |
+
if not isinstance(file, UploadFile):
|
214 |
+
file = None
|
215 |
+
|
216 |
+
# Ensure at least one input is provided.
|
217 |
+
if file is None and wiki_query is None and wiki_url is None:
|
218 |
+
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
|
219 |
+
|
220 |
+
# Load document(s) based on input priority: file > wiki_query > wiki_url.
|
221 |
+
if file is not None:
|
222 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
223 |
+
contents = await file.read()
|
224 |
+
tmp.write(contents)
|
225 |
+
tmp_filename = tmp.name
|
226 |
+
|
227 |
+
ext = file.filename.split(".")[-1].lower()
|
228 |
+
try:
|
229 |
+
if ext == "pdf":
|
230 |
+
documents = document_loader.load_pdf(tmp_filename)
|
231 |
+
elif ext == "csv":
|
232 |
+
documents = document_loader.load_csv(tmp_filename)
|
233 |
+
elif ext in ["doc", "docx"]:
|
234 |
+
documents = document_loader.load_doc(tmp_filename)
|
235 |
+
elif ext in ["html", "htm"]:
|
236 |
+
documents = document_loader.load_text_from_html(tmp_filename)
|
237 |
+
elif ext in ["md", "markdown"]:
|
238 |
+
documents = document_loader.load_markdown(tmp_filename)
|
239 |
+
else:
|
240 |
+
documents = document_loader.load_unstructured(tmp_filename)
|
241 |
+
except Exception as e:
|
242 |
+
os.remove(tmp_filename)
|
243 |
+
raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}")
|
244 |
+
os.remove(tmp_filename)
|
245 |
+
elif wiki_query is not None:
|
246 |
+
try:
|
247 |
+
documents = document_loader.wikipedia_query(wiki_query)
|
248 |
+
except Exception as e:
|
249 |
+
raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}")
|
250 |
+
elif wiki_url is not None:
|
251 |
+
try:
|
252 |
+
documents = document_loader.load_urls([wiki_url])
|
253 |
+
except Exception as e:
|
254 |
+
raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}")
|
255 |
+
|
256 |
+
try:
|
257 |
+
chunks = text_splitter.split_documents(documents)
|
258 |
+
except Exception as e:
|
259 |
+
raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}")
|
260 |
+
|
261 |
+
try:
|
262 |
+
ids = vector_store_manager.add_documents(chunks)
|
263 |
+
except Exception as e:
|
264 |
+
raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}")
|
265 |
+
|
266 |
+
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
|
267 |
+
|
268 |
+
|
269 |
+
# ----------------------- Delete Document Endpoint -----------------------
|
270 |
+
@app.post("/delete_document")
|
271 |
+
def delete_document(ids: List[str]):
|
272 |
+
"""
|
273 |
+
Delete document(s) from the vector store using their IDs.
|
274 |
+
"""
|
275 |
+
try:
|
276 |
+
success = vector_store_manager.delete_documents(ids)
|
277 |
+
except Exception as e:
|
278 |
+
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")
|
279 |
+
if not success:
|
280 |
+
raise HTTPException(status_code=400, detail="Failed to delete documents.")
|
281 |
+
return {"message": f"Deleted documents with IDs: {ids}"}
|
282 |
+
|
283 |
+
|
284 |
+
# ----------------------- Save Vectorstore Endpoint -----------------------
|
285 |
+
@app.get("/save_vectorstore")
|
286 |
+
def save_vectorstore():
|
287 |
+
"""
|
288 |
+
Save the current vector store locally.
|
289 |
+
If it is a directory, it will be zipped.
|
290 |
+
Returns the file as a downloadable response.
|
291 |
+
"""
|
292 |
+
try:
|
293 |
+
save_result = vector_store_manager.save("faiss_index")
|
294 |
+
except Exception as e:
|
295 |
+
raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}")
|
296 |
+
return FileResponse(
|
297 |
+
path=save_result["file_path"],
|
298 |
+
media_type=save_result["media_type"],
|
299 |
+
filename=save_result["serve_filename"],
|
300 |
+
)
|
301 |
+
|
302 |
+
|
303 |
+
# ----------------------- Load Vectorstore Endpoint -----------------------
|
304 |
+
@app.post("/load_vectorstore")
|
305 |
+
async def load_vectorstore(file: UploadFile = File(...)):
|
306 |
+
"""
|
307 |
+
Load a vector store from an uploaded file (raw or zipped).
|
308 |
+
This will replace the current vector store.
|
309 |
+
"""
|
310 |
+
tmp_filename = None
|
311 |
+
try:
|
312 |
+
# Save the uploaded file content to a temporary file.
|
313 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
314 |
+
file_bytes = await file.read() # await to get bytes
|
315 |
+
tmp.write(file_bytes)
|
316 |
+
tmp_filename = tmp.name
|
317 |
+
|
318 |
+
instance, message = VectorStoreManager.load(tmp_filename, embeddings)
|
319 |
+
except Exception as e:
|
320 |
+
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
|
321 |
+
finally:
|
322 |
+
if tmp_filename and os.path.exists(tmp_filename):
|
323 |
+
os.remove(tmp_filename)
|
324 |
+
global vector_store_manager
|
325 |
+
vector_store_manager = instance
|
326 |
+
return {"message": message}
|
327 |
+
|
328 |
+
|
329 |
+
# ----------------------- Merge Vectorstore Endpoint -----------------------
|
330 |
+
@app.post("/merge_vectorstore")
|
331 |
+
async def merge_vectorstore(file: UploadFile = File(...)):
|
332 |
+
"""
|
333 |
+
Merge an uploaded vector store (raw or zipped) into the current vector store.
|
334 |
+
"""
|
335 |
+
tmp_filename = None
|
336 |
+
try:
|
337 |
+
# Save the uploaded file content to a temporary file.
|
338 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
339 |
+
file_bytes = await file.read() # Await the file.read() coroutine!
|
340 |
+
tmp.write(file_bytes)
|
341 |
+
tmp_filename = tmp.name
|
342 |
+
|
343 |
+
# Pass the filename (a string) to the merge method.
|
344 |
+
result = vector_store_manager.merge(tmp_filename, embeddings)
|
345 |
+
except Exception as e:
|
346 |
+
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
|
347 |
+
finally:
|
348 |
+
if tmp_filename and os.path.exists(tmp_filename):
|
349 |
+
os.remove(tmp_filename)
|
350 |
+
return result
|
351 |
+
|
352 |
+
|
353 |
+
if __name__ == "__main__":
|
354 |
+
import uvicorn
|
355 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
prompt_templates.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import ChatPromptTemplate
|
2 |
+
|
3 |
+
class PromptTemplates:
|
4 |
+
"""
|
5 |
+
A class to encapsulate various prompt templates for solving assignments, papers, creating quizzes, and assignments.
|
6 |
+
"""
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def get_quiz_solving_prompt():
|
10 |
+
|
11 |
+
quiz_solving_prompt = '''
|
12 |
+
You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers.
|
13 |
+
Use the following retrieved context to answer the user's question.
|
14 |
+
If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information.
|
15 |
+
|
16 |
+
Guidelines:
|
17 |
+
1. Extract key information from the context to form a coherent response.
|
18 |
+
2. Maintain a clear and professional tone.
|
19 |
+
3. If the question requires clarification, specify it politely.
|
20 |
+
|
21 |
+
Retrieved context:
|
22 |
+
{context}
|
23 |
+
|
24 |
+
User's question:
|
25 |
+
{question}
|
26 |
+
|
27 |
+
Your response:
|
28 |
+
'''
|
29 |
+
|
30 |
+
|
31 |
+
# Create a prompt template to pass the context and user input to the chain
|
32 |
+
prompt = ChatPromptTemplate.from_messages(
|
33 |
+
[
|
34 |
+
("system", quiz_solving_prompt),
|
35 |
+
("human", "{question}"),
|
36 |
+
]
|
37 |
+
)
|
38 |
+
|
39 |
+
return prompt
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_assignment_solving_prompt():
|
43 |
+
# Prompt template for solving assignments
|
44 |
+
assignment_solving_prompt = '''
|
45 |
+
You are an expert assistant specializing in solving academic assignments with clarity and precision.
|
46 |
+
Your task is to provide step-by-step solutions and detailed explanations that align with the given requirements.
|
47 |
+
|
48 |
+
Retrieved context:
|
49 |
+
{context}
|
50 |
+
|
51 |
+
Assignment Details:
|
52 |
+
{question}
|
53 |
+
|
54 |
+
Guidelines:
|
55 |
+
1. **Understand the Problem:** Carefully analyze the assignment details to identify the objective and requirements.
|
56 |
+
2. **Provide a Step-by-Step Solution:** Break down the solution into clear, logical steps. Use examples where appropriate.
|
57 |
+
3. **Explain Your Reasoning:** Include concise explanations for each step to enhance understanding.
|
58 |
+
4. **Follow Formatting Rules:** Ensure the response matches any specified formatting or citation guidelines.
|
59 |
+
5. **Maintain Academic Integrity:** Do not fabricate information, copy content verbatim without attribution, or complete the task in a way that breaches academic honesty policies.
|
60 |
+
|
61 |
+
Deliverable:
|
62 |
+
Provide the final answer in the format outlined in the assignment description. Where relevant, include:
|
63 |
+
- A brief introduction summarizing the approach.
|
64 |
+
- Calculations or code (if applicable).
|
65 |
+
- Any necessary diagrams, tables, or figures (use textual descriptions for diagrams if unavailable).
|
66 |
+
- A conclusion summarizing the findings.
|
67 |
+
|
68 |
+
If the assignment details are incomplete or ambiguous, specify what additional information is required to proceed.
|
69 |
+
|
70 |
+
Assignment Response:
|
71 |
+
'''
|
72 |
+
|
73 |
+
# Create a prompt template to pass the context and user input to the chain
|
74 |
+
prompt = ChatPromptTemplate.from_messages(
|
75 |
+
[
|
76 |
+
("system", assignment_solving_prompt),
|
77 |
+
("human", "{question}"),
|
78 |
+
]
|
79 |
+
)
|
80 |
+
|
81 |
+
return prompt
|
82 |
+
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def get_paper_solving_prompt():
|
86 |
+
# Prompt template for solving papers
|
87 |
+
paper_solving_prompt = '''
|
88 |
+
You are an expert assistant specialized in solving academic papers with precision and clarity.
|
89 |
+
Your task is to provide well-structured answers to the questions in the paper, ensuring accuracy, depth, and adherence to any specified instructions.
|
90 |
+
|
91 |
+
Retrieved context:
|
92 |
+
{context}
|
93 |
+
|
94 |
+
|
95 |
+
Paper Information:
|
96 |
+
{question}
|
97 |
+
|
98 |
+
Instructions:
|
99 |
+
1. **Understand Each Question:** Read each question carefully and identify its requirements, keywords, and scope.
|
100 |
+
2. **Structured Responses:** Provide answers in a clear, logical structure (e.g., Introduction, Body, Conclusion).
|
101 |
+
3. **Depth and Accuracy:** Support answers with explanations, examples, calculations, or references where applicable.
|
102 |
+
4. **Formatting Guidelines:** Adhere to any specified format or style (e.g., bullet points, paragraphs, equations).
|
103 |
+
5. **Time Efficiency:** If the paper is timed, prioritize accuracy and completeness over excessive detail.
|
104 |
+
6. **Clarify Ambiguities:** If any question is unclear, mention the assumptions made while answering.
|
105 |
+
7. **Ethical Guidelines:** Ensure the answers are original and aligned with academic integrity standards.
|
106 |
+
|
107 |
+
Deliverables:
|
108 |
+
- Answer all questions to the best of your ability.
|
109 |
+
- Include relevant diagrams, tables, or code (describe diagrams in text if unavailable).
|
110 |
+
- Summarize key points in a conclusion where applicable.
|
111 |
+
- Clearly number and label answers to match the questions in the paper.
|
112 |
+
|
113 |
+
|
114 |
+
If the paper includes multiple sections, label each section and solve sequentially.
|
115 |
+
|
116 |
+
Paper Solution:
|
117 |
+
'''
|
118 |
+
|
119 |
+
# Create a prompt template to pass the context and user input to the chain
|
120 |
+
prompt = ChatPromptTemplate.from_messages(
|
121 |
+
[
|
122 |
+
("system", paper_solving_prompt),
|
123 |
+
("human", "{question}"),
|
124 |
+
]
|
125 |
+
)
|
126 |
+
|
127 |
+
return prompt
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def get_quiz_creation_prompt():
|
131 |
+
# Prompt template for creating a quiz
|
132 |
+
quiz_creation_prompt = '''
|
133 |
+
You are an expert assistant specializing in creating engaging and educational quizzes for students.
|
134 |
+
Your task is to design a quiz based on the topic, difficulty level, and format specified by the teacher.
|
135 |
+
|
136 |
+
Retrieved context:
|
137 |
+
{context}
|
138 |
+
|
139 |
+
Quiz Details:
|
140 |
+
Topic: {question}
|
141 |
+
|
142 |
+
Guidelines for Quiz Creation:
|
143 |
+
1. **Relevance to Topic:** Ensure all questions are directly related to the specified topic.
|
144 |
+
2. **Clear and Concise Wording:** Write questions clearly and concisely to avoid ambiguity.
|
145 |
+
3. **Diverse Question Types:** Incorporate a variety of question types if specified.
|
146 |
+
4. **Appropriate Difficulty:** Tailor the complexity of the questions to match the target audience and difficulty level.
|
147 |
+
5. **Answer Key:** Provide correct answers or explanations for each question.
|
148 |
+
|
149 |
+
Deliverables:
|
150 |
+
- A complete quiz with numbered questions.
|
151 |
+
- An answer key with correct answers and explanations where relevant.
|
152 |
+
|
153 |
+
Quiz:
|
154 |
+
'''
|
155 |
+
|
156 |
+
# Create a prompt template to pass the context and user input to the chain
|
157 |
+
prompt = ChatPromptTemplate.from_messages(
|
158 |
+
[
|
159 |
+
("system", quiz_creation_prompt),
|
160 |
+
("human", "{question}"),
|
161 |
+
]
|
162 |
+
)
|
163 |
+
|
164 |
+
return prompt
|
165 |
+
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def get_assignment_creation_prompt():
|
169 |
+
# Prompt template for creating an assignment
|
170 |
+
assignment_creation_prompt = '''
|
171 |
+
You are an expert assistant specializing in designing assignments that align with the educational goals and requirements of teachers.
|
172 |
+
Your task is to create a comprehensive assignment based on the provided topic, target audience, and desired outcomes.
|
173 |
+
|
174 |
+
Retrieved context:
|
175 |
+
{context}
|
176 |
+
|
177 |
+
Assignment Details:
|
178 |
+
Topic: {question}
|
179 |
+
|
180 |
+
Guidelines for Assignment Creation:
|
181 |
+
1. **Alignment with Topic:** Ensure all tasks/questions are closely related to the specified topic and designed to achieve the teacher’s learning objectives.
|
182 |
+
2. **Clear Instructions:** Provide detailed and clear instructions for each question or task.
|
183 |
+
3. **Encourage Critical Thinking:** Include questions or tasks that require analysis, creativity, and application of knowledge where appropriate.
|
184 |
+
4. **Variety of Tasks:** Incorporate diverse question types (e.g., short answers, essays, practical tasks) as per the specified format.
|
185 |
+
5. **Grading Rubric (Optional):** Include a grading rubric or evaluation criteria if specified in the instructions.
|
186 |
+
|
187 |
+
Deliverables:
|
188 |
+
- A detailed assignment with numbered tasks/questions.
|
189 |
+
- Any required supporting materials (e.g., diagrams, data tables, references).
|
190 |
+
- (Optional) A grading rubric or expected outcomes for each task.
|
191 |
+
|
192 |
+
Assignment:
|
193 |
+
'''
|
194 |
+
|
195 |
+
# Create a prompt template to pass the context and user input to the chain
|
196 |
+
prompt = ChatPromptTemplate.from_messages(
|
197 |
+
[
|
198 |
+
("system", assignment_creation_prompt),
|
199 |
+
("human", "{question}"),
|
200 |
+
]
|
201 |
+
)
|
202 |
+
|
203 |
+
return prompt
|
204 |
+
|
205 |
+
|
206 |
+
@staticmethod
|
207 |
+
def get_paper_creation_prompt():
|
208 |
+
# Prompt template for creating an academic paper
|
209 |
+
paper_creation_prompt = '''
|
210 |
+
You are an expert assistant specializing in designing comprehensive academic papers tailored to the educational goals and requirements of teachers.
|
211 |
+
Your task is to create a complete paper based on the specified topic, audience, format, and difficulty level.
|
212 |
+
|
213 |
+
Retrieved context:
|
214 |
+
{context}
|
215 |
+
|
216 |
+
Paper Details:
|
217 |
+
Subject/Topic: {question}
|
218 |
+
|
219 |
+
Guidelines for Paper Creation:
|
220 |
+
1. **Relevance and Alignment:** Ensure all questions align with the specified subject/topic and are tailored to the target audience’s curriculum or learning objectives.
|
221 |
+
2. **Clear Wording:** Write questions in clear, concise language to avoid ambiguity or confusion.
|
222 |
+
3. **Diverse Question Types:** Incorporate a variety of question formats as specified (e.g., multiple-choice, fill-in-the-blank, long-form essays).
|
223 |
+
4. **Grading and Marks Allocation:** Provide a suggested mark allocation for each question, ensuring it reflects the question's complexity and time required.
|
224 |
+
5. **Answer Key:** Include correct answers or model responses for objective and descriptive questions (optional).
|
225 |
+
|
226 |
+
Deliverables:
|
227 |
+
- A complete paper with numbered questions, organized by sections if required.
|
228 |
+
- An answer key or marking scheme (if requested).
|
229 |
+
- Any supporting materials (e.g., diagrams, charts, or data tables) if applicable.
|
230 |
+
|
231 |
+
Paper:
|
232 |
+
'''
|
233 |
+
|
234 |
+
# Create a prompt template to pass the context and user input to the chain
|
235 |
+
prompt = ChatPromptTemplate.from_messages(
|
236 |
+
[
|
237 |
+
("system", paper_creation_prompt),
|
238 |
+
("human", "{question}"),
|
239 |
+
]
|
240 |
+
)
|
241 |
+
|
242 |
+
return prompt
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi
|
2 |
+
# python-jose
|
3 |
+
python-dotenv
|
4 |
+
# bcrypt
|
5 |
+
# passlib
|
6 |
+
uvicorn
|
7 |
+
# pyjwt
|
8 |
+
python-multipart
|
9 |
+
# pydantic[email]
|
10 |
+
pymongo
|
11 |
+
faiss-cpu
|
12 |
+
sentence_transformers
|
13 |
+
langchain_groq
|
14 |
+
langchain-community
|
15 |
+
langchain_unstructured
|
16 |
+
unstructured[all-docs]
|
17 |
+
unstructured[docx]
|
18 |
+
unstructured
|
19 |
+
unstructured[pdf]
|
20 |
+
langchain-mongodb
|
21 |
+
langchain_huggingface
|
22 |
+
wikipedia
|
23 |
+
docx2txt
|
retrieval_chain.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import ConversationalRetrievalChain
|
2 |
+
from langchain.prompts import ChatPromptTemplate
|
3 |
+
|
4 |
+
class RetrievalChain:
|
5 |
+
def __init__(self, llm, retriever, user_prompt, verbose=False):
|
6 |
+
"""
|
7 |
+
Initializes the RetrievalChain with an LLM and retriever.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
llm: Language model to use for the conversational chain.
|
11 |
+
retriever: Retriever object to fetch relevant documents.
|
12 |
+
user_prompt: Custom prompt to guide the chain.
|
13 |
+
verbose (bool): Whether to print verbose chain outputs.
|
14 |
+
"""
|
15 |
+
self.llm = llm
|
16 |
+
|
17 |
+
self.chain = ConversationalRetrievalChain.from_llm(
|
18 |
+
llm=llm,
|
19 |
+
retriever=retriever,
|
20 |
+
return_source_documents=True,
|
21 |
+
chain_type='stuff',
|
22 |
+
combine_docs_chain_kwargs={"prompt": user_prompt},
|
23 |
+
verbose=verbose,
|
24 |
+
)
|
25 |
+
|
26 |
+
def summarize_messages(self, chat_history):
|
27 |
+
"""
|
28 |
+
Summarizes the chat history into a concise message.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
chat_history: The chat history object for the session.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
bool: True if summarization is successful, False otherwise.
|
35 |
+
"""
|
36 |
+
stored_messages = chat_history.messages
|
37 |
+
if len(stored_messages) == 0:
|
38 |
+
return False
|
39 |
+
|
40 |
+
summarization_prompt = ChatPromptTemplate.from_messages(
|
41 |
+
[
|
42 |
+
("placeholder", "{chat_history}"),
|
43 |
+
(
|
44 |
+
"human",
|
45 |
+
"Summarize the above chat messages into a single concise message. Include only the important specific details.",
|
46 |
+
),
|
47 |
+
]
|
48 |
+
)
|
49 |
+
# Create a chain for summarization by piping the prompt into the language model.
|
50 |
+
summarization_chain = summarization_prompt | self.llm
|
51 |
+
summary_message = summarization_chain.invoke({"chat_history": stored_messages})
|
52 |
+
|
53 |
+
chat_history.clear() # Clear the existing chat history
|
54 |
+
chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry
|
55 |
+
return True
|
56 |
+
|
57 |
+
def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history):
|
58 |
+
"""
|
59 |
+
Streams the response to a query in real-time for a given chat session using SSE formatting.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
query (str): The user's query.
|
63 |
+
chat_id (str): The unique ID of the chat session.
|
64 |
+
get_chat_history (function): Function to retrieve chat history by chat ID.
|
65 |
+
initialize_chat_history (function): Function to initialize a new chat history.
|
66 |
+
|
67 |
+
Yields:
|
68 |
+
str: Server-Sent Event (SSE) formatted string for each chunk of the response.
|
69 |
+
"""
|
70 |
+
# Retrieve the chat history for the session.
|
71 |
+
chat_message_history = get_chat_history(chat_id)
|
72 |
+
if not chat_message_history:
|
73 |
+
# If no chat history exists, initialize one.
|
74 |
+
chat_message_history = initialize_chat_history(chat_id)
|
75 |
+
|
76 |
+
# Optionally summarize previous messages.
|
77 |
+
self.summarize_messages(chat_message_history)
|
78 |
+
chat_history = chat_message_history.messages
|
79 |
+
|
80 |
+
# Prepare input data for the conversational retrieval chain.
|
81 |
+
input_data_for_chain = {
|
82 |
+
"question": query,
|
83 |
+
"chat_history": chat_history
|
84 |
+
}
|
85 |
+
|
86 |
+
# Add the user query to the chat history.
|
87 |
+
chat_message_history.add_user_message(query)
|
88 |
+
|
89 |
+
# Execute the chain in streaming mode (this assumes the chain supports a `stream` method).
|
90 |
+
response_stream = self.chain.stream(input_data_for_chain)
|
91 |
+
|
92 |
+
accumulated_response = ""
|
93 |
+
# Process the response stream and yield SSE events.
|
94 |
+
for chunk in response_stream:
|
95 |
+
if 'answer' in chunk:
|
96 |
+
accumulated_response += chunk['answer']
|
97 |
+
# Format the SSE event.
|
98 |
+
sse_event = f"data: {chunk['answer']}\n\n"
|
99 |
+
yield sse_event
|
100 |
+
else:
|
101 |
+
# Yield an SSE event with debug info if the chunk structure is unexpected.
|
102 |
+
debug_msg = f"Unexpected chunk structure: {chunk}"
|
103 |
+
yield f"data: {debug_msg}\n\n"
|
104 |
+
|
105 |
+
# Once streaming is complete, update chat history with the final response.
|
106 |
+
if accumulated_response:
|
107 |
+
chat_message_history.add_ai_message(accumulated_response)
|
108 |
+
else:
|
109 |
+
yield "data: No valid response content was generated.\n\n"
|
110 |
+
|
text_splitter.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
2 |
+
|
3 |
+
class TextSplitter:
|
4 |
+
def __init__(self, chunk_size=1024, chunk_overlap=100):
|
5 |
+
"""
|
6 |
+
Initialize the TextSplitter with a specific chunk size and overlap.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
chunk_size (int): The size of each text chunk.
|
10 |
+
chunk_overlap (int): The overlap size between chunks.
|
11 |
+
"""
|
12 |
+
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
13 |
+
|
14 |
+
def split_documents(self, documents):
|
15 |
+
"""
|
16 |
+
Split the provided documents into chunks based on the chunk size and overlap.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
documents (list): A list of documents to be split.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
A list of split documents.
|
23 |
+
|
24 |
+
Exceptions:
|
25 |
+
Prints an error message if splitting documents fails.
|
26 |
+
"""
|
27 |
+
try:
|
28 |
+
return self.text_splitter.split_documents(documents)
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Error splitting documents: {e}")
|
vector_store.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import uuid
|
3 |
+
import shutil
|
4 |
+
import tempfile
|
5 |
+
import zipfile
|
6 |
+
|
7 |
+
from faiss import IndexFlatL2
|
8 |
+
|
9 |
+
from langchain_community.vectorstores import FAISS
|
10 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
11 |
+
|
12 |
+
|
13 |
+
class VectorStoreManager:
|
14 |
+
def __init__(self, embeddings=None):
|
15 |
+
"""
|
16 |
+
Initializes the VectorStoreManager with a FAISS vector store.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
embeddings (Embeddings, optional): Embeddings model used for the vector store.
|
20 |
+
"""
|
21 |
+
self.vectorstore = None
|
22 |
+
if embeddings:
|
23 |
+
self.vectorstore = self.create_vectorstore(embeddings)
|
24 |
+
|
25 |
+
def create_vectorstore(self, embeddings):
|
26 |
+
"""
|
27 |
+
Creates and initializes a FAISS vector store.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
embeddings (Embeddings): Embeddings model used for the vector store.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
FAISS: Initialized vector store.
|
34 |
+
"""
|
35 |
+
# Define vector store dimensions based on embeddings
|
36 |
+
dimensions = len(embeddings.embed_query("dummy"))
|
37 |
+
|
38 |
+
# Initialize FAISS vector store
|
39 |
+
vectorstore = FAISS(
|
40 |
+
embedding_function=embeddings,
|
41 |
+
index=IndexFlatL2(dimensions),
|
42 |
+
docstore=InMemoryDocstore(),
|
43 |
+
index_to_docstore_id={},
|
44 |
+
normalize_L2=False
|
45 |
+
)
|
46 |
+
|
47 |
+
print("Created a new FAISS vector store.")
|
48 |
+
return vectorstore
|
49 |
+
|
50 |
+
def add_documents(self, documents):
|
51 |
+
"""
|
52 |
+
Adds new documents to the FAISS vector store, each document with a unique UUID.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
documents (list): List of Document objects to be added to the vector store.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
list: List of UUIDs corresponding to the added documents.
|
59 |
+
"""
|
60 |
+
if not self.vectorstore:
|
61 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
62 |
+
|
63 |
+
uuids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
64 |
+
self.vectorstore.add_documents(documents=documents, ids=uuids)
|
65 |
+
|
66 |
+
print(f"Added {len(documents)} documents to the vector store with IDs: {uuids}")
|
67 |
+
return uuids
|
68 |
+
|
69 |
+
def delete_documents(self, ids):
|
70 |
+
"""
|
71 |
+
Deletes documents from the FAISS vector store using their unique IDs.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
ids (list): List of UUIDs corresponding to the documents to be deleted.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
bool: True if the documents were successfully deleted, False otherwise.
|
78 |
+
"""
|
79 |
+
if not self.vectorstore:
|
80 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
81 |
+
|
82 |
+
if not ids:
|
83 |
+
print("No document IDs provided for deletion.")
|
84 |
+
return False
|
85 |
+
|
86 |
+
success = self.vectorstore.delete(ids=ids)
|
87 |
+
if success:
|
88 |
+
print(f"Successfully deleted documents with IDs: {ids}")
|
89 |
+
else:
|
90 |
+
print(f"Failed to delete documents with IDs: {ids}")
|
91 |
+
return success
|
92 |
+
|
93 |
+
def save(self, filename="faiss_index"):
|
94 |
+
"""
|
95 |
+
Saves the current FAISS vector store locally. If the saved store is a directory,
|
96 |
+
it compresses it into a ZIP archive.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
filename (str): The filename or directory name where the vector store will be saved.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
dict: A dictionary with details about the saved file including file path and media type.
|
103 |
+
"""
|
104 |
+
if not self.vectorstore:
|
105 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
106 |
+
|
107 |
+
# Save the vectorstore locally
|
108 |
+
self.vectorstore.save_local(filename)
|
109 |
+
print(f"Vector store saved to {filename}")
|
110 |
+
|
111 |
+
if not os.path.exists(filename):
|
112 |
+
raise FileNotFoundError("Saved vectorstore not found.")
|
113 |
+
|
114 |
+
# If the saved vectorstore is a directory, compress it into a zip file.
|
115 |
+
if os.path.isdir(filename):
|
116 |
+
zip_filename = filename + ".zip"
|
117 |
+
shutil.make_archive(filename, 'zip', filename)
|
118 |
+
return {
|
119 |
+
"file_path": zip_filename,
|
120 |
+
"media_type": "application/zip",
|
121 |
+
"serve_filename": os.path.basename(zip_filename),
|
122 |
+
"original": filename,
|
123 |
+
}
|
124 |
+
else:
|
125 |
+
return {
|
126 |
+
"file_path": filename,
|
127 |
+
"media_type": "application/octet-stream",
|
128 |
+
"serve_filename": os.path.basename(filename),
|
129 |
+
"original": filename,
|
130 |
+
}
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def load(file_input, embeddings):
|
134 |
+
"""
|
135 |
+
Loads a FAISS vector store from an uploaded file or a filename.
|
136 |
+
If file_input is a file-like object, it is saved to a temporary file.
|
137 |
+
If it's a string (filename), it is used directly.
|
138 |
+
"""
|
139 |
+
# Check if file_input is a string (filename) or a file-like object.
|
140 |
+
if isinstance(file_input, str):
|
141 |
+
tmp_filename = file_input
|
142 |
+
else:
|
143 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
144 |
+
tmp.write(file_input.read())
|
145 |
+
tmp_filename = tmp.name
|
146 |
+
|
147 |
+
try:
|
148 |
+
if zipfile.is_zipfile(tmp_filename):
|
149 |
+
with tempfile.TemporaryDirectory() as extract_dir:
|
150 |
+
with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
|
151 |
+
zip_ref.extractall(extract_dir)
|
152 |
+
extracted_items = os.listdir(extract_dir)
|
153 |
+
if len(extracted_items) == 1:
|
154 |
+
potential_dir = os.path.join(extract_dir, extracted_items[0])
|
155 |
+
if os.path.isdir(potential_dir):
|
156 |
+
vectorstore_dir = potential_dir
|
157 |
+
else:
|
158 |
+
vectorstore_dir = extract_dir
|
159 |
+
else:
|
160 |
+
vectorstore_dir = extract_dir
|
161 |
+
|
162 |
+
new_vectorstore = FAISS.load_local(vectorstore_dir, embeddings, allow_dangerous_deserialization=True)
|
163 |
+
message = "Vector store loaded successfully from ZIP."
|
164 |
+
else:
|
165 |
+
new_vectorstore = FAISS.load_local(tmp_filename, embeddings, allow_dangerous_deserialization=True)
|
166 |
+
message = "Vector store loaded successfully."
|
167 |
+
except Exception as e:
|
168 |
+
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
|
169 |
+
finally:
|
170 |
+
# Only remove the temp file if we created it here (i.e. file_input was not a filename)
|
171 |
+
if not isinstance(file_input, str) and os.path.exists(tmp_filename):
|
172 |
+
os.remove(tmp_filename)
|
173 |
+
|
174 |
+
instance = VectorStoreManager()
|
175 |
+
instance.vectorstore = new_vectorstore
|
176 |
+
print(message)
|
177 |
+
return instance, message
|
178 |
+
|
179 |
+
def merge(self, file_input, embeddings):
|
180 |
+
"""
|
181 |
+
Merges an uploaded vector store file into the current FAISS vector store.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
file_input (Union[file-like object, str]): An object with a .read() method or a filename (str).
|
185 |
+
embeddings (Embeddings): Embeddings model used for loading the vector store.
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
dict: A dictionary containing a message indicating successful merging.
|
189 |
+
"""
|
190 |
+
# Determine if file_input is a filename (str) or a file-like object.
|
191 |
+
if isinstance(file_input, str):
|
192 |
+
tmp_filename = file_input
|
193 |
+
temp_created = False
|
194 |
+
else:
|
195 |
+
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
196 |
+
tmp.write(file_input.read())
|
197 |
+
tmp_filename = tmp.name
|
198 |
+
temp_created = True
|
199 |
+
|
200 |
+
try:
|
201 |
+
# Check if the file is a ZIP archive.
|
202 |
+
if zipfile.is_zipfile(tmp_filename):
|
203 |
+
with tempfile.TemporaryDirectory() as extract_dir:
|
204 |
+
with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
|
205 |
+
zip_ref.extractall(extract_dir)
|
206 |
+
extracted_items = os.listdir(extract_dir)
|
207 |
+
if len(extracted_items) == 1:
|
208 |
+
potential_dir = os.path.join(extract_dir, extracted_items[0])
|
209 |
+
if os.path.isdir(potential_dir):
|
210 |
+
vectorstore_dir = potential_dir
|
211 |
+
else:
|
212 |
+
vectorstore_dir = extract_dir
|
213 |
+
else:
|
214 |
+
vectorstore_dir = extract_dir
|
215 |
+
|
216 |
+
source_store = FAISS.load_local(
|
217 |
+
vectorstore_dir, embeddings, allow_dangerous_deserialization=True
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
source_store = FAISS.load_local(
|
221 |
+
tmp_filename, embeddings, allow_dangerous_deserialization=True
|
222 |
+
)
|
223 |
+
|
224 |
+
if not self.vectorstore:
|
225 |
+
raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
|
226 |
+
|
227 |
+
self.vectorstore.merge_from(source_store)
|
228 |
+
print("Successfully merged the source vector store into the current vector store.")
|
229 |
+
except Exception as e:
|
230 |
+
raise Exception(f"Error merging vectorstore: {str(e)}")
|
231 |
+
finally:
|
232 |
+
if temp_created and os.path.exists(tmp_filename):
|
233 |
+
os.remove(tmp_filename)
|
234 |
+
return {"message": "Vector stores merged successfully"}
|