import os import tempfile import zipfile from typing import List, Optional from fastapi import FastAPI, File, UploadFile, HTTPException, Query from fastapi.responses import FileResponse, StreamingResponse from llm_initialization import get_llm from embedding import get_embeddings from document_loaders import DocumentLoader from text_splitter import TextSplitter from vector_store import VectorStoreManager from prompt_templates import PromptTemplates from chat_management import ChatManagement from retrieval_chain import RetrievalChain from urllib.parse import quote_plus from dotenv import load_dotenv from pymongo import MongoClient # Load environment variables load_dotenv() MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING") app = FastAPI(title="VectorStore & Document Management API") # Global variables (initialized on startup) llm = None embeddings = None chat_manager = None document_loader = None text_splitter = None vector_store_manager = None vector_store = None k = 3 # Number of documents to retrieve per query # Global MongoDB collection to store retrieval chain configuration per chat session. chat_chains_collection = None # ----------------------- Startup Event ----------------------- @app.on_event("startup") async def startup_event(): global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store, chat_chains_collection print("Starting up: Initializing components...") # Initialize LLM and embeddings llm = get_llm() print("LLM initialized.") embeddings = get_embeddings() print("Embeddings initialized.") # Setup chat management chat_manager = ChatManagement( cluster_url=MONGO_CLUSTER_URL, database_name=MONGO_DATABASE_NAME, collection_name=MONGO_COLLECTION_NAME, ) print("Chat management initialized.") # Initialize document loader and text splitter document_loader = DocumentLoader() text_splitter = TextSplitter() print("Document loader and text splitter initialized.") # Initialize vector store manager and ensure vectorstore is set vector_store_manager = VectorStoreManager(embeddings) vector_store = vector_store_manager.vectorstore # Now properly initialized print("Vector store initialized.") # Connect to MongoDB and get the collection. client = MongoClient(MONGO_CLUSTER_URL) db = client[MONGO_DATABASE_NAME] chat_chains_collection = db["chat_chains"] print("Chat chains collection initialized in MongoDB.") # ----------------------- Root Endpoint ----------------------- @app.get("/") def root(): """ Root endpoint that returns a welcome message. """ return {"message": "Welcome to the VectorStore & Document Management API!"} # ----------------------- New Chat Endpoint ----------------------- @app.post("/new_chat") def new_chat(): """ Create a new chat session. """ new_chat_id = chat_manager.create_new_chat() return {"chat_id": new_chat_id} # ----------------------- Create Chain Endpoint ----------------------- @app.post("/create_chain") def create_chain( chat_id: str = Query(..., description="Existing chat session ID"), template: str = Query( "quiz_solving", description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation", ), ): global chat_chains_collection # Ensure we reference the global variable valid_templates = [ "quiz_solving", "assignment_solving", "paper_solving", "quiz_creation", "assignment_creation", "paper_creation", ] if template not in valid_templates: raise HTTPException(status_code=400, detail="Invalid template selection.") # Upsert the configuration document for this chat session. chat_chains_collection.update_one( {"chat_id": chat_id}, {"$set": {"template": template}}, upsert=True ) return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template} # ----------------------- Chat Endpoint ----------------------- @app.get("/chat") def chat(query: str, chat_id: str = Query(..., description="Chat session ID created via /new_chat and configured via /create_chain")): """ Process a chat query using the retrieval chain associated with the given chat_id. This endpoint uses the following code: try: stream_generator = retrieval_chain.stream_chat_response( query=query, chat_id=chat_id, get_chat_history=chat_manager.get_chat_history, initialize_chat_history=chat_manager.initialize_chat_history, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") return StreamingResponse(stream_generator, media_type="text/event-stream") It first retrieves the configuration from MongoDB, re-creates the chain, and then streams the response. """ # Retrieve the chat configuration from MongoDB. config = chat_chains_collection.find_one({"chat_id": chat_id}) if not config: raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.") template = config.get("template", "quiz_solving") if template == "quiz_solving": prompt = PromptTemplates.get_quiz_solving_prompt() elif template == "assignment_solving": prompt = PromptTemplates.get_assignment_solving_prompt() elif template == "paper_solving": prompt = PromptTemplates.get_paper_solving_prompt() elif template == "quiz_creation": prompt = PromptTemplates.get_quiz_creation_prompt() elif template == "assignment_creation": prompt = PromptTemplates.get_assignment_creation_prompt() elif template == "paper_creation": prompt = PromptTemplates.get_paper_creation_prompt() else: raise HTTPException(status_code=400, detail="Invalid chat configuration.") # Re-create the retrieval chain for this chat session. retrieval_chain = RetrievalChain( llm, vector_store.as_retriever(search_kwargs={"k": k}), prompt, verbose=True, ) try: stream_generator = retrieval_chain.stream_chat_response( query=query, chat_id=chat_id, get_chat_history=chat_manager.get_chat_history, initialize_chat_history=chat_manager.initialize_chat_history, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") return StreamingResponse(stream_generator, media_type="text/event-stream") # ----------------------- Add Document Endpoint ----------------------- from typing import Any, Optional @app.post("/add_document") async def add_document( file: Optional[Any] = File(None), wiki_query: Optional[str] = Query(None), wiki_url: Optional[str] = Query(None) ): """ Upload a document OR load data from a Wikipedia query or URL. - If a file is provided, the document is loaded from the file. - If 'wiki_query' is provided, the Wikipedia page(s) are loaded using document_loader.wikipedia_query. - If 'wiki_url' is provided, the URL is loaded using document_loader.load_urls. The loaded document(s) are then split into chunks and added to the vector store. """ # If file is provided but not as an UploadFile (e.g. an empty string), set it to None. if not isinstance(file, UploadFile): file = None # Ensure at least one input is provided. if file is None and wiki_query is None and wiki_url is None: raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).") # Load document(s) based on input priority: file > wiki_query > wiki_url. if file is not None: with tempfile.NamedTemporaryFile(delete=False) as tmp: contents = await file.read() tmp.write(contents) tmp_filename = tmp.name ext = file.filename.split(".")[-1].lower() try: if ext == "pdf": documents = document_loader.load_pdf(tmp_filename) elif ext == "csv": documents = document_loader.load_csv(tmp_filename) elif ext in ["doc", "docx"]: documents = document_loader.load_doc(tmp_filename) elif ext in ["html", "htm"]: documents = document_loader.load_text_from_html(tmp_filename) elif ext in ["md", "markdown"]: documents = document_loader.load_markdown(tmp_filename) else: documents = document_loader.load_unstructured(tmp_filename) except Exception as e: os.remove(tmp_filename) raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}") os.remove(tmp_filename) elif wiki_query is not None: try: documents = document_loader.wikipedia_query(wiki_query) except Exception as e: raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}") elif wiki_url is not None: try: documents = document_loader.load_urls([wiki_url]) except Exception as e: raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}") try: chunks = text_splitter.split_documents(documents) except Exception as e: raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}") try: ids = vector_store_manager.add_documents(chunks) except Exception as e: raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}") return {"message": f"Added {len(chunks)} document chunks.", "ids": ids} # ----------------------- Delete Document Endpoint ----------------------- @app.post("/delete_document") def delete_document(ids: List[str]): """ Delete document(s) from the vector store using their IDs. """ try: success = vector_store_manager.delete_documents(ids) except Exception as e: raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}") if not success: raise HTTPException(status_code=400, detail="Failed to delete documents.") return {"message": f"Deleted documents with IDs: {ids}"} # ----------------------- Save Vectorstore Endpoint ----------------------- @app.get("/save_vectorstore") def save_vectorstore(): """ Save the current vector store locally. If it is a directory, it will be zipped. Returns the file as a downloadable response. """ try: save_result = vector_store_manager.save("faiss_index") except Exception as e: raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}") return FileResponse( path=save_result["file_path"], media_type=save_result["media_type"], filename=save_result["serve_filename"], ) # ----------------------- Load Vectorstore Endpoint ----------------------- @app.post("/load_vectorstore") async def load_vectorstore(file: UploadFile = File(...)): """ Load a vector store from an uploaded file (raw or zipped). This will replace the current vector store. """ tmp_filename = None try: # Save the uploaded file content to a temporary file. with tempfile.NamedTemporaryFile(delete=False) as tmp: file_bytes = await file.read() # await to get bytes tmp.write(file_bytes) tmp_filename = tmp.name instance, message = VectorStoreManager.load(tmp_filename, embeddings) except Exception as e: raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}") finally: if tmp_filename and os.path.exists(tmp_filename): os.remove(tmp_filename) global vector_store_manager vector_store_manager = instance return {"message": message} # ----------------------- Merge Vectorstore Endpoint ----------------------- @app.post("/merge_vectorstore") async def merge_vectorstore(file: UploadFile = File(...)): """ Merge an uploaded vector store (raw or zipped) into the current vector store. """ tmp_filename = None try: # Save the uploaded file content to a temporary file. with tempfile.NamedTemporaryFile(delete=False) as tmp: file_bytes = await file.read() # Await the file.read() coroutine! tmp.write(file_bytes) tmp_filename = tmp.name # Pass the filename (a string) to the merge method. result = vector_store_manager.merge(tmp_filename, embeddings) except Exception as e: raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}") finally: if tmp_filename and os.path.exists(tmp_filename): os.remove(tmp_filename) return result if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)