Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import zipfile | |
from typing import List, Optional | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query | |
from fastapi.responses import FileResponse, StreamingResponse | |
from llm_initialization import get_llm | |
from embedding import get_embeddings | |
from document_loaders import DocumentLoader | |
from text_splitter import TextSplitter | |
from vector_store import VectorStoreManager | |
from prompt_templates import PromptTemplates | |
from chat_management import ChatManagement | |
from retrieval_chain import RetrievalChain | |
from urllib.parse import quote_plus | |
from dotenv import load_dotenv | |
from pymongo import MongoClient | |
# Load environment variables | |
load_dotenv() | |
MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) | |
MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") | |
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") | |
MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING") | |
app = FastAPI(title="VectorStore & Document Management API") | |
# Global variables (initialized on startup) | |
llm = None | |
embeddings = None | |
chat_manager = None | |
document_loader = None | |
text_splitter = None | |
vector_store_manager = None | |
vector_store = None | |
k = 3 # Number of documents to retrieve per query | |
# Global MongoDB collection to store retrieval chain configuration per chat session. | |
chat_chains_collection = None | |
# ----------------------- Startup Event ----------------------- | |
async def startup_event(): | |
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection | |
print("Starting up: Initializing components...") | |
# Initialize LLM and embeddings | |
llm = get_llm() | |
print("LLM initialized.") | |
embeddings = get_embeddings() | |
print("Embeddings initialized.") | |
# Setup chat management | |
chat_manager = ChatManagement( | |
cluster_url=MONGO_CLUSTER_URL, | |
database_name=MONGO_DATABASE_NAME, | |
collection_name=MONGO_COLLECTION_NAME, | |
) | |
print("Chat management initialized.") | |
# Initialize document loader and text splitter | |
document_loader = DocumentLoader() | |
text_splitter = TextSplitter() | |
print("Document loader and text splitter initialized.") | |
# Initialize vector store manager and ensure vectorstore is set | |
vector_store_manager = VectorStoreManager(embeddings) | |
vector_store = vector_store_manager.vectorstore # Now properly initialized | |
print("Vector store initialized.") | |
# Connect to MongoDB and get the collection. | |
client = MongoClient(MONGO_CLUSTER_URL) | |
db = client[MONGO_DATABASE_NAME] | |
chat_chains_collection = db["chat_chains"] | |
print("Chat chains collection initialized in MongoDB.") | |
# ----------------------- Root Endpoint ----------------------- | |
def root(): | |
""" | |
Root endpoint that returns a welcome message. | |
""" | |
return {"message": "Welcome to the VectorStore & Document Management API!"} | |
# ----------------------- New Chat Endpoint ----------------------- | |
def new_chat(): | |
""" | |
Create a new chat session. | |
""" | |
new_chat_id = chat_manager.create_new_chat() | |
return {"chat_id": new_chat_id} | |
# ----------------------- Create Chain Endpoint ----------------------- | |
def create_chain( | |
chat_id: str = Query(..., description="Existing chat session ID"), | |
template: str = Query( | |
"quiz_solving", | |
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation", | |
), | |
): | |
global chat_chains_collection # Ensure we reference the global variable | |
valid_templates = [ | |
"quiz_solving", | |
"assignment_solving", | |
"paper_solving", | |
"quiz_creation", | |
"assignment_creation", | |
"paper_creation", | |
] | |
if template not in valid_templates: | |
raise HTTPException(status_code=400, detail="Invalid template selection.") | |
# Upsert the configuration document for this chat session. | |
chat_chains_collection.update_one( | |
{"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True | |
) | |
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template} | |
# ----------------------- Chat Endpoint ----------------------- | |
def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")): | |
""" | |
Process a chat query using the retrieval chain associated with the given chat_id. | |
This endpoint uses the following code: | |
try: | |
stream_generator = retrieval_chain.stream_chat_response( | |
query=query, | |
chat_id=chat_id, | |
get_chat_history=chat_manager.get_chat_history, | |
initialize_chat_history=chat_manager.initialize_chat_history, | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") | |
return StreamingResponse(stream_generator, media_type="text/event-stream") | |
It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response. | |
""" | |
# Retrieve the chat configuration from MongoDB. | |
config = chat_chains_collection.find_one({"chat_id": chat_id}) | |
if not config: | |
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.") | |
template = config.get("template", "quiz_solving") | |
if template == "quiz_solving": | |
prompt = PromptTemplates.get_quiz_solving_prompt() | |
elif template == "assignment_solving": | |
prompt = PromptTemplates.get_assignment_solving_prompt() | |
elif template == "paper_solving": | |
prompt = PromptTemplates.get_paper_solving_prompt() | |
elif template == "quiz_creation": | |
prompt = PromptTemplates.get_quiz_creation_prompt() | |
elif template == "assignment_creation": | |
prompt = PromptTemplates.get_assignment_creation_prompt() | |
elif template == "paper_creation": | |
prompt = PromptTemplates.get_paper_creation_prompt() | |
else: | |
raise HTTPException(status_code=400, detail="Invalid chat configuration.") | |
# Re-create the retrieval chain for this chat session. | |
retrieval_chain = RetrievalChain( | |
llm, | |
vector_store.as_retriever(search_kwargs={"k": k}), | |
prompt, | |
verbose=True, | |
) | |
try: | |
stream_generator = retrieval_chain.stream_chat_response( | |
query=query, | |
chat_id=chat_id, | |
get_chat_history=chat_manager.get_chat_history, | |
initialize_chat_history=chat_manager.initialize_chat_history, | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") | |
return StreamingResponse(stream_generator, media_type="text/event-stream") | |
# ----------------------- Add Document Endpoint ----------------------- | |
from typing import Any, Optional | |
async def add_document( | |
file: Optional[Any] = File(None), | |
wiki_query: Optional[str] = Query(None), | |
wiki_url: Optional[str] = Query(None) | |
): | |
""" | |
Upload a document OR load data from a Wikipedia query or URL. | |
- If a file is provided, the document is loaded from the file. | |
- If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query. | |
- If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls. | |
The loaded document(s) are then split into chunks and added to the vector store. | |
""" | |
# If file is provided but not as an UploadFile (e.g. an empty string), set it to None. | |
if not isinstance(file, UploadFile): | |
file = None | |
# Ensure at least one input is provided. | |
if file is None and wiki_query is None and wiki_url is None: | |
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).") | |
# Load document(s) based on input priority: file > wiki_query > wiki_url. | |
if file is not None: | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
contents = await file.read() | |
tmp.write(contents) | |
tmp_filename = tmp.name | |
ext = file.filename.split(".")[-1].lower() | |
try: | |
if ext == "pdf": | |
documents = document_loader.load_pdf(tmp_filename) | |
elif ext == "csv": | |
documents = document_loader.load_csv(tmp_filename) | |
elif ext in ["doc", "docx"]: | |
documents = document_loader.load_doc(tmp_filename) | |
elif ext in ["html", "htm"]: | |
documents = document_loader.load_text_from_html(tmp_filename) | |
elif ext in ["md", "markdown"]: | |
documents = document_loader.load_markdown(tmp_filename) | |
else: | |
documents = document_loader.load_unstructured(tmp_filename) | |
except Exception as e: | |
os.remove(tmp_filename) | |
raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}") | |
os.remove(tmp_filename) | |
elif wiki_query is not None: | |
try: | |
documents = document_loader.wikipedia_query(wiki_query) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}") | |
elif wiki_url is not None: | |
try: | |
documents = document_loader.load_urls([wiki_url]) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}") | |
try: | |
chunks = text_splitter.split_documents(documents) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}") | |
try: | |
ids = vector_store_manager.add_documents(chunks) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}") | |
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids} | |
# ----------------------- Delete Document Endpoint ----------------------- | |
def delete_document(ids: List[str]): | |
""" | |
Delete document(s) from the vector store using their IDs. | |
""" | |
try: | |
success = vector_store_manager.delete_documents(ids) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}") | |
if not success: | |
raise HTTPException(status_code=400, detail="Failed to delete documents.") | |
return {"message": f"Deleted documents with IDs: {ids}"} | |
# ----------------------- Save Vectorstore Endpoint ----------------------- | |
def save_vectorstore(): | |
""" | |
Save the current vector store locally. | |
If it is a directory, it will be zipped. | |
Returns the file as a downloadable response. | |
""" | |
try: | |
save_result = vector_store_manager.save("faiss_index") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}") | |
return FileResponse( | |
path=save_result["file_path"], | |
media_type=save_result["media_type"], | |
filename=save_result["serve_filename"], | |
) | |
# ----------------------- Load Vectorstore Endpoint ----------------------- | |
async def load_vectorstore(file: UploadFile = File(...)): | |
""" | |
Load a vector store from an uploaded file (raw or zipped). | |
This will replace the current vector store. | |
""" | |
tmp_filename = None | |
try: | |
# Save the uploaded file content to a temporary file. | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
file_bytes = await file.read() # await to get bytes | |
tmp.write(file_bytes) | |
tmp_filename = tmp.name | |
instance, message = VectorStoreManager.load(tmp_filename, embeddings) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}") | |
finally: | |
if tmp_filename and os.path.exists(tmp_filename): | |
os.remove(tmp_filename) | |
global vector_store_manager | |
vector_store_manager = instance | |
return {"message": message} | |
# ----------------------- Merge Vectorstore Endpoint ----------------------- | |
async def merge_vectorstore(file: UploadFile = File(...)): | |
""" | |
Merge an uploaded vector store (raw or zipped) into the current vector store. | |
""" | |
tmp_filename = None | |
try: | |
# Save the uploaded file content to a temporary file. | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
file_bytes = await file.read() # Await the file.read() coroutine! | |
tmp.write(file_bytes) | |
tmp_filename = tmp.name | |
# Pass the filename (a string) to the merge method. | |
result = vector_store_manager.merge(tmp_filename, embeddings) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}") | |
finally: | |
if tmp_filename and os.path.exists(tmp_filename): | |
os.remove(tmp_filename) | |
return result | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |