Tamil Eniyan
commited on
Commit
·
226641d
1
Parent(s):
114c773
Add application file
Browse files
app.py
CHANGED
@@ -5,11 +5,13 @@ import pickle
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from transformers import pipeline
|
7 |
|
8 |
-
# File names for saved data
|
9 |
INDEX_FILE = "faiss_index.index"
|
10 |
CHUNKS_FILE = "chunks.pkl"
|
|
|
|
|
11 |
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
12 |
-
QA_MODEL_NAME = "deepset/roberta-large-squad2" #
|
13 |
|
14 |
@st.cache_resource
|
15 |
def load_index_and_chunks():
|
@@ -24,13 +26,67 @@ def load_embedding_model():
|
|
24 |
|
25 |
@st.cache_resource
|
26 |
def load_qa_pipeline():
|
27 |
-
# This QA pipeline expects a question and a context
|
28 |
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def main():
|
31 |
-
st.title("PDF Question-Answering App")
|
32 |
|
33 |
-
#
|
|
|
|
|
|
|
|
|
34 |
index, chunks = load_index_and_chunks()
|
35 |
embed_model = load_embedding_model()
|
36 |
qa_pipeline = load_qa_pipeline()
|
@@ -39,29 +95,37 @@ def main():
|
|
39 |
query = st.text_input("Question:")
|
40 |
|
41 |
if query:
|
42 |
-
#
|
43 |
-
|
44 |
|
45 |
-
# Retrieve
|
46 |
-
|
|
|
47 |
distances, indices = index.search(query_embedding, k)
|
48 |
|
49 |
-
|
50 |
-
context = ""
|
51 |
for idx in indices[0]:
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
#
|
56 |
-
with st.expander("Show Retrieved Context"):
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
|
60 |
st.subheader("Answer:")
|
61 |
try:
|
62 |
-
|
63 |
-
|
64 |
-
st.write(
|
|
|
65 |
except Exception as e:
|
66 |
st.error(f"Error generating answer: {e}")
|
67 |
|
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from transformers import pipeline
|
7 |
|
8 |
+
# File names for saved PDF-based data
|
9 |
INDEX_FILE = "faiss_index.index"
|
10 |
CHUNKS_FILE = "chunks.pkl"
|
11 |
+
|
12 |
+
# Models
|
13 |
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
14 |
+
QA_MODEL_NAME = "deepset/roberta-large-squad2" # More powerful QA model
|
15 |
|
16 |
@st.cache_resource
|
17 |
def load_index_and_chunks():
|
|
|
26 |
|
27 |
@st.cache_resource
|
28 |
def load_qa_pipeline():
|
|
|
29 |
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
|
30 |
|
31 |
+
# Curated Q/A pairs for enhanced context
|
32 |
+
curated_qa_pairs = [
|
33 |
+
{
|
34 |
+
"question": "What is Adagrasib (Krazati®) used for?",
|
35 |
+
"answer": "Adagrasib (Krazati®) is used for the treatment of adult patients with advanced non-small cell lung cancer (NSCLC) with KRAS G12C mutation and disease progression after at least one prior systemic therapy."
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"question": "What are the comparator treatments for Adagrasib in the PICO exercises?",
|
39 |
+
"answer": "In PICOs 1 and 2, the comparator treatments are Sotorasib monotherapy and Platinum-based doublet chemotherapy. In PICO 3, the comparator treatment is Docetaxel monotherapy. In PICO 4, the comparator treatment is Immune checkpoint inhibitor monotherapy."
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"question": "What are the effectiveness measures used in the PICO framework?",
|
43 |
+
"answer": "Effectiveness measures in the PICO framework typically include outcomes such as overall response rate, progression-free survival, and overall survival. In these PICOs, Adagrasib is used as monotherapy in PICOs 1 and 2, and as monotherapy or in combination with platinum-based chemotherapy in PICO 3, while in PICO 4 it may be used as monotherapy or in combination therapy."
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"question": "What adverse events were monitored in the PICO studies?",
|
47 |
+
"answer": "The adverse events monitored in the PICO studies include a range of treatment-related side effects, though the provided context does not detail them fully. More detailed study data would be required for a complete list."
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"question": "How does Adagrasib compare to immune checkpoint inhibitors?",
|
51 |
+
"answer": "The analysis suggests that comparisons between Adagrasib and immune checkpoint inhibitors (or Docetaxel) should account for patients' prior treatment lines. Direct comparisons may not be relevant in certain subgroups, especially in patients with only one prior line of systemic therapy."
|
52 |
+
}
|
53 |
+
]
|
54 |
+
|
55 |
+
def get_curated_context(query, curated_qa, embed_model):
|
56 |
+
"""
|
57 |
+
Retrieve the most relevant curated Q/A pair based on the user's query.
|
58 |
+
Returns a formatted string if the similarity (based on L2 distance) is below a threshold.
|
59 |
+
"""
|
60 |
+
curated_questions = [qa["question"] for qa in curated_qa]
|
61 |
+
query_embedding = embed_model.encode([query]).astype('float32')
|
62 |
+
curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
|
63 |
+
curated_embeddings = np.array(curated_embeddings).astype('float32')
|
64 |
+
|
65 |
+
# Build a temporary FAISS index for curated questions
|
66 |
+
dimension = curated_embeddings.shape[1]
|
67 |
+
curated_index = faiss.IndexFlatL2(dimension)
|
68 |
+
curated_index.add(curated_embeddings)
|
69 |
+
|
70 |
+
k = 1
|
71 |
+
distances, indices = curated_index.search(query_embedding, k)
|
72 |
+
|
73 |
+
# Define a threshold for relevance (tune as needed)
|
74 |
+
threshold = 1.0
|
75 |
+
if distances[0][0] < threshold:
|
76 |
+
idx = indices[0][0]
|
77 |
+
qa_pair = curated_qa[idx]
|
78 |
+
return f"Curated Q/A Pair:\nQuestion: {qa_pair['question']}\nAnswer: {qa_pair['answer']}\n"
|
79 |
+
else:
|
80 |
+
return ""
|
81 |
+
|
82 |
def main():
|
83 |
+
st.title("PDF Question-Answering App with Enhanced Context")
|
84 |
|
85 |
+
# Initialize conversation history if not already in session_state
|
86 |
+
if 'conversation_history' not in st.session_state:
|
87 |
+
st.session_state.conversation_history = ""
|
88 |
+
|
89 |
+
# Load PDF index, chunks, and models
|
90 |
index, chunks = load_index_and_chunks()
|
91 |
embed_model = load_embedding_model()
|
92 |
qa_pipeline = load_qa_pipeline()
|
|
|
95 |
query = st.text_input("Question:")
|
96 |
|
97 |
if query:
|
98 |
+
# Append the current question to conversation history
|
99 |
+
st.session_state.conversation_history += f"User: {query}\n"
|
100 |
|
101 |
+
# Retrieve relevant PDF context using FAISS index
|
102 |
+
query_embedding = embed_model.encode([query]).astype('float32')
|
103 |
+
k = 3 # number of top chunks to retrieve
|
104 |
distances, indices = index.search(query_embedding, k)
|
105 |
|
106 |
+
pdf_context = ""
|
|
|
107 |
for idx in indices[0]:
|
108 |
+
pdf_context += chunks[idx] + "\n"
|
109 |
+
|
110 |
+
# Get curated Q/A context if the query matches any curated questions
|
111 |
+
curated_context = get_curated_context(query, curated_qa_pairs, embed_model)
|
112 |
+
|
113 |
+
# Combine conversation history, curated context, and PDF context for the QA pipeline
|
114 |
+
combined_context = st.session_state.conversation_history + "\n" + curated_context + "\n" + pdf_context
|
115 |
|
116 |
+
# Optionally show the retrieved contexts in expanders
|
117 |
+
with st.expander("Show PDF Retrieved Context"):
|
118 |
+
st.write(pdf_context)
|
119 |
+
if curated_context:
|
120 |
+
with st.expander("Show Curated Q/A Context"):
|
121 |
+
st.write(curated_context)
|
122 |
|
123 |
st.subheader("Answer:")
|
124 |
try:
|
125 |
+
result = qa_pipeline(question=query, context=combined_context)
|
126 |
+
answer = result["answer"]
|
127 |
+
st.write(answer)
|
128 |
+
st.session_state.conversation_history += f"AI: {answer}\n"
|
129 |
except Exception as e:
|
130 |
st.error(f"Error generating answer: {e}")
|
131 |
|