Shriharshan commited on
Commit
43d3036
·
verified ·
1 Parent(s): a6cef50

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -0
app.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import spaces
3
+ from typing import List
4
+ import gradio as gr
5
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline, ChatHuggingFace
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.retrievers import BM25Retriever
8
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
9
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
10
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
11
+ from langchain_core.messages import HumanMessage, AIMessage
12
+ from langchain.docstore.document import Document
13
+
14
+
15
+ model = HuggingFacePipeline.from_model_id(
16
+ model_id="HuggingFaceTB/SmolLM2-360M-Instruct",
17
+ task="text-generation",
18
+ device_map="auto",
19
+ pipeline_kwargs=dict(
20
+ max_new_tokens=512,
21
+ do_sample=False,
22
+ repetition_penalty=1.03,
23
+ return_full_text=False,
24
+ ),
25
+ )
26
+
27
+ llm = ChatHuggingFace(llm=model)
28
+
29
+ def create_embeddings_model() -> HuggingFaceEmbeddings:
30
+ model_name = "BAAI/bge-m3"
31
+ model_kwargs = {
32
+ 'device': 'cpu',
33
+ 'trust_remote_code': True,
34
+ }
35
+ encode_kwargs = {'normalize_embeddings': True}
36
+ return HuggingFaceEmbeddings(
37
+ model_name=model_name,
38
+ model_kwargs=model_kwargs,
39
+ encode_kwargs=encode_kwargs,
40
+ show_progress=True
41
+ )
42
+
43
+ embeddings = create_embeddings_model()
44
+
45
+ def load_faiss_retriever(path: str) -> FAISS:
46
+ vector_store = FAISS.load_local(path, embeddings, allow_dangerous_deserialization=True)
47
+ return vector_store.as_retriever(search_kwargs={"k": 10})
48
+
49
+ def load_bm25_retriever(load_path: str) -> BM25Retriever:
50
+ with open(load_path, "r", encoding="utf-8") as f:
51
+ docs_json = json.load(f)
52
+ documents = [Document(page_content=doc["page_content"], metadata=doc["metadata"]) for doc in docs_json]
53
+ return BM25Retriever.from_documents(documents, language="english")
54
+
55
+ class EmbeddingBM25RerankerRetriever:
56
+ def __init__(self, vector_retriever, bm25_retriever, reranker):
57
+ self.vector_retriever = vector_retriever
58
+ self.bm25_retriever = bm25_retriever
59
+ self.reranker = reranker
60
+
61
+ def invoke(self, query: str):
62
+ vector_docs = self.vector_retriever.invoke(query)
63
+ bm25_docs = self.bm25_retriever.invoke(query)
64
+ combined_docs = vector_docs + [doc for doc in bm25_docs if doc not in vector_docs]
65
+ return self.reranker.compress_documents(combined_docs, query)
66
+
67
+ faiss_path = "VectorDB/faiss_index"
68
+ bm25_path = "VectorDB/bm25_index.json"
69
+
70
+ faiss_retriever = load_faiss_retriever(faiss_path)
71
+ bm25_retriever = load_bm25_retriever(bm25_path)
72
+ reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-v2-m3")
73
+ reranker = CrossEncoderReranker(top_n=4, model=reranker_model)
74
+ retriever = EmbeddingBM25RerankerRetriever(faiss_retriever, bm25_retriever, reranker)
75
+
76
+ qa_prompt = ChatPromptTemplate.from_messages([
77
+ ("system", """You are an AI research assistant specializing in Autism research, powered by a retrieval system of curated PubMed documents.
78
+
79
+ Response Guidelines:
80
+ - Provide precise, evidence-based answers drawing directly from retrieved medical research
81
+ - Synthesize information from multiple documents when possible
82
+ - Clearly distinguish between established findings and emerging research
83
+ - Maintain scientific rigor and objectivity
84
+
85
+ Query Handling:
86
+ - Prioritize direct, informative responses
87
+ - When document evidence is incomplete, explain the current state of research
88
+ - Highlight areas where more research is needed
89
+ - Never introduce speculation or unsupported claims
90
+
91
+ Contextual Integrity:
92
+ - Ensure all statements are traceable to specific research documents
93
+ - Preserve the nuance and complexity of scientific findings
94
+ - Communicate with clarity, avoiding unnecessary medical jargon
95
+
96
+ Knowledge Limitations:
97
+ - If no relevant information is found, state: "Current research documents do not provide a comprehensive answer to this specific query."
98
+ """),
99
+ MessagesPlaceholder("chat_history"),
100
+ ("human", "Context:\n{context}\n\nQuestion: {input}")
101
+ ])
102
+
103
+ def format_context(docs) -> str:
104
+ return "\n\n".join([f"Doc {i+1}: {doc.page_content}" for i, doc in enumerate(docs)])
105
+
106
+ @spaces.GPU
107
+ def chat_with_rag(query: str, history: List[tuple[str, str]]) -> str:
108
+ chat_history = []
109
+ for human, ai in history:
110
+ chat_history.append(HumanMessage(content=human))
111
+ chat_history.append(AIMessage(content=ai))
112
+
113
+ docs = retriever.invoke(query)
114
+ context = format_context(docs)
115
+
116
+ prompt_input = {
117
+ "chat_history": chat_history,
118
+ "context": context,
119
+ "input": query
120
+ }
121
+ prompt = qa_prompt.format(**prompt_input)
122
+
123
+ response = llm.invoke(prompt)
124
+ return response.content
125
+
126
+ chat_interface = gr.ChatInterface(
127
+ fn=chat_with_rag,
128
+ title="Autism RAG Chatbot",
129
+ description="Ask questions about Autism.",
130
+ examples=["What causes Autism?", "How is Autism treated?", "What is Autism"],
131
+ )
132
+
133
+ if __name__ == "__main__":
134
+ chat_interface.launch(share=True)