File size: 3,320 Bytes
5e433de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29865e5
 
5e433de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import toml
from typing import Optional
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
import logging
from langchain_groq import ChatGroq
from src.bot.extract_metadata import Metadata


# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class Medibot:
    def __init__(self, config_path: str = "src/bot/configs/prompt.toml",
                 metadata_database: str = "database/metadata.csv",
                 faiss_database: str = "database/faiss_index" 
                 ):
        """Initialize Medibot with configuration and Groq client."""
        # Load environment variables
        api_key = os.environ.get("GROQ_API_KEY")
        if not api_key:
            logger.error("GROQ_API_KEY not found in environment variables")
            raise ValueError("GROQ_API_KEY is required")
        
        os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY")

        # Load prompt configuration
        try:
            config = toml.load(config_path)
            system_prompt = config["rag_prompt"]["system_prompt"]
            user_prompt_template = config["rag_prompt"]["user_prompt_template"]
            
        except (FileNotFoundError, toml.TomlDecodeError) as e:
            logger.error(f"Failed to load config from {config_path}: {e}")
            raise

        # Initialize prompt template
        self.prompt_template = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("user", user_prompt_template)
        ])

        # initialize vector database
        embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
        vector_store = FAISS.load_local(
                        faiss_database, embeddings, allow_dangerous_deserialization=True
                    )
        self.retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 10})
        # Initialize Groq client
        
        self.model = ChatGroq(
                            model="llama-3.1-8b-instant",
                            temperature=0.2,
                            max_tokens=None,
                            timeout=None,
                            max_retries=2,
                            )
        
        self.metadata_extactor = Metadata(metadata_database)

    
    def query(self, question: str) -> str:
        retrieved_docs = self.retriever.invoke(question)

        # RunnableParallel({"context": retriever, "question": RunnablePassthrough()})
        rag_chain = (
            RunnableParallel({
            "context": RunnableLambda(lambda _: retrieved_docs),  # Reuse retrieved docs
            "question": RunnablePassthrough()
        })
            | self.prompt_template
            | self.model
            | StrOutputParser()
        )

        answer = rag_chain.invoke({"question": question})

        refered_tables , refered_images = self.metadata_extactor.get_data_from_ref(retrieved_docs)
        return answer, retrieved_docs, refered_tables , refered_images