Spaces:
Running
Running
# src/main.py | |
from fastapi import FastAPI, Depends, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from .agents.rag_agent import RAGAgent | |
from .llms.openai_llm import OpenAILanguageModel | |
from .llms.ollama_llm import OllamaLanguageModel | |
from .embeddings.huggingface_embedding import HuggingFaceEmbedding | |
from .vectorstores.chroma_vectorstore import ChromaVectorStore | |
from config.config import settings | |
app = FastAPI(title="RAG Chatbot API") | |
class ChatRequest(BaseModel): | |
query: str | |
context_docs: Optional[List[str]] = None | |
llm_provider: str = 'openai' | |
class ChatResponse(BaseModel): | |
response: str | |
context: Optional[List[str]] = None | |
async def chat_endpoint(request: ChatRequest): | |
try: | |
# Select LLM based on provider | |
if request.llm_provider == 'openai': | |
llm = OpenAILanguageModel(api_key=settings.OPENAI_API_KEY) | |
elif request.llm_provider == 'ollama': | |
llm = OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL) | |
else: | |
raise HTTPException(status_code=400, detail="Unsupported LLM provider") | |
# Initialize embedding and vector store | |
embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL) | |
vector_store = ChromaVectorStore( | |
embedding_function=embedding.embed_documents, | |
persist_directory=settings.CHROMA_PATH | |
) | |
# Create RAG agent | |
rag_agent = RAGAgent( | |
llm=llm, | |
embedding=embedding, | |
vector_store=vector_store | |
) | |
# Process query | |
response = rag_agent.generate_response( | |
query=request.query, | |
context_docs=request.context_docs | |
) | |
return ChatResponse( | |
response=response.response, | |
context=response.context_docs | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Optional: Health check endpoint | |
async def health_check(): | |
return {"status": "healthy"} |