Tamil Eniyan commited on
Commit
226641d
·
1 Parent(s): 114c773

Add application file

Browse files
Files changed (1) hide show
  1. app.py +84 -20
app.py CHANGED
@@ -5,11 +5,13 @@ import pickle
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import pipeline
7
 
8
- # File names for saved data
9
  INDEX_FILE = "faiss_index.index"
10
  CHUNKS_FILE = "chunks.pkl"
 
 
11
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
12
- QA_MODEL_NAME = "deepset/roberta-large-squad2" # You can change this to any Hugging Face QA model
13
 
14
  @st.cache_resource
15
  def load_index_and_chunks():
@@ -24,13 +26,67 @@ def load_embedding_model():
24
 
25
  @st.cache_resource
26
  def load_qa_pipeline():
27
- # This QA pipeline expects a question and a context
28
  return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def main():
31
- st.title("PDF Question-Answering App")
32
 
33
- # Load FAISS index, chunks, and models
 
 
 
 
34
  index, chunks = load_index_and_chunks()
35
  embed_model = load_embedding_model()
36
  qa_pipeline = load_qa_pipeline()
@@ -39,29 +95,37 @@ def main():
39
  query = st.text_input("Question:")
40
 
41
  if query:
42
- # Encode the query using the same SentenceTransformer model
43
- query_embedding = embed_model.encode([query]).astype('float32')
44
 
45
- # Retrieve top k relevant chunks
46
- k = 3
 
47
  distances, indices = index.search(query_embedding, k)
48
 
49
- # Prepare combined context from the retrieved chunks
50
- context = ""
51
  for idx in indices[0]:
52
- context_piece = chunks[idx]
53
- context += context_piece + " "
 
 
 
 
 
54
 
55
- # Use an expander to optionally display the retrieved context
56
- with st.expander("Show Retrieved Context"):
57
- for idx in indices[0]:
58
- st.write(chunks[idx])
 
 
59
 
60
  st.subheader("Answer:")
61
  try:
62
- # Use the QA pipeline to generate an answer based on the combined context
63
- result = qa_pipeline(question=query, context=context)
64
- st.write(result["answer"])
 
65
  except Exception as e:
66
  st.error(f"Error generating answer: {e}")
67
 
 
5
  from sentence_transformers import SentenceTransformer
6
  from transformers import pipeline
7
 
8
+ # File names for saved PDF-based data
9
  INDEX_FILE = "faiss_index.index"
10
  CHUNKS_FILE = "chunks.pkl"
11
+
12
+ # Models
13
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
14
+ QA_MODEL_NAME = "deepset/roberta-large-squad2" # More powerful QA model
15
 
16
  @st.cache_resource
17
  def load_index_and_chunks():
 
26
 
27
  @st.cache_resource
28
  def load_qa_pipeline():
 
29
  return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
30
 
31
+ # Curated Q/A pairs for enhanced context
32
+ curated_qa_pairs = [
33
+ {
34
+ "question": "What is Adagrasib (Krazati®) used for?",
35
+ "answer": "Adagrasib (Krazati®) is used for the treatment of adult patients with advanced non-small cell lung cancer (NSCLC) with KRAS G12C mutation and disease progression after at least one prior systemic therapy."
36
+ },
37
+ {
38
+ "question": "What are the comparator treatments for Adagrasib in the PICO exercises?",
39
+ "answer": "In PICOs 1 and 2, the comparator treatments are Sotorasib monotherapy and Platinum-based doublet chemotherapy. In PICO 3, the comparator treatment is Docetaxel monotherapy. In PICO 4, the comparator treatment is Immune checkpoint inhibitor monotherapy."
40
+ },
41
+ {
42
+ "question": "What are the effectiveness measures used in the PICO framework?",
43
+ "answer": "Effectiveness measures in the PICO framework typically include outcomes such as overall response rate, progression-free survival, and overall survival. In these PICOs, Adagrasib is used as monotherapy in PICOs 1 and 2, and as monotherapy or in combination with platinum-based chemotherapy in PICO 3, while in PICO 4 it may be used as monotherapy or in combination therapy."
44
+ },
45
+ {
46
+ "question": "What adverse events were monitored in the PICO studies?",
47
+ "answer": "The adverse events monitored in the PICO studies include a range of treatment-related side effects, though the provided context does not detail them fully. More detailed study data would be required for a complete list."
48
+ },
49
+ {
50
+ "question": "How does Adagrasib compare to immune checkpoint inhibitors?",
51
+ "answer": "The analysis suggests that comparisons between Adagrasib and immune checkpoint inhibitors (or Docetaxel) should account for patients' prior treatment lines. Direct comparisons may not be relevant in certain subgroups, especially in patients with only one prior line of systemic therapy."
52
+ }
53
+ ]
54
+
55
+ def get_curated_context(query, curated_qa, embed_model):
56
+ """
57
+ Retrieve the most relevant curated Q/A pair based on the user's query.
58
+ Returns a formatted string if the similarity (based on L2 distance) is below a threshold.
59
+ """
60
+ curated_questions = [qa["question"] for qa in curated_qa]
61
+ query_embedding = embed_model.encode([query]).astype('float32')
62
+ curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
63
+ curated_embeddings = np.array(curated_embeddings).astype('float32')
64
+
65
+ # Build a temporary FAISS index for curated questions
66
+ dimension = curated_embeddings.shape[1]
67
+ curated_index = faiss.IndexFlatL2(dimension)
68
+ curated_index.add(curated_embeddings)
69
+
70
+ k = 1
71
+ distances, indices = curated_index.search(query_embedding, k)
72
+
73
+ # Define a threshold for relevance (tune as needed)
74
+ threshold = 1.0
75
+ if distances[0][0] < threshold:
76
+ idx = indices[0][0]
77
+ qa_pair = curated_qa[idx]
78
+ return f"Curated Q/A Pair:\nQuestion: {qa_pair['question']}\nAnswer: {qa_pair['answer']}\n"
79
+ else:
80
+ return ""
81
+
82
  def main():
83
+ st.title("PDF Question-Answering App with Enhanced Context")
84
 
85
+ # Initialize conversation history if not already in session_state
86
+ if 'conversation_history' not in st.session_state:
87
+ st.session_state.conversation_history = ""
88
+
89
+ # Load PDF index, chunks, and models
90
  index, chunks = load_index_and_chunks()
91
  embed_model = load_embedding_model()
92
  qa_pipeline = load_qa_pipeline()
 
95
  query = st.text_input("Question:")
96
 
97
  if query:
98
+ # Append the current question to conversation history
99
+ st.session_state.conversation_history += f"User: {query}\n"
100
 
101
+ # Retrieve relevant PDF context using FAISS index
102
+ query_embedding = embed_model.encode([query]).astype('float32')
103
+ k = 3 # number of top chunks to retrieve
104
  distances, indices = index.search(query_embedding, k)
105
 
106
+ pdf_context = ""
 
107
  for idx in indices[0]:
108
+ pdf_context += chunks[idx] + "\n"
109
+
110
+ # Get curated Q/A context if the query matches any curated questions
111
+ curated_context = get_curated_context(query, curated_qa_pairs, embed_model)
112
+
113
+ # Combine conversation history, curated context, and PDF context for the QA pipeline
114
+ combined_context = st.session_state.conversation_history + "\n" + curated_context + "\n" + pdf_context
115
 
116
+ # Optionally show the retrieved contexts in expanders
117
+ with st.expander("Show PDF Retrieved Context"):
118
+ st.write(pdf_context)
119
+ if curated_context:
120
+ with st.expander("Show Curated Q/A Context"):
121
+ st.write(curated_context)
122
 
123
  st.subheader("Answer:")
124
  try:
125
+ result = qa_pipeline(question=query, context=combined_context)
126
+ answer = result["answer"]
127
+ st.write(answer)
128
+ st.session_state.conversation_history += f"AI: {answer}\n"
129
  except Exception as e:
130
  st.error(f"Error generating answer: {e}")
131