Spaces:
Running
Running
import os | |
import tempfile | |
import zipfile | |
from typing import List, Optional, Any | |
import uuid | |
from datetime import datetime | |
from fastapi import FastAPI, File, UploadFile, HTTPException, Query, Depends | |
from fastapi.responses import FileResponse, StreamingResponse | |
# Removed static files mounting for avatars as avatars are now served via GridFS in auth | |
#from fastapi.staticfiles import StaticFiles | |
from llm_initialization import get_llm | |
from embedding import get_embeddings | |
from document_loaders import DocumentLoader | |
from text_splitter import TextSplitter | |
from vector_store import VectorStoreManager | |
from prompt_templates import PromptTemplates | |
from chat_management import ChatManagement | |
from retrieval_chain import RetrievalChain | |
from urllib.parse import quote_plus | |
from dotenv import load_dotenv | |
from pymongo import MongoClient | |
# Load environment variables | |
load_dotenv() | |
MONGO_PASSWORD = quote_plus(os.getenv("MONGO_PASSWORD")) | |
MONGO_DATABASE_NAME = os.getenv("DATABASE_NAME") | |
MONGO_COLLECTION_NAME = os.getenv("COLLECTION_NAME") | |
MONGO_CLUSTER_URL = os.getenv("CONNECTION_STRING") | |
app = FastAPI(title="VectorStore & Document Management API") | |
# Note: Since user avatars are now stored in MongoDB via GridFS and served via /auth/avatar, | |
# we no longer mount a local avatars directory. | |
# Import auth router and dependencies | |
from auth import router as auth_router, get_current_user, users_collection | |
# Mount auth endpoints under /auth | |
app.include_router(auth_router, prefix="/auth") | |
from transcribe import router as transcribe_router | |
app.include_router(transcribe_router, prefix="/audio") | |
# Global variables (initialized on startup) | |
llm = None | |
embeddings = None | |
chat_manager = None | |
document_loader = None | |
text_splitter = None | |
vector_store_manager = None | |
vector_store = None | |
k = 3 # Number of documents to retrieve per query | |
# ----------------------- Startup Event ----------------------- | |
async def startup_event(): | |
global llm, embeddings, chat_manager, document_loader, text_splitter, vector_store_manager, vector_store | |
print("Starting up: Initializing components...") | |
# Initialize LLM and embeddings | |
llm = get_llm() | |
print("LLM initialized.") | |
embeddings = get_embeddings() | |
print("Embeddings initialized.") | |
# Setup chat management | |
chat_manager = ChatManagement( | |
cluster_url=MONGO_CLUSTER_URL, | |
database_name=MONGO_DATABASE_NAME, | |
collection_name=MONGO_COLLECTION_NAME, | |
) | |
print("Chat management initialized.") | |
# Initialize document loader and text splitter | |
document_loader = DocumentLoader() | |
text_splitter = TextSplitter() | |
print("Document loader and text splitter initialized.") | |
# Initialize vector store manager and set vector store | |
vector_store_manager = VectorStoreManager(embeddings) | |
vector_store = vector_store_manager.vectorstore | |
print("Vector store initialized.") | |
# ----------------------- New Chat Endpoint (Updated) ----------------------- | |
def new_chat(current_user: dict = Depends(get_current_user)): | |
""" | |
Create a new chat session under the current user's document. | |
""" | |
new_chat_id = str(uuid.uuid4()) | |
# Append a new chat session to the user's chat_histories | |
users_collection.update_one( | |
{"email": current_user["email"]}, | |
{"$push": {"chat_histories": {"chat_id": new_chat_id, "created_at": datetime.utcnow(), "messages": []}}} | |
) | |
return {"chat_id": new_chat_id} | |
# ----------------------- Create Chain Endpoint (Updated) ----------------------- | |
def create_chain( | |
chat_id: str = Query(..., description="Existing chat session ID"), | |
template: str = Query( | |
"quiz_solving", | |
description="Select prompt template. Options: quiz_solving, assignment_solving, paper_solving, quiz_creation, assignment_creation, paper_creation", | |
), | |
current_user: dict = Depends(get_current_user) | |
): | |
valid_templates = [ | |
"quiz_solving", | |
"assignment_solving", | |
"paper_solving", | |
"quiz_creation", | |
"assignment_creation", | |
"paper_creation", | |
] | |
if template not in valid_templates: | |
raise HTTPException(status_code=400, detail="Invalid template selection.") | |
# Update the specific chat session's configuration in the user's document | |
users_collection.update_one( | |
{"email": current_user["email"], "chat_histories.chat_id": chat_id}, | |
{"$set": {"chat_histories.$.template": template}} | |
) | |
return {"message": "Retrieval chain configuration stored successfully.", "chat_id": chat_id, "template": template} | |
# ----------------------- Chat Endpoint ----------------------- | |
def chat( | |
query: str, | |
chat_id: str = Query(..., description="Chat session ID"), | |
current_user: dict = Depends(get_current_user) | |
): | |
""" | |
Process a chat query using the retrieval chain associated with the given chat_id. | |
""" | |
# Retrieve chat configuration from the user's document | |
user = current_user | |
chat_config = None | |
for chat in user.get("chat_histories", []): | |
if chat.get("chat_id") == chat_id: | |
chat_config = chat | |
break | |
if not chat_config: | |
raise HTTPException(status_code=400, detail="Chat configuration not found. Please create a chain using /create_chain.") | |
template = chat_config.get("template", "quiz_solving") | |
if template == "quiz_solving": | |
prompt = PromptTemplates.get_quiz_solving_prompt() | |
elif template == "assignment_solving": | |
prompt = PromptTemplates.get_assignment_solving_prompt() | |
elif template == "paper_solving": | |
prompt = PromptTemplates.get_paper_solving_prompt() | |
elif template == "quiz_creation": | |
prompt = PromptTemplates.get_quiz_creation_prompt() | |
elif template == "assignment_creation": | |
prompt = PromptTemplates.get_assignment_creation_prompt() | |
elif template == "paper_creation": | |
prompt = PromptTemplates.get_paper_creation_prompt() | |
else: | |
raise HTTPException(status_code=400, detail="Invalid chat configuration.") | |
retrieval_chain = RetrievalChain( | |
llm, | |
vector_store.as_retriever(search_kwargs={"k": k}), | |
prompt, | |
verbose=True, | |
) | |
try: | |
stream_generator = retrieval_chain.stream_chat_response( | |
query=query, | |
chat_id=chat_id, | |
get_chat_history=chat_manager.get_chat_history, | |
initialize_chat_history=chat_manager.initialize_chat_history, | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing chat query: {str(e)}") | |
return StreamingResponse(stream_generator, media_type="text/event-stream") | |
# ----------------------- Remaining Endpoints ----------------------- | |
async def add_document( | |
file: Optional[UploadFile] = File(None), # File parameter now is an UploadFile | |
wiki_query: Optional[str] = Query(None), | |
wiki_url: Optional[str] = Query(None) | |
): | |
if file is None and wiki_query is None and wiki_url is None: | |
raise HTTPException(status_code=400, detail="No document input provided (file, wiki_query, or wiki_url).") | |
if file is not None: | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
contents = await file.read() | |
tmp.write(contents) | |
tmp_filename = tmp.name | |
ext = file.filename.split(".")[-1].lower() | |
try: | |
if ext == "pdf": | |
documents = document_loader.load_pdf(tmp_filename) | |
elif ext == "csv": | |
documents = document_loader.load_csv(tmp_filename) | |
elif ext in ["doc", "docx"]: | |
documents = document_loader.load_doc(tmp_filename) | |
elif ext in ["html", "htm"]: | |
documents = document_loader.load_text_from_html(tmp_filename) | |
elif ext in ["md", "markdown"]: | |
documents = document_loader.load_markdown(tmp_filename) | |
else: | |
documents = document_loader.load_unstructured(tmp_filename) | |
except Exception as e: | |
os.remove(tmp_filename) | |
raise HTTPException(status_code=400, detail=f"Error loading document from file: {str(e)}") | |
os.remove(tmp_filename) | |
elif wiki_query is not None: | |
try: | |
documents = document_loader.wikipedia_query(wiki_query) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading Wikipedia query: {str(e)}") | |
elif wiki_url is not None: | |
try: | |
documents = document_loader.load_urls([wiki_url]) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error loading URL: {str(e)}") | |
try: | |
chunks = text_splitter.split_documents(documents) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error splitting document: {str(e)}") | |
try: | |
ids = vector_store_manager.add_documents(chunks) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error indexing document chunks: {str(e)}") | |
return {"message": f"Added {len(chunks)} document chunks.", "ids": ids} | |
def delete_document(ids: List[str]): | |
try: | |
success = vector_store_manager.delete_documents(ids) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error deleting documents: {str(e)}") | |
if not success: | |
raise HTTPException(status_code=400, detail="Failed to delete documents.") | |
return {"message": f"Deleted documents with IDs: {ids}"} | |
def save_vectorstore(): | |
try: | |
save_result = vector_store_manager.save("faiss_index") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error saving vectorstore: {str(e)}") | |
return FileResponse( | |
path=save_result["file_path"], | |
media_type=save_result["media_type"], | |
filename=save_result["serve_filename"], | |
) | |
async def load_vectorstore(file: UploadFile = File(...)): | |
tmp_filename = None | |
try: | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
file_bytes = await file.read() | |
tmp.write(file_bytes) | |
tmp_filename = tmp.name | |
instance, message = VectorStoreManager.load(tmp_filename, embeddings) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error loading vectorstore: {str(e)}") | |
finally: | |
if tmp_filename and os.path.exists(tmp_filename): | |
os.remove(tmp_filename) | |
global vector_store_manager | |
vector_store_manager = instance | |
return {"message": message} | |
async def merge_vectorstore(file: UploadFile = File(...)): | |
tmp_filename = None | |
try: | |
with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
file_bytes = await file.read() | |
tmp.write(file_bytes) | |
tmp_filename = tmp.name | |
result = vector_store_manager.merge(tmp_filename, embeddings) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error merging vectorstore: {str(e)}") | |
finally: | |
if tmp_filename and os.path.exists(tmp_filename): | |
os.remove(tmp_filename) | |
return result | |
async def root(): | |
""" | |
Root endpoint that provides a welcome message. | |
""" | |
return { | |
"message": "Welcome to the EduLearn AI." | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |