File size: 5,312 Bytes
965ac15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# qa_system.py
from langchain.vectorstores import Pinecone
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import pinecone
from langchain_pinecone import PineconeVectorStore
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate

class PineconeQA:
    def __init__(self, pinecone_api_key, openai_api_key, index_name):
        # Initialize Pinecone
        self.pc = pinecone.Pinecone(api_key=pinecone_api_key)
        self.index = self.pc.Index(index_name)
        
        # Initialize embeddings
        self.embeddings = OpenAIEmbeddings(
            openai_api_key=openai_api_key
        )
        
        # Create retriever
        self.retriever = PineconeVectorStore(
            index=self.index,
            embedding=self.embeddings
        )
        
        # Initialize LLM
        self.llm = ChatOpenAI(
            openai_api_key=openai_api_key,
            model="gpt-4o",
            temperature=0.2
        )
        
        # Create the RAG chain
        self._create_rag_chain()
    
    def _create_rag_chain(self):
        # Define system prompt
        # system_prompt = (
        #     "You are an assistant for question-answering tasks. "
        #     "Use the following pieces of retrieved context to answer "
        #     "the question. If you don't know the answer, say that you "
        #     "don't know. Use three sentences maximum and keep the "
        #     "answer concise."
        #     "\n\n"
        #     "{context}"
        # )

        
        system_prompt = (
            "You are an expert assistant for biomedical question-answering tasks. "
            "You will be provided with context retrieved from medical literature."
            "The medical literature is all from PubMed Open Access Articles. "
            "Use this context to answer the question as accurately as possible. "
            "The response might not be added precisly, so try to derive the answers from it as much as possible."
            "If the context does not contain the required information, explain why. "
            "Provide a concise and accurate answer "
            "\n\n"
            "Context:\n{context}\n"
        )
        # Create chat prompt template
        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("human", "{input}"),
        ])
        
        # Create question-answer chain
        question_answer_chain = create_stuff_documents_chain(
            self.llm,
            prompt
        )
        
        # Create the RAG chain
        self.rag_chain = create_retrieval_chain(
            self.retriever.as_retriever(search_type="mmr"),
            question_answer_chain
        )
    def merge_relevant_chunks(self, retrieved_docs, question, max_tokens=1500):
        """
        Merge document chunks based on their semantic relevance to the question.
        """
        merged_context = ""
        current_tokens = 0

        for doc in retrieved_docs:
            tokens = doc.page_content.split()
            if current_tokens + len(tokens) <= max_tokens:
                merged_context += doc.page_content + "\n"
                current_tokens += len(tokens)
            else:
                break

        return merged_context


    def ask(self, question):
        """
        Ask a question and get response with sources
        """
        # Initialize conversation history if it doesn't exist
        if not hasattr(self, "conversation_history"):
            self.conversation_history = []

        try:
                
            system_prompt = (
            "You are an expert assistant for biomedical question-answering tasks. "
            "You will be provided with context retrieved from medical literature, specifically PubMed Open Access Articles. "
            "Use the provided context to directly answer the question in the most accurate and concise manner possible. "
            "If the context does not provide sufficient information, state that the specific details are not available in the context."
            "Do not include statements about limitations of the context in your response. "
            "Your answer should sound authoritative and professional, tailored for a medical audience."
            "\n\n"
            "Context:\n{context}\n"
                )               
            # Create chat prompt template
            prompt = ChatPromptTemplate.from_messages([
                ("system", system_prompt),
                ("human", "{input}"),
            ])
            
            # Create question-answer chain
            question_answer_chain = create_stuff_documents_chain(
                self.llm,
                prompt
            )
            
                        
            
            results = create_retrieval_chain(
                self.retriever.as_retriever(seach_type="mmr"),
                question_answer_chain
            ).invoke({"input": question})
            
            return {
                "answer": results["answer"],
                "context": results["context"]
            }
        except Exception as e:
            return {
                "error": str(e)
            }