mominah commited on
Commit
7b7cab6
·
verified ·
1 Parent(s): 8360eea

Upload 11 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image using Python 3.9
2
+ FROM python:3.9
3
+
4
+ # Create a new user to run the app
5
+ RUN useradd -m -u 1000 user
6
+ USER user
7
+
8
+ # Set environment variables
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+
11
+ # Set the working directory
12
+ WORKDIR /app
13
+
14
+ # Copy the requirements and install dependencies
15
+ COPY --chown=user ./requirements.txt requirements.txt
16
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
17
+
18
+ # Copy the rest of the application
19
+ COPY --chown=user . /app
20
+
21
+ # Expose port 7860 for the application
22
+ EXPOSE 7860
23
+
24
+ # Command to run the FastAPI app using uvicorn
25
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
chat_management.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from pymongo import MongoClient
3
+ from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
4
+
5
+
6
+ class ChatManagement:
7
+ def __init__(self, cluster_url, database_name, collection_name):
8
+ """
9
+ Initializes the ChatManagement class with MongoDB connection details.
10
+
11
+ Args:
12
+ cluster_url (str): MongoDB cluster URL.
13
+ database_name (str): Name of the database.
14
+ collection_name (str): Name of the collection.
15
+ """
16
+ self.connection_string = cluster_url
17
+ self.database_name = database_name
18
+ self.collection_name = collection_name
19
+ self.chat_sessions = {} # Dictionary to store chat history objects for each session
20
+
21
+ def create_new_chat(self):
22
+ """
23
+ Creates a new chat session by initializing a MongoDBChatMessageHistory object.
24
+
25
+ Returns:
26
+ str: The unique chat ID.
27
+ """
28
+ # Generate a unique chat ID
29
+ chat_id = str(uuid.uuid4())
30
+
31
+ # Initialize MongoDBChatMessageHistory for the chat session
32
+ chat_message_history = MongoDBChatMessageHistory(
33
+ session_id=chat_id,
34
+ connection_string=self.connection_string,
35
+ database_name=self.database_name,
36
+ collection_name=self.collection_name,
37
+ )
38
+
39
+ # Store the chat_message_history object in the session dictionary
40
+ self.chat_sessions[chat_id] = chat_message_history
41
+ return chat_id
42
+
43
+ def get_chat_history(self, chat_id):
44
+ """
45
+ Retrieves the MongoDBChatMessageHistory object for a given chat session by its chat ID.
46
+
47
+ Args:
48
+ chat_id (str): The unique ID of the chat session.
49
+
50
+ Returns:
51
+ MongoDBChatMessageHistory or None: The chat history object of the chat session, or None if not found.
52
+ """
53
+ # Check if the chat session is already in memory
54
+ if chat_id in self.chat_sessions:
55
+ return self.chat_sessions[chat_id]
56
+
57
+ # If not in memory, try to fetch from the database
58
+ chat_message_history = MongoDBChatMessageHistory(
59
+ session_id=chat_id,
60
+ connection_string=self.connection_string,
61
+ database_name=self.database_name,
62
+ collection_name=self.collection_name,
63
+ )
64
+ if chat_message_history.messages: # Check if the session exists in the database
65
+ self.chat_sessions[chat_id] = chat_message_history
66
+ return chat_message_history
67
+
68
+ return None # Chat session not found
69
+
70
+ def initialize_chat_history(self, chat_id):
71
+ """
72
+ Initializes a new chat history for the given chat ID if it does not already exist.
73
+
74
+ Args:
75
+ chat_id (str): The unique ID of the chat session.
76
+
77
+ Returns:
78
+ MongoDBChatMessageHistory: The initialized chat history object.
79
+ """
80
+ # If the chat history already exists, return it
81
+ if chat_id in self.chat_sessions:
82
+ return self.chat_sessions[chat_id]
83
+
84
+ # Otherwise, create a new chat history
85
+ chat_message_history = MongoDBChatMessageHistory(
86
+ session_id=chat_id,
87
+ connection_string=self.connection_string,
88
+ database_name=self.database_name,
89
+ collection_name=self.collection_name,
90
+ )
91
+
92
+ # Save the new chat session to the session dictionary
93
+ self.chat_sessions[chat_id] = chat_message_history
94
+ return chat_message_history
document_loaders.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import (CSVLoader, WikipediaLoader, UnstructuredURLLoader,
2
+ YoutubeLoader, PyPDFLoader, BSHTMLLoader,
3
+ Docx2txtLoader, UnstructuredMarkdownLoader)
4
+
5
+ from langchain_unstructured import UnstructuredLoader
6
+
7
+
8
+ class DocumentLoader:
9
+ def load_unstructured(self, path):
10
+ """
11
+ Load data from a file at the specified path:
12
+
13
+ supported files:
14
+ "csv", "doc", "docx", "epub", "image", "md", "msg", "odt", "org", "pdf", "ppt", "pptx", "rtf", "rst", "tsv", "xlsx"
15
+
16
+
17
+ Args:
18
+ path (str): The file paths
19
+
20
+ Returns:
21
+ The loaded data.
22
+
23
+ Exceptions:
24
+ Prints an error message if the loading fails.
25
+ """
26
+ try:
27
+ loader = UnstructuredLoader(path)
28
+ data = loader.load()
29
+ return data
30
+ except Exception as e:
31
+ print(f"Error loading Unstructured: {e}")
32
+
33
+ def load_csv(self, path):
34
+ """
35
+ Load data from a CSV file at the specified path.
36
+
37
+ Args:
38
+ path (str): The file path to the CSV file.
39
+
40
+ Returns:
41
+ The loaded CSV data.
42
+
43
+ Exceptions:
44
+ Prints an error message if the CSV loading fails.
45
+ """
46
+ try:
47
+ loader = CSVLoader(file_path=path)
48
+ data = loader.load()
49
+ return data
50
+ except Exception as e:
51
+ print(f"Error loading CSV: {e}")
52
+
53
+ def wikipedia_query(self, search_query):
54
+ """
55
+ Query Wikipedia using a given search term and return the results.
56
+
57
+ Args:
58
+ search_query (str): The search term to query on Wikipedia.
59
+
60
+ Returns:
61
+ The query results.
62
+
63
+ Exceptions:
64
+ Prints an error message if the Wikipedia query fails.
65
+ """
66
+ try:
67
+ data = WikipediaLoader(query=search_query, load_max_docs=2).load()
68
+ return data
69
+ except Exception as e:
70
+ print(f"Error querying Wikipedia: {e}")
71
+
72
+ def load_urls(self, urls):
73
+ """
74
+ Load and parse content from a list of URLs.
75
+
76
+ Args:
77
+ urls (list): A list of URLs to load.
78
+
79
+ Returns:
80
+ The loaded data from the URLs.
81
+
82
+ Exceptions:
83
+ Prints an error message if loading URLs fails.
84
+ """
85
+ try:
86
+ loader = UnstructuredURLLoader(urls=urls)
87
+ data = loader.load()
88
+ return data
89
+ except Exception as e:
90
+ print(f"Error loading URLs: {e}")
91
+
92
+ def load_YouTubeVideo(self, urls):
93
+ """
94
+ Load YouTube video information from provided URLs.
95
+
96
+ Args:
97
+ urls (list): A list of YouTube video URLs.
98
+
99
+ Returns:
100
+ The loaded documents from the YouTube URLs.
101
+
102
+ Exceptions:
103
+ Prints an error message if loading YouTube videos fails.
104
+ """
105
+ try:
106
+ loader = YoutubeLoader.from_youtube_url(
107
+ urls, add_video_info=True, language=["en", "pt", "zh-Hans", "es", "ur", "hi"],
108
+ translation="en")
109
+ documents = loader.load()
110
+ return documents
111
+ except Exception as e:
112
+ print(f"Error loading YouTube video: {e}")
113
+
114
+ def load_pdf(self, path):
115
+ """
116
+ Load data from a PDF file at the specified path.
117
+
118
+ Args:
119
+ path (str): The file path to the PDF file.
120
+
121
+ Returns:
122
+ The loaded and split PDF pages.
123
+
124
+ Exceptions:
125
+ Prints an error message if the PDF loading fails.
126
+ """
127
+ try:
128
+ loader = PyPDFLoader(path)
129
+ pages = loader.load_and_split()
130
+ return pages
131
+ except Exception as e:
132
+ print(f"Error loading PDF: {e}")
133
+
134
+ def load_text_from_html(self, path):
135
+ """
136
+ Load and parse text content from an HTML file at the specified path.
137
+
138
+ Args:
139
+ path (str): The file path to the HTML file.
140
+
141
+ Returns:
142
+ The loaded HTML data.
143
+
144
+ Exceptions:
145
+ Prints an error message if loading text from HTML fails.
146
+ """
147
+ try:
148
+ loader = BSHTMLLoader(path)
149
+ data = loader.load()
150
+ return data
151
+ except Exception as e:
152
+ print(f"Error loading text from HTML: {e}")
153
+
154
+ def load_markdown(self, path):
155
+ """
156
+ Load data from a Markdown file at the specified path.
157
+
158
+ Args:
159
+ path (str): The file path to the Markdown file.
160
+
161
+ Returns:
162
+ The loaded Markdown data.
163
+
164
+ Exceptions:
165
+ Prints an error message if loading Markdown fails.
166
+ """
167
+ try:
168
+ loader = UnstructuredMarkdownLoader(path)
169
+ data = loader.load()
170
+ return data
171
+ except Exception as e:
172
+ print(f"Error loading Markdown: {e}")
173
+
174
+ def load_doc(self, path):
175
+ """
176
+ Load data from a DOCX file at the specified path.
177
+
178
+ Args:
179
+ path (str): The file path to the DOCX file.
180
+
181
+ Returns:
182
+ The loaded DOCX data.
183
+
184
+ Exceptions:
185
+ Prints an error message if loading DOCX fails.
186
+ """
187
+ try:
188
+ loader = Docx2txtLoader(path)
189
+ data = loader.load()
190
+ return data
191
+ except Exception as e:
192
+ print(f"Error loading DOCX: {e}")
193
+
embedding.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_huggingface import HuggingFaceEmbeddings
2
+
3
+ def get_embeddings():
4
+ # Initialize HuggingFace embeddings
5
+ model_name = "BAAI/bge-small-en"
6
+ model_kwargs = {"device": "cpu"}
7
+ encode_kwargs = {"normalize_embeddings": True}
8
+ embeddings = HuggingFaceEmbeddings(
9
+ model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
10
+ )
11
+ return embeddings
llm_initialization.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_groq import ChatGroq
2
+
3
+ def get_llm():
4
+ """
5
+ Returns the language model instance (LLM) using ChatGroq API.
6
+ The LLM used is Llama 3.1 with a versatile 70 billion parameters model.
7
+
8
+ Returns:
9
+ llm (ChatGroq): An instance of the ChatGroq LLM.
10
+ """
11
+ llm = ChatGroq(
12
+ model="llama-3.3-70b-versatile",
13
+ temperature=0,
14
+ max_tokens=1024,
15
+ api_key='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'
16
+ )
17
+ return llm
main.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import zipfile
4
+ from typing import List, Optional
5
+
6
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Query
7
+ from fastapi.responses import FileResponse, StreamingResponse
8
+
9
+ from llm_initialization import get_llm
10
+ from embedding import get_embeddings
11
+ from document_loaders import DocumentLoader
12
+ from text_splitter import TextSplitter
13
+ from vector_store import VectorStoreManager
14
+ from prompt_templates import PromptTemplates
15
+ from chat_management import ChatManagement
16
+ from retrieval_chain import RetrievalChain
17
+ from urllib.parse import quote_plus
18
+ from dotenv import load_dotenv
19
+ from pymongo import MongoClient
20
+
21
+ # Load environment variables
22
+ load_dotenv()
23
+ MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD"))
24
+ MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME")
25
+ MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME")
26
+ MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING")
27
+
28
+ app = FastAPI(title="VectorStore & Document Management API")
29
+
30
+ # Global variables (initialized on startup)
31
+ llm = None
32
+ embeddings = None
33
+ chat_manager = None
34
+ document_loader = None
35
+ text_splitter = None
36
+ vector_store_manager = None
37
+ vector_store = None
38
+ k = 3 # Number of documents to retrieve per query
39
+
40
+ # Global MongoDB collection to store retrieval chain configuration per chat session.
41
+ chat_chains_collection = None
42
+
43
+ # ----------------------- Startup Event -----------------------
44
+ @app.on_event("startup")
45
+ async def startup_event():
46
+ global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection
47
+
48
+ print("Starting up: Initializing components...")
49
+
50
+ # Initialize LLM and embeddings
51
+ llm = get_llm()
52
+ print("LLM initialized.")
53
+ embeddings = get_embeddings()
54
+ print("Embeddings initialized.")
55
+
56
+ # Setup chat management
57
+ chat_manager = ChatManagement(
58
+ cluster_url=MONGO_CLUSTER_URL,
59
+ database_name=MONGO_DATABASE_NAME,
60
+ collection_name=MONGO_COLLECTION_NAME,
61
+ )
62
+ print("Chat management initialized.")
63
+
64
+ # Initialize document loader and text splitter
65
+ document_loader = DocumentLoader()
66
+ text_splitter = TextSplitter()
67
+ print("Document loader and text splitter initialized.")
68
+
69
+ # Initialize vector store manager and ensure vectorstore is set
70
+ vector_store_manager = VectorStoreManager(embeddings)
71
+ vector_store = vector_store_manager.vectorstore # Now properly initialized
72
+ print("Vector store initialized.")
73
+
74
+ # Connect to MongoDB and get the collection.
75
+ client = MongoClient(MONGO_CLUSTER_URL)
76
+ db = client[MONGO_DATABASE_NAME]
77
+ chat_chains_collection = db["chat_chains"]
78
+ print("Chat chains collection initialized in MongoDB.")
79
+
80
+
81
+ # ----------------------- Root Endpoint -----------------------
82
+ @app.get("/")
83
+ def root():
84
+ """
85
+ Root endpoint that returns a welcome message.
86
+ """
87
+ return {"message": "Welcome to the VectorStore & Document Management API!"}
88
+
89
+
90
+ # ----------------------- New Chat Endpoint -----------------------
91
+ @app.post("/new_chat")
92
+ def new_chat():
93
+ """
94
+ Create a new chat session.
95
+ """
96
+ new_chat_id = chat_manager.create_new_chat()
97
+ return {"chat_id": new_chat_id}
98
+
99
+
100
+ # ----------------------- Create Chain Endpoint -----------------------
101
+ @app.post("/create_chain")
102
+ def create_chain(
103
+ chat_id: str = Query(..., description="Existing chat session ID"),
104
+ template: str = Query(
105
+ "quiz_solving",
106
+ description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation",
107
+ ),
108
+ ):
109
+ global chat_chains_collection # Ensure we reference the global variable
110
+
111
+ valid_templates = [
112
+ "quiz_solving",
113
+ "assignment_solving",
114
+ "paper_solving",
115
+ "quiz_creation",
116
+ "assignment_creation",
117
+ "paper_creation",
118
+ ]
119
+ if template not in valid_templates:
120
+ raise HTTPException(status_code=400, detail="Invalid template selection.")
121
+
122
+ # Upsert the configuration document for this chat session.
123
+ chat_chains_collection.update_one(
124
+ {"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True
125
+ )
126
+
127
+ return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template}
128
+
129
+
130
+ # ----------------------- Chat Endpoint -----------------------
131
+ @app.get("/chat")
132
+ def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")):
133
+ """
134
+ Process a chat query using the retrieval chain associated with the given chat_id.
135
+
136
+ This endpoint uses the following code:
137
+
138
+ try:
139
+ stream_generator = retrieval_chain.stream_chat_response(
140
+ query=query,
141
+ chat_id=chat_id,
142
+ get_chat_history=chat_manager.get_chat_history,
143
+ initialize_chat_history=chat_manager.initialize_chat_history,
144
+ )
145
+ except Exception as e:
146
+ raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
147
+
148
+ return StreamingResponse(stream_generator, media_type="text/event-stream")
149
+
150
+ It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response.
151
+ """
152
+ # Retrieve the chat configuration from MongoDB.
153
+ config = chat_chains_collection.find_one({"chat_id": chat_id})
154
+ if not config:
155
+ raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.")
156
+
157
+ template = config.get("template", "quiz_solving")
158
+ if template == "quiz_solving":
159
+ prompt = PromptTemplates.get_quiz_solving_prompt()
160
+ elif template == "assignment_solving":
161
+ prompt = PromptTemplates.get_assignment_solving_prompt()
162
+ elif template == "paper_solving":
163
+ prompt = PromptTemplates.get_paper_solving_prompt()
164
+ elif template == "quiz_creation":
165
+ prompt = PromptTemplates.get_quiz_creation_prompt()
166
+ elif template == "assignment_creation":
167
+ prompt = PromptTemplates.get_assignment_creation_prompt()
168
+ elif template == "paper_creation":
169
+ prompt = PromptTemplates.get_paper_creation_prompt()
170
+ else:
171
+ raise HTTPException(status_code=400, detail="Invalid chat configuration.")
172
+
173
+ # Re-create the retrieval chain for this chat session.
174
+ retrieval_chain = RetrievalChain(
175
+ llm,
176
+ vector_store.as_retriever(search_kwargs={"k": k}),
177
+ prompt,
178
+ verbose=True,
179
+ )
180
+
181
+ try:
182
+ stream_generator = retrieval_chain.stream_chat_response(
183
+ query=query,
184
+ chat_id=chat_id,
185
+ get_chat_history=chat_manager.get_chat_history,
186
+ initialize_chat_history=chat_manager.initialize_chat_history,
187
+ )
188
+ except Exception as e:
189
+ raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}")
190
+
191
+ return StreamingResponse(stream_generator, media_type="text/event-stream")
192
+
193
+
194
+ # ----------------------- Add Document Endpoint -----------------------
195
+ from typing import Any, Optional
196
+
197
+ @app.post("/add_document")
198
+ async def add_document(
199
+ file: Optional[Any] = File(None),
200
+ wiki_query: Optional[str] = Query(None),
201
+ wiki_url: Optional[str] = Query(None)
202
+ ):
203
+ """
204
+ Upload a document OR load data from a Wikipedia query or URL.
205
+
206
+ - If a file is provided, the document is loaded from the file.
207
+ - If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query.
208
+ - If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls.
209
+
210
+ The loaded document(s) are then split into chunks and added to the vector store.
211
+ """
212
+ # If file is provided but not as an UploadFile (e.g. an empty string), set it to None.
213
+ if not isinstance(file, UploadFile):
214
+ file = None
215
+
216
+ # Ensure at least one input is provided.
217
+ if file is None and wiki_query is None and wiki_url is None:
218
+ raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).")
219
+
220
+ # Load document(s) based on input priority: file > wiki_query > wiki_url.
221
+ if file is not None:
222
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
223
+ contents = await file.read()
224
+ tmp.write(contents)
225
+ tmp_filename = tmp.name
226
+
227
+ ext = file.filename.split(".")[-1].lower()
228
+ try:
229
+ if ext == "pdf":
230
+ documents = document_loader.load_pdf(tmp_filename)
231
+ elif ext == "csv":
232
+ documents = document_loader.load_csv(tmp_filename)
233
+ elif ext in ["doc", "docx"]:
234
+ documents = document_loader.load_doc(tmp_filename)
235
+ elif ext in ["html", "htm"]:
236
+ documents = document_loader.load_text_from_html(tmp_filename)
237
+ elif ext in ["md", "markdown"]:
238
+ documents = document_loader.load_markdown(tmp_filename)
239
+ else:
240
+ documents = document_loader.load_unstructured(tmp_filename)
241
+ except Exception as e:
242
+ os.remove(tmp_filename)
243
+ raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}")
244
+ os.remove(tmp_filename)
245
+ elif wiki_query is not None:
246
+ try:
247
+ documents = document_loader.wikipedia_query(wiki_query)
248
+ except Exception as e:
249
+ raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}")
250
+ elif wiki_url is not None:
251
+ try:
252
+ documents = document_loader.load_urls([wiki_url])
253
+ except Exception as e:
254
+ raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}")
255
+
256
+ try:
257
+ chunks = text_splitter.split_documents(documents)
258
+ except Exception as e:
259
+ raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}")
260
+
261
+ try:
262
+ ids = vector_store_manager.add_documents(chunks)
263
+ except Exception as e:
264
+ raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}")
265
+
266
+ return {"message": f"Added {len(chunks)} document chunks.", "ids": ids}
267
+
268
+
269
+ # ----------------------- Delete Document Endpoint -----------------------
270
+ @app.post("/delete_document")
271
+ def delete_document(ids: List[str]):
272
+ """
273
+ Delete document(s) from the vector store using their IDs.
274
+ """
275
+ try:
276
+ success = vector_store_manager.delete_documents(ids)
277
+ except Exception as e:
278
+ raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}")
279
+ if not success:
280
+ raise HTTPException(status_code=400, detail="Failed to delete documents.")
281
+ return {"message": f"Deleted documents with IDs: {ids}"}
282
+
283
+
284
+ # ----------------------- Save Vectorstore Endpoint -----------------------
285
+ @app.get("/save_vectorstore")
286
+ def save_vectorstore():
287
+ """
288
+ Save the current vector store locally.
289
+ If it is a directory, it will be zipped.
290
+ Returns the file as a downloadable response.
291
+ """
292
+ try:
293
+ save_result = vector_store_manager.save("faiss_index")
294
+ except Exception as e:
295
+ raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}")
296
+ return FileResponse(
297
+ path=save_result["file_path"],
298
+ media_type=save_result["media_type"],
299
+ filename=save_result["serve_filename"],
300
+ )
301
+
302
+
303
+ # ----------------------- Load Vectorstore Endpoint -----------------------
304
+ @app.post("/load_vectorstore")
305
+ async def load_vectorstore(file: UploadFile = File(...)):
306
+ """
307
+ Load a vector store from an uploaded file (raw or zipped).
308
+ This will replace the current vector store.
309
+ """
310
+ tmp_filename = None
311
+ try:
312
+ # Save the uploaded file content to a temporary file.
313
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
314
+ file_bytes = await file.read() # await to get bytes
315
+ tmp.write(file_bytes)
316
+ tmp_filename = tmp.name
317
+
318
+ instance, message = VectorStoreManager.load(tmp_filename, embeddings)
319
+ except Exception as e:
320
+ raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
321
+ finally:
322
+ if tmp_filename and os.path.exists(tmp_filename):
323
+ os.remove(tmp_filename)
324
+ global vector_store_manager
325
+ vector_store_manager = instance
326
+ return {"message": message}
327
+
328
+
329
+ # ----------------------- Merge Vectorstore Endpoint -----------------------
330
+ @app.post("/merge_vectorstore")
331
+ async def merge_vectorstore(file: UploadFile = File(...)):
332
+ """
333
+ Merge an uploaded vector store (raw or zipped) into the current vector store.
334
+ """
335
+ tmp_filename = None
336
+ try:
337
+ # Save the uploaded file content to a temporary file.
338
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
339
+ file_bytes = await file.read() # Await the file.read() coroutine!
340
+ tmp.write(file_bytes)
341
+ tmp_filename = tmp.name
342
+
343
+ # Pass the filename (a string) to the merge method.
344
+ result = vector_store_manager.merge(tmp_filename, embeddings)
345
+ except Exception as e:
346
+ raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}")
347
+ finally:
348
+ if tmp_filename and os.path.exists(tmp_filename):
349
+ os.remove(tmp_filename)
350
+ return result
351
+
352
+
353
+ if __name__ == "__main__":
354
+ import uvicorn
355
+ uvicorn.run(app, host="0.0.0.0", port=8000)
prompt_templates.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import ChatPromptTemplate
2
+
3
+ class PromptTemplates:
4
+ """
5
+ A class to encapsulate various prompt templates for solving assignments, papers, creating quizzes, and assignments.
6
+ """
7
+
8
+ @staticmethod
9
+ def get_quiz_solving_prompt():
10
+
11
+ quiz_solving_prompt = '''
12
+ You are an assistant specialized in solving quizzes. Your goal is to provide accurate, concise, and contextually relevant answers.
13
+ Use the following retrieved context to answer the user's question.
14
+ If the context lacks sufficient information, respond with "I don't know." Do not make up answers or provide unverified information.
15
+
16
+ Guidelines:
17
+ 1. Extract key information from the context to form a coherent response.
18
+ 2. Maintain a clear and professional tone.
19
+ 3. If the question requires clarification, specify it politely.
20
+
21
+ Retrieved context:
22
+ {context}
23
+
24
+ User's question:
25
+ {question}
26
+
27
+ Your response:
28
+ '''
29
+
30
+
31
+ # Create a prompt template to pass the context and user input to the chain
32
+ prompt = ChatPromptTemplate.from_messages(
33
+ [
34
+ ("system", quiz_solving_prompt),
35
+ ("human", "{question}"),
36
+ ]
37
+ )
38
+
39
+ return prompt
40
+
41
+ @staticmethod
42
+ def get_assignment_solving_prompt():
43
+ # Prompt template for solving assignments
44
+ assignment_solving_prompt = '''
45
+ You are an expert assistant specializing in solving academic assignments with clarity and precision.
46
+ Your task is to provide step-by-step solutions and detailed explanations that align with the given requirements.
47
+
48
+ Retrieved context:
49
+ {context}
50
+
51
+ Assignment Details:
52
+ {question}
53
+
54
+ Guidelines:
55
+ 1. **Understand the Problem:** Carefully analyze the assignment details to identify the objective and requirements.
56
+ 2. **Provide a Step-by-Step Solution:** Break down the solution into clear, logical steps. Use examples where appropriate.
57
+ 3. **Explain Your Reasoning:** Include concise explanations for each step to enhance understanding.
58
+ 4. **Follow Formatting Rules:** Ensure the response matches any specified formatting or citation guidelines.
59
+ 5. **Maintain Academic Integrity:** Do not fabricate information, copy content verbatim without attribution, or complete the task in a way that breaches academic honesty policies.
60
+
61
+ Deliverable:
62
+ Provide the final answer in the format outlined in the assignment description. Where relevant, include:
63
+ - A brief introduction summarizing the approach.
64
+ - Calculations or code (if applicable).
65
+ - Any necessary diagrams, tables, or figures (use textual descriptions for diagrams if unavailable).
66
+ - A conclusion summarizing the findings.
67
+
68
+ If the assignment details are incomplete or ambiguous, specify what additional information is required to proceed.
69
+
70
+ Assignment Response:
71
+ '''
72
+
73
+ # Create a prompt template to pass the context and user input to the chain
74
+ prompt = ChatPromptTemplate.from_messages(
75
+ [
76
+ ("system", assignment_solving_prompt),
77
+ ("human", "{question}"),
78
+ ]
79
+ )
80
+
81
+ return prompt
82
+
83
+
84
+ @staticmethod
85
+ def get_paper_solving_prompt():
86
+ # Prompt template for solving papers
87
+ paper_solving_prompt = '''
88
+ You are an expert assistant specialized in solving academic papers with precision and clarity.
89
+ Your task is to provide well-structured answers to the questions in the paper, ensuring accuracy, depth, and adherence to any specified instructions.
90
+
91
+ Retrieved context:
92
+ {context}
93
+
94
+
95
+ Paper Information:
96
+ {question}
97
+
98
+ Instructions:
99
+ 1. **Understand Each Question:** Read each question carefully and identify its requirements, keywords, and scope.
100
+ 2. **Structured Responses:** Provide answers in a clear, logical structure (e.g., Introduction, Body, Conclusion).
101
+ 3. **Depth and Accuracy:** Support answers with explanations, examples, calculations, or references where applicable.
102
+ 4. **Formatting Guidelines:** Adhere to any specified format or style (e.g., bullet points, paragraphs, equations).
103
+ 5. **Time Efficiency:** If the paper is timed, prioritize accuracy and completeness over excessive detail.
104
+ 6. **Clarify Ambiguities:** If any question is unclear, mention the assumptions made while answering.
105
+ 7. **Ethical Guidelines:** Ensure the answers are original and aligned with academic integrity standards.
106
+
107
+ Deliverables:
108
+ - Answer all questions to the best of your ability.
109
+ - Include relevant diagrams, tables, or code (describe diagrams in text if unavailable).
110
+ - Summarize key points in a conclusion where applicable.
111
+ - Clearly number and label answers to match the questions in the paper.
112
+
113
+
114
+ If the paper includes multiple sections, label each section and solve sequentially.
115
+
116
+ Paper Solution:
117
+ '''
118
+
119
+ # Create a prompt template to pass the context and user input to the chain
120
+ prompt = ChatPromptTemplate.from_messages(
121
+ [
122
+ ("system", paper_solving_prompt),
123
+ ("human", "{question}"),
124
+ ]
125
+ )
126
+
127
+ return prompt
128
+
129
+ @staticmethod
130
+ def get_quiz_creation_prompt():
131
+ # Prompt template for creating a quiz
132
+ quiz_creation_prompt = '''
133
+ You are an expert assistant specializing in creating engaging and educational quizzes for students.
134
+ Your task is to design a quiz based on the topic, difficulty level, and format specified by the teacher.
135
+
136
+ Retrieved context:
137
+ {context}
138
+
139
+ Quiz Details:
140
+ Topic: {question}
141
+
142
+ Guidelines for Quiz Creation:
143
+ 1. **Relevance to Topic:** Ensure all questions are directly related to the specified topic.
144
+ 2. **Clear and Concise Wording:** Write questions clearly and concisely to avoid ambiguity.
145
+ 3. **Diverse Question Types:** Incorporate a variety of question types if specified.
146
+ 4. **Appropriate Difficulty:** Tailor the complexity of the questions to match the target audience and difficulty level.
147
+ 5. **Answer Key:** Provide correct answers or explanations for each question.
148
+
149
+ Deliverables:
150
+ - A complete quiz with numbered questions.
151
+ - An answer key with correct answers and explanations where relevant.
152
+
153
+ Quiz:
154
+ '''
155
+
156
+ # Create a prompt template to pass the context and user input to the chain
157
+ prompt = ChatPromptTemplate.from_messages(
158
+ [
159
+ ("system", quiz_creation_prompt),
160
+ ("human", "{question}"),
161
+ ]
162
+ )
163
+
164
+ return prompt
165
+
166
+
167
+ @staticmethod
168
+ def get_assignment_creation_prompt():
169
+ # Prompt template for creating an assignment
170
+ assignment_creation_prompt = '''
171
+ You are an expert assistant specializing in designing assignments that align with the educational goals and requirements of teachers.
172
+ Your task is to create a comprehensive assignment based on the provided topic, target audience, and desired outcomes.
173
+
174
+ Retrieved context:
175
+ {context}
176
+
177
+ Assignment Details:
178
+ Topic: {question}
179
+
180
+ Guidelines for Assignment Creation:
181
+ 1. **Alignment with Topic:** Ensure all tasks/questions are closely related to the specified topic and designed to achieve the teacher’s learning objectives.
182
+ 2. **Clear Instructions:** Provide detailed and clear instructions for each question or task.
183
+ 3. **Encourage Critical Thinking:** Include questions or tasks that require analysis, creativity, and application of knowledge where appropriate.
184
+ 4. **Variety of Tasks:** Incorporate diverse question types (e.g., short answers, essays, practical tasks) as per the specified format.
185
+ 5. **Grading Rubric (Optional):** Include a grading rubric or evaluation criteria if specified in the instructions.
186
+
187
+ Deliverables:
188
+ - A detailed assignment with numbered tasks/questions.
189
+ - Any required supporting materials (e.g., diagrams, data tables, references).
190
+ - (Optional) A grading rubric or expected outcomes for each task.
191
+
192
+ Assignment:
193
+ '''
194
+
195
+ # Create a prompt template to pass the context and user input to the chain
196
+ prompt = ChatPromptTemplate.from_messages(
197
+ [
198
+ ("system", assignment_creation_prompt),
199
+ ("human", "{question}"),
200
+ ]
201
+ )
202
+
203
+ return prompt
204
+
205
+
206
+ @staticmethod
207
+ def get_paper_creation_prompt():
208
+ # Prompt template for creating an academic paper
209
+ paper_creation_prompt = '''
210
+ You are an expert assistant specializing in designing comprehensive academic papers tailored to the educational goals and requirements of teachers.
211
+ Your task is to create a complete paper based on the specified topic, audience, format, and difficulty level.
212
+
213
+ Retrieved context:
214
+ {context}
215
+
216
+ Paper Details:
217
+ Subject/Topic: {question}
218
+
219
+ Guidelines for Paper Creation:
220
+ 1. **Relevance and Alignment:** Ensure all questions align with the specified subject/topic and are tailored to the target audience’s curriculum or learning objectives.
221
+ 2. **Clear Wording:** Write questions in clear, concise language to avoid ambiguity or confusion.
222
+ 3. **Diverse Question Types:** Incorporate a variety of question formats as specified (e.g., multiple-choice, fill-in-the-blank, long-form essays).
223
+ 4. **Grading and Marks Allocation:** Provide a suggested mark allocation for each question, ensuring it reflects the question's complexity and time required.
224
+ 5. **Answer Key:** Include correct answers or model responses for objective and descriptive questions (optional).
225
+
226
+ Deliverables:
227
+ - A complete paper with numbered questions, organized by sections if required.
228
+ - An answer key or marking scheme (if requested).
229
+ - Any supporting materials (e.g., diagrams, charts, or data tables) if applicable.
230
+
231
+ Paper:
232
+ '''
233
+
234
+ # Create a prompt template to pass the context and user input to the chain
235
+ prompt = ChatPromptTemplate.from_messages(
236
+ [
237
+ ("system", paper_creation_prompt),
238
+ ("human", "{question}"),
239
+ ]
240
+ )
241
+
242
+ return prompt
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ # python-jose
3
+ python-dotenv
4
+ # bcrypt
5
+ # passlib
6
+ uvicorn
7
+ # pyjwt
8
+ python-multipart
9
+ # pydantic[email]
10
+ pymongo
11
+ faiss-cpu
12
+ sentence_transformers
13
+ langchain_groq
14
+ langchain-community
15
+ langchain_unstructured
16
+ unstructured[all-docs]
17
+ unstructured[docx]
18
+ unstructured
19
+ unstructured[pdf]
20
+ langchain-mongodb
21
+ langchain_huggingface
22
+ wikipedia
23
+ docx2txt
retrieval_chain.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import ConversationalRetrievalChain
2
+ from langchain.prompts import ChatPromptTemplate
3
+
4
+ class RetrievalChain:
5
+ def __init__(self, llm, retriever, user_prompt, verbose=False):
6
+ """
7
+ Initializes the RetrievalChain with an LLM and retriever.
8
+
9
+ Args:
10
+ llm: Language model to use for the conversational chain.
11
+ retriever: Retriever object to fetch relevant documents.
12
+ user_prompt: Custom prompt to guide the chain.
13
+ verbose (bool): Whether to print verbose chain outputs.
14
+ """
15
+ self.llm = llm
16
+
17
+ self.chain = ConversationalRetrievalChain.from_llm(
18
+ llm=llm,
19
+ retriever=retriever,
20
+ return_source_documents=True,
21
+ chain_type='stuff',
22
+ combine_docs_chain_kwargs={"prompt": user_prompt},
23
+ verbose=verbose,
24
+ )
25
+
26
+ def summarize_messages(self, chat_history):
27
+ """
28
+ Summarizes the chat history into a concise message.
29
+
30
+ Args:
31
+ chat_history: The chat history object for the session.
32
+
33
+ Returns:
34
+ bool: True if summarization is successful, False otherwise.
35
+ """
36
+ stored_messages = chat_history.messages
37
+ if len(stored_messages) == 0:
38
+ return False
39
+
40
+ summarization_prompt = ChatPromptTemplate.from_messages(
41
+ [
42
+ ("placeholder", "{chat_history}"),
43
+ (
44
+ "human",
45
+ "Summarize the above chat messages into a single concise message. Include only the important specific details.",
46
+ ),
47
+ ]
48
+ )
49
+ # Create a chain for summarization by piping the prompt into the language model.
50
+ summarization_chain = summarization_prompt | self.llm
51
+ summary_message = summarization_chain.invoke({"chat_history": stored_messages})
52
+
53
+ chat_history.clear() # Clear the existing chat history
54
+ chat_history.add_ai_message(summary_message.content) # Add the summary message as the first entry
55
+ return True
56
+
57
+ def stream_chat_response(self, query, chat_id, get_chat_history, initialize_chat_history):
58
+ """
59
+ Streams the response to a query in real-time for a given chat session using SSE formatting.
60
+
61
+ Args:
62
+ query (str): The user's query.
63
+ chat_id (str): The unique ID of the chat session.
64
+ get_chat_history (function): Function to retrieve chat history by chat ID.
65
+ initialize_chat_history (function): Function to initialize a new chat history.
66
+
67
+ Yields:
68
+ str: Server-Sent Event (SSE) formatted string for each chunk of the response.
69
+ """
70
+ # Retrieve the chat history for the session.
71
+ chat_message_history = get_chat_history(chat_id)
72
+ if not chat_message_history:
73
+ # If no chat history exists, initialize one.
74
+ chat_message_history = initialize_chat_history(chat_id)
75
+
76
+ # Optionally summarize previous messages.
77
+ self.summarize_messages(chat_message_history)
78
+ chat_history = chat_message_history.messages
79
+
80
+ # Prepare input data for the conversational retrieval chain.
81
+ input_data_for_chain = {
82
+ "question": query,
83
+ "chat_history": chat_history
84
+ }
85
+
86
+ # Add the user query to the chat history.
87
+ chat_message_history.add_user_message(query)
88
+
89
+ # Execute the chain in streaming mode (this assumes the chain supports a `stream` method).
90
+ response_stream = self.chain.stream(input_data_for_chain)
91
+
92
+ accumulated_response = ""
93
+ # Process the response stream and yield SSE events.
94
+ for chunk in response_stream:
95
+ if 'answer' in chunk:
96
+ accumulated_response += chunk['answer']
97
+ # Format the SSE event.
98
+ sse_event = f"data: {chunk['answer']}\n\n"
99
+ yield sse_event
100
+ else:
101
+ # Yield an SSE event with debug info if the chunk structure is unexpected.
102
+ debug_msg = f"Unexpected chunk structure: {chunk}"
103
+ yield f"data: {debug_msg}\n\n"
104
+
105
+ # Once streaming is complete, update chat history with the final response.
106
+ if accumulated_response:
107
+ chat_message_history.add_ai_message(accumulated_response)
108
+ else:
109
+ yield "data: No valid response content was generated.\n\n"
110
+
text_splitter.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
2
+
3
+ class TextSplitter:
4
+ def __init__(self, chunk_size=1024, chunk_overlap=100):
5
+ """
6
+ Initialize the TextSplitter with a specific chunk size and overlap.
7
+
8
+ Args:
9
+ chunk_size (int): The size of each text chunk.
10
+ chunk_overlap (int): The overlap size between chunks.
11
+ """
12
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
13
+
14
+ def split_documents(self, documents):
15
+ """
16
+ Split the provided documents into chunks based on the chunk size and overlap.
17
+
18
+ Args:
19
+ documents (list): A list of documents to be split.
20
+
21
+ Returns:
22
+ A list of split documents.
23
+
24
+ Exceptions:
25
+ Prints an error message if splitting documents fails.
26
+ """
27
+ try:
28
+ return self.text_splitter.split_documents(documents)
29
+ except Exception as e:
30
+ print(f"Error splitting documents: {e}")
vector_store.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import shutil
4
+ import tempfile
5
+ import zipfile
6
+
7
+ from faiss import IndexFlatL2
8
+
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.docstore.in_memory import InMemoryDocstore
11
+
12
+
13
+ class VectorStoreManager:
14
+ def __init__(self, embeddings=None):
15
+ """
16
+ Initializes the VectorStoreManager with a FAISS vector store.
17
+
18
+ Args:
19
+ embeddings (Embeddings, optional): Embeddings model used for the vector store.
20
+ """
21
+ self.vectorstore = None
22
+ if embeddings:
23
+ self.vectorstore = self.create_vectorstore(embeddings)
24
+
25
+ def create_vectorstore(self, embeddings):
26
+ """
27
+ Creates and initializes a FAISS vector store.
28
+
29
+ Args:
30
+ embeddings (Embeddings): Embeddings model used for the vector store.
31
+
32
+ Returns:
33
+ FAISS: Initialized vector store.
34
+ """
35
+ # Define vector store dimensions based on embeddings
36
+ dimensions = len(embeddings.embed_query("dummy"))
37
+
38
+ # Initialize FAISS vector store
39
+ vectorstore = FAISS(
40
+ embedding_function=embeddings,
41
+ index=IndexFlatL2(dimensions),
42
+ docstore=InMemoryDocstore(),
43
+ index_to_docstore_id={},
44
+ normalize_L2=False
45
+ )
46
+
47
+ print("Created a new FAISS vector store.")
48
+ return vectorstore
49
+
50
+ def add_documents(self, documents):
51
+ """
52
+ Adds new documents to the FAISS vector store, each document with a unique UUID.
53
+
54
+ Args:
55
+ documents (list): List of Document objects to be added to the vector store.
56
+
57
+ Returns:
58
+ list: List of UUIDs corresponding to the added documents.
59
+ """
60
+ if not self.vectorstore:
61
+ raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
62
+
63
+ uuids = [str(uuid.uuid4()) for _ in range(len(documents))]
64
+ self.vectorstore.add_documents(documents=documents, ids=uuids)
65
+
66
+ print(f"Added {len(documents)} documents to the vector store with IDs: {uuids}")
67
+ return uuids
68
+
69
+ def delete_documents(self, ids):
70
+ """
71
+ Deletes documents from the FAISS vector store using their unique IDs.
72
+
73
+ Args:
74
+ ids (list): List of UUIDs corresponding to the documents to be deleted.
75
+
76
+ Returns:
77
+ bool: True if the documents were successfully deleted, False otherwise.
78
+ """
79
+ if not self.vectorstore:
80
+ raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
81
+
82
+ if not ids:
83
+ print("No document IDs provided for deletion.")
84
+ return False
85
+
86
+ success = self.vectorstore.delete(ids=ids)
87
+ if success:
88
+ print(f"Successfully deleted documents with IDs: {ids}")
89
+ else:
90
+ print(f"Failed to delete documents with IDs: {ids}")
91
+ return success
92
+
93
+ def save(self, filename="faiss_index"):
94
+ """
95
+ Saves the current FAISS vector store locally. If the saved store is a directory,
96
+ it compresses it into a ZIP archive.
97
+
98
+ Args:
99
+ filename (str): The filename or directory name where the vector store will be saved.
100
+
101
+ Returns:
102
+ dict: A dictionary with details about the saved file including file path and media type.
103
+ """
104
+ if not self.vectorstore:
105
+ raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
106
+
107
+ # Save the vectorstore locally
108
+ self.vectorstore.save_local(filename)
109
+ print(f"Vector store saved to {filename}")
110
+
111
+ if not os.path.exists(filename):
112
+ raise FileNotFoundError("Saved vectorstore not found.")
113
+
114
+ # If the saved vectorstore is a directory, compress it into a zip file.
115
+ if os.path.isdir(filename):
116
+ zip_filename = filename + ".zip"
117
+ shutil.make_archive(filename, 'zip', filename)
118
+ return {
119
+ "file_path": zip_filename,
120
+ "media_type": "application/zip",
121
+ "serve_filename": os.path.basename(zip_filename),
122
+ "original": filename,
123
+ }
124
+ else:
125
+ return {
126
+ "file_path": filename,
127
+ "media_type": "application/octet-stream",
128
+ "serve_filename": os.path.basename(filename),
129
+ "original": filename,
130
+ }
131
+
132
+ @staticmethod
133
+ def load(file_input, embeddings):
134
+ """
135
+ Loads a FAISS vector store from an uploaded file or a filename.
136
+ If file_input is a file-like object, it is saved to a temporary file.
137
+ If it's a string (filename), it is used directly.
138
+ """
139
+ # Check if file_input is a string (filename) or a file-like object.
140
+ if isinstance(file_input, str):
141
+ tmp_filename = file_input
142
+ else:
143
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
144
+ tmp.write(file_input.read())
145
+ tmp_filename = tmp.name
146
+
147
+ try:
148
+ if zipfile.is_zipfile(tmp_filename):
149
+ with tempfile.TemporaryDirectory() as extract_dir:
150
+ with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
151
+ zip_ref.extractall(extract_dir)
152
+ extracted_items = os.listdir(extract_dir)
153
+ if len(extracted_items) == 1:
154
+ potential_dir = os.path.join(extract_dir, extracted_items[0])
155
+ if os.path.isdir(potential_dir):
156
+ vectorstore_dir = potential_dir
157
+ else:
158
+ vectorstore_dir = extract_dir
159
+ else:
160
+ vectorstore_dir = extract_dir
161
+
162
+ new_vectorstore = FAISS.load_local(vectorstore_dir, embeddings, allow_dangerous_deserialization=True)
163
+ message = "Vector store loaded successfully from ZIP."
164
+ else:
165
+ new_vectorstore = FAISS.load_local(tmp_filename, embeddings, allow_dangerous_deserialization=True)
166
+ message = "Vector store loaded successfully."
167
+ except Exception as e:
168
+ raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}")
169
+ finally:
170
+ # Only remove the temp file if we created it here (i.e. file_input was not a filename)
171
+ if not isinstance(file_input, str) and os.path.exists(tmp_filename):
172
+ os.remove(tmp_filename)
173
+
174
+ instance = VectorStoreManager()
175
+ instance.vectorstore = new_vectorstore
176
+ print(message)
177
+ return instance, message
178
+
179
+ def merge(self, file_input, embeddings):
180
+ """
181
+ Merges an uploaded vector store file into the current FAISS vector store.
182
+
183
+ Args:
184
+ file_input (Union[file-like object, str]): An object with a .read() method or a filename (str).
185
+ embeddings (Embeddings): Embeddings model used for loading the vector store.
186
+
187
+ Returns:
188
+ dict: A dictionary containing a message indicating successful merging.
189
+ """
190
+ # Determine if file_input is a filename (str) or a file-like object.
191
+ if isinstance(file_input, str):
192
+ tmp_filename = file_input
193
+ temp_created = False
194
+ else:
195
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
196
+ tmp.write(file_input.read())
197
+ tmp_filename = tmp.name
198
+ temp_created = True
199
+
200
+ try:
201
+ # Check if the file is a ZIP archive.
202
+ if zipfile.is_zipfile(tmp_filename):
203
+ with tempfile.TemporaryDirectory() as extract_dir:
204
+ with zipfile.ZipFile(tmp_filename, 'r') as zip_ref:
205
+ zip_ref.extractall(extract_dir)
206
+ extracted_items = os.listdir(extract_dir)
207
+ if len(extracted_items) == 1:
208
+ potential_dir = os.path.join(extract_dir, extracted_items[0])
209
+ if os.path.isdir(potential_dir):
210
+ vectorstore_dir = potential_dir
211
+ else:
212
+ vectorstore_dir = extract_dir
213
+ else:
214
+ vectorstore_dir = extract_dir
215
+
216
+ source_store = FAISS.load_local(
217
+ vectorstore_dir, embeddings, allow_dangerous_deserialization=True
218
+ )
219
+ else:
220
+ source_store = FAISS.load_local(
221
+ tmp_filename, embeddings, allow_dangerous_deserialization=True
222
+ )
223
+
224
+ if not self.vectorstore:
225
+ raise ValueError("Vector store is not initialized. Please create or load a vector store first.")
226
+
227
+ self.vectorstore.merge_from(source_store)
228
+ print("Successfully merged the source vector store into the current vector store.")
229
+ except Exception as e:
230
+ raise Exception(f"Error merging vectorstore: {str(e)}")
231
+ finally:
232
+ if temp_created and os.path.exists(tmp_filename):
233
+ os.remove(tmp_filename)
234
+ return {"message": "Vector stores merged successfully"}