File size: 2,140 Bytes
640b1c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# src/main.py
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from typing import List, Optional

from .agents.rag_agent import RAGAgent
from .llms.openai_llm import OpenAILanguageModel
from .llms.ollama_llm import OllamaLanguageModel
from .embeddings.huggingface_embedding import HuggingFaceEmbedding
from .vectorstores.chroma_vectorstore import ChromaVectorStore
from config.config import settings

app = FastAPI(title="RAG Chatbot API")

class ChatRequest(BaseModel):
    query: str
    context_docs: Optional[List[str]] = None
    llm_provider: str = 'openai'

class ChatResponse(BaseModel):
    response: str
    context: Optional[List[str]] = None

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    try:
        # Select LLM based on provider
        if request.llm_provider == 'openai':
            llm = OpenAILanguageModel(api_key=settings.OPENAI_API_KEY)
        elif request.llm_provider == 'ollama':
            llm = OllamaLanguageModel(base_url=settings.OLLAMA_BASE_URL)
        else:
            raise HTTPException(status_code=400, detail="Unsupported LLM provider")
        
        # Initialize embedding and vector store
        embedding = HuggingFaceEmbedding(model_name=settings.EMBEDDING_MODEL)
        vector_store = ChromaVectorStore(
            embedding_function=embedding.embed_documents, 
            persist_directory=settings.CHROMA_PATH
        )
        
        # Create RAG agent
        rag_agent = RAGAgent(
            llm=llm, 
            embedding=embedding, 
            vector_store=vector_store
        )
        
        # Process query
        response = rag_agent.generate_response(
            query=request.query, 
            context_docs=request.context_docs
        )
        
        return ChatResponse(
            response=response.response, 
            context=response.context_docs
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Optional: Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy"}