Spaces:
Build error
Build error
# qa_system.py | |
from langchain.vectorstores import Pinecone | |
from langchain_openai import OpenAIEmbeddings, ChatOpenAI | |
import pinecone | |
from langchain_pinecone import PineconeVectorStore | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.prompts import ChatPromptTemplate | |
class PineconeQA: | |
def __init__(self, pinecone_api_key, openai_api_key, index_name): | |
# Initialize Pinecone | |
self.pc = pinecone.Pinecone(api_key=pinecone_api_key) | |
self.index = self.pc.Index(index_name) | |
# Initialize embeddings | |
self.embeddings = OpenAIEmbeddings( | |
openai_api_key=openai_api_key | |
) | |
# Create retriever | |
self.retriever = PineconeVectorStore( | |
index=self.index, | |
embedding=self.embeddings | |
) | |
# Initialize LLM | |
self.llm = ChatOpenAI( | |
openai_api_key=openai_api_key, | |
model="gpt-4o", | |
temperature=0.2 | |
) | |
# Create the RAG chain | |
self._create_rag_chain() | |
def _create_rag_chain(self): | |
# Define system prompt | |
# system_prompt = ( | |
# "You are an assistant for question-answering tasks. " | |
# "Use the following pieces of retrieved context to answer " | |
# "the question. If you don't know the answer, say that you " | |
# "don't know. Use three sentences maximum and keep the " | |
# "answer concise." | |
# "\n\n" | |
# "{context}" | |
# ) | |
system_prompt = ( | |
"You are an expert assistant for biomedical question-answering tasks. " | |
"You will be provided with context retrieved from medical literature." | |
"The medical literature is all from PubMed Open Access Articles. " | |
"Use this context to answer the question as accurately as possible. " | |
"The response might not be added precisly, so try to derive the answers from it as much as possible." | |
"If the context does not contain the required information, explain why. " | |
"Provide a concise and accurate answer " | |
"\n\n" | |
"Context:\n{context}\n" | |
) | |
# Create chat prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", "{input}"), | |
]) | |
# Create question-answer chain | |
question_answer_chain = create_stuff_documents_chain( | |
self.llm, | |
prompt | |
) | |
# Create the RAG chain | |
self.rag_chain = create_retrieval_chain( | |
self.retriever.as_retriever(search_type="mmr"), | |
question_answer_chain | |
) | |
def merge_relevant_chunks(self, retrieved_docs, question, max_tokens=1500): | |
""" | |
Merge document chunks based on their semantic relevance to the question. | |
""" | |
merged_context = "" | |
current_tokens = 0 | |
for doc in retrieved_docs: | |
tokens = doc.page_content.split() | |
if current_tokens + len(tokens) <= max_tokens: | |
merged_context += doc.page_content + "\n" | |
current_tokens += len(tokens) | |
else: | |
break | |
return merged_context | |
def ask(self, question): | |
""" | |
Ask a question and get response with sources | |
""" | |
# Initialize conversation history if it doesn't exist | |
if not hasattr(self, "conversation_history"): | |
self.conversation_history = [] | |
try: | |
system_prompt = ( | |
"You are an expert assistant for biomedical question-answering tasks. " | |
"You will be provided with context retrieved from medical literature, specifically PubMed Open Access Articles. " | |
"Use the provided context to directly answer the question in the most accurate and concise manner possible. " | |
"If the context does not provide sufficient information, state that the specific details are not available in the context." | |
"Do not include statements about limitations of the context in your response. " | |
"Your answer should sound authoritative and professional, tailored for a medical audience." | |
"\n\n" | |
"Context:\n{context}\n" | |
) | |
# Create chat prompt template | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", "{input}"), | |
]) | |
# Create question-answer chain | |
question_answer_chain = create_stuff_documents_chain( | |
self.llm, | |
prompt | |
) | |
results = create_retrieval_chain( | |
self.retriever.as_retriever(seach_type="mmr"), | |
question_answer_chain | |
).invoke({"input": question}) | |
return { | |
"answer": results["answer"], | |
"context": results["context"] | |
} | |
except Exception as e: | |
return { | |
"error": str(e) | |
} |