Ali2206 commited on
Commit
26faa43
Β·
verified Β·
1 Parent(s): 2debc41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -139
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
- from typing import List, Tuple, Union, Generator
7
  import hashlib
8
  import shutil
9
  import re
@@ -42,20 +42,17 @@ def clean_response(text: str) -> str:
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 3.5 + 1
44
 
45
- def extract_text_from_excel(file_obj: Union[str, os.PathLike, 'file']) -> str:
46
  all_text = []
47
- try:
48
- xls = pd.ExcelFile(file_obj)
49
- except Exception as e:
50
- raise ValueError(f"❌ Error reading Excel file: {e}")
51
  for sheet_name in xls.sheet_names:
52
  df = xls.parse(sheet_name).astype(str).fillna("")
53
- rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
54
- sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
55
  all_text.extend(sheet_text)
56
  return "\n".join(all_text)
57
 
58
- def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_chunks: int = 30) -> List[str]:
59
  effective_max = max_tokens - PROMPT_OVERHEAD
60
  lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
61
  for line in lines:
@@ -63,13 +60,11 @@ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_ch
63
  if curr_tokens + t > effective_max:
64
  if curr_chunk:
65
  chunks.append("\n".join(curr_chunk))
66
- if len(chunks) >= max_chunks:
67
- break
68
  curr_chunk, curr_tokens = [line], t
69
  else:
70
  curr_chunk.append(line)
71
  curr_tokens += t
72
- if curr_chunk and len(chunks) < max_chunks:
73
  chunks.append("\n".join(curr_chunk))
74
  return chunks
75
 
@@ -92,48 +87,14 @@ Analyze the following clinical notes and provide a detailed, concise summary foc
92
  Respond in well-structured bullet points with medical reasoning.
93
  """
94
 
95
- def validate_tool_file(file_path):
96
- try:
97
- with open(file_path, 'r') as f:
98
- data = json.load(f)
99
- if isinstance(data, list):
100
- assert all(isinstance(t, dict) and "name" in t for t in data), "Invalid list format"
101
- elif isinstance(data, dict):
102
- assert "tools" in data and isinstance(data["tools"], list), "'tools' field missing or invalid"
103
- assert all(isinstance(t, dict) and "name" in t for t in data["tools"]), "Invalid item in 'tools'"
104
- else:
105
- raise ValueError("Unexpected structure")
106
- return True
107
- except Exception as e:
108
- print(f"❌ Tool validation failed for {file_path}: {e}")
109
- return False
110
-
111
  def init_agent():
112
- all_tool_paths = {
113
- "opentarget": "/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json",
114
- "fda_drug_label": "/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json",
115
- "special_tools": "/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json",
116
- "monarch": "/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json",
117
- "new_tool": os.path.join(tool_cache_dir, "new_tool.json"),
118
- }
119
-
120
- if not os.path.exists(all_tool_paths["new_tool"]):
121
- shutil.copy(os.path.abspath("data/new_tool.json"), all_tool_paths["new_tool"])
122
-
123
- valid_tool_paths = {}
124
- for key, path in all_tool_paths.items():
125
- if validate_tool_file(path):
126
- valid_tool_paths[key] = path
127
- else:
128
- print(f"⚠️ Skipping invalid tool file: {path}")
129
-
130
- if not valid_tool_paths:
131
- raise RuntimeError("❌ No valid tool files found to load into TxAgent.")
132
-
133
  agent = TxAgent(
134
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
135
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
136
- tool_files_dict=valid_tool_paths,
137
  force_finish=True,
138
  enable_checker=True,
139
  step_rag_num=4,
@@ -142,111 +103,126 @@ def init_agent():
142
  agent.init_model()
143
  return agent
144
 
145
- def stream_report(agent, input_file: Union[str, 'file'], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
146
- accumulated_text = ""
147
- try:
148
- if input_file is None:
149
- yield "❌ Please upload a valid Excel file.", None, ""
150
- return
151
 
152
- if hasattr(input_file, "read"):
153
- text = extract_text_from_excel(input_file)
154
- elif isinstance(input_file, str) and os.path.exists(input_file):
155
- text = extract_text_from_excel(input_file)
156
- else:
157
- raise ValueError("❌ Invalid or missing file.")
158
-
159
- chunks = split_text_into_chunks(text)
160
-
161
- for i, chunk in enumerate(chunks):
162
- prompt = build_prompt_from_text(chunk)
163
- partial = ""
164
- for res in agent.run_gradio_chat(
165
- message=prompt, history=[], temperature=0.2,
166
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
167
- call_agent=False, conversation=[]
168
- ):
169
- if isinstance(res, str):
170
- partial += res
171
- elif hasattr(res, "content"):
172
- partial += res.content
173
- cleaned = clean_response(partial)
174
- accumulated_text += f"\n\nπŸ“„ **Chunk {i+1}**:\n{cleaned}"
175
- yield accumulated_text, None, ""
176
 
177
- summary_prompt = f"Summarize this analysis in a final structured report:\n\n" + accumulated_text
178
- final_report = ""
179
- for res in agent.run_gradio_chat(
180
- message=summary_prompt, history=[], temperature=0.2,
181
- max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
182
- call_agent=False, conversation=[]
183
- ):
184
  if isinstance(res, str):
185
- final_report += res
186
  elif hasattr(res, "content"):
187
- final_report += res.content
188
-
189
- cleaned = clean_response(final_report)
190
- accumulated_text += f"\n\nπŸ“Š **Final Summary**:\n{cleaned}"
191
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
192
- with open(report_path, 'w') as f:
193
- f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
194
-
195
- yield accumulated_text, report_path, cleaned
196
-
197
- except Exception as e:
198
- yield f"❌ Error: {str(e)}", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def create_ui(agent):
201
  with gr.Blocks(css="""
202
- body {
203
- background: #10141f;
204
- color: #ffffff;
 
205
  font-family: 'Inter', sans-serif;
206
- margin: 0;
207
- padding: 0;
208
  }
209
- .gradio-container {
210
- padding: 30px;
211
- width: 100vw;
212
- max-width: 100%;
213
- border-radius: 0;
214
- background-color: #1a1f2e;
 
 
 
 
215
  }
216
- .output-markdown {
217
- background-color: #131720;
 
218
  border-radius: 12px;
219
- padding: 20px;
220
- min-height: 600px;
221
- overflow-y: auto;
222
- border: 1px solid #2c3344;
223
  }
224
- .gr-button {
225
- background: linear-gradient(135deg, #4b4ced, #37b6e9);
 
 
 
 
 
 
 
226
  color: white;
227
- font-weight: 500;
228
- border: none;
229
- padding: 10px 20px;
230
  border-radius: 8px;
231
- transition: background 0.3s ease;
232
  }
233
- .gr-button:hover {
234
- background: linear-gradient(135deg, #37b6e9, #4b4ced);
235
  }
236
  """) as demo:
237
- gr.Markdown("""# 🧠 Clinical Reasoning Assistant
238
- Upload clinical Excel records below and click **Analyze** to generate a medical summary.""")
239
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
240
- analyze_btn = gr.Button("Analyze")
241
- report_output_markdown = gr.Markdown(elem_classes="output-markdown")
242
- report_file = gr.File(label="Download Report", visible=False)
243
- full_output = gr.State(value="")
244
-
245
- analyze_btn.click(
246
- fn=stream_report,
247
- inputs=[file_upload, full_output],
248
- outputs=[report_output_markdown, report_file, full_output]
249
- )
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  return demo
252
 
@@ -254,7 +230,7 @@ if __name__ == "__main__":
254
  try:
255
  agent = init_agent()
256
  demo = create_ui(agent)
257
- demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=True)
258
  except Exception as e:
259
  print(f"Error: {str(e)}")
260
  sys.exit(1)
 
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
+ from typing import List, Tuple, Dict, Any, Union
7
  import hashlib
8
  import shutil
9
  import re
 
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 3.5 + 1
44
 
45
+ def extract_text_from_excel(file_path: str) -> str:
46
  all_text = []
47
+ xls = pd.ExcelFile(file_path)
 
 
 
48
  for sheet_name in xls.sheet_names:
49
  df = xls.parse(sheet_name).astype(str).fillna("")
50
+ rows = df.apply(lambda row: " | ".join(row), axis=1)
51
+ sheet_text = [f"[{sheet_name}] {line}" for line in rows]
52
  all_text.extend(sheet_text)
53
  return "\n".join(all_text)
54
 
55
+ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS) -> List[str]:
56
  effective_max = max_tokens - PROMPT_OVERHEAD
57
  lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
58
  for line in lines:
 
60
  if curr_tokens + t > effective_max:
61
  if curr_chunk:
62
  chunks.append("\n".join(curr_chunk))
 
 
63
  curr_chunk, curr_tokens = [line], t
64
  else:
65
  curr_chunk.append(line)
66
  curr_tokens += t
67
+ if curr_chunk:
68
  chunks.append("\n".join(curr_chunk))
69
  return chunks
70
 
 
87
  Respond in well-structured bullet points with medical reasoning.
88
  """
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def init_agent():
91
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
92
+ if not os.path.exists(tool_path):
93
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  agent = TxAgent(
95
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
96
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
97
+ tool_files_dict={"new_tool": tool_path},
98
  force_finish=True,
99
  enable_checker=True,
100
  step_rag_num=4,
 
103
  agent.init_model()
104
  return agent
105
 
106
+ def process_final_report(agent, file, chatbot_state: List[Dict[str, str]]) -> Tuple[List[Dict[str, str]], Union[str, None]]:
107
+ messages = chatbot_state if chatbot_state else []
108
+ if file is None or not hasattr(file, "name"):
109
+ return messages + [{"role": "assistant", "content": "❌ Please upload a valid Excel file."}], None
 
 
110
 
111
+ messages.append({"role": "user", "content": f"Processing Excel file: {os.path.basename(file.name)}"})
112
+ text = extract_text_from_excel(file.name)
113
+ chunks = split_text_into_chunks(text)
114
+ chunk_responses = [None] * len(chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ def analyze_chunk(i, chunk):
117
+ prompt = build_prompt_from_text(chunk)
118
+ response = ""
119
+ for res in agent.run_gradio_chat(message=prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]):
 
 
 
120
  if isinstance(res, str):
121
+ response += res
122
  elif hasattr(res, "content"):
123
+ response += res.content
124
+ elif isinstance(res, list):
125
+ for r in res:
126
+ if hasattr(r, "content"):
127
+ response += r.content
128
+ return i, clean_response(response)
129
+
130
+ with ThreadPoolExecutor(max_workers=1) as executor:
131
+ futures = [executor.submit(analyze_chunk, i, c) for i, c in enumerate(chunks)]
132
+ for f in as_completed(futures):
133
+ i, result = f.result()
134
+ chunk_responses[i] = result
135
+
136
+ valid = [r for r in chunk_responses if r and not r.startswith("❌")]
137
+ if not valid:
138
+ return messages + [{"role": "assistant", "content": "❌ No valid chunk results."}], None
139
+
140
+ summary_prompt = f"Summarize this analysis in a final structured report:\n\n" + "\n\n".join(valid)
141
+ messages.append({"role": "assistant", "content": "πŸ“Š Generating final report..."})
142
+
143
+ final_report = ""
144
+ for res in agent.run_gradio_chat(message=summary_prompt, history=[], temperature=0.2, max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, call_agent=False, conversation=[]):
145
+ if isinstance(res, str):
146
+ final_report += res
147
+ elif hasattr(res, "content"):
148
+ final_report += res.content
149
+
150
+ cleaned = clean_response(final_report)
151
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
152
+ with open(report_path, 'w') as f:
153
+ f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
154
+
155
+ messages.append({"role": "assistant", "content": f"πŸ“Š Final Report:\n\n{cleaned}"})
156
+ messages.append({"role": "assistant", "content": f"βœ… Report generated and saved: {os.path.basename(report_path)}"})
157
+ return messages, report_path
158
 
159
  def create_ui(agent):
160
  with gr.Blocks(css="""
161
+ html, body, .gradio-container {
162
+ height: 100vh;
163
+ background-color: #111827;
164
+ color: #e5e7eb;
165
  font-family: 'Inter', sans-serif;
 
 
166
  }
167
+ .message-avatar {
168
+ width: 38px;
169
+ height: 38px;
170
+ border-radius: 50%;
171
+ margin-right: 10px;
172
+ }
173
+ .chat-message {
174
+ display: flex;
175
+ align-items: flex-start;
176
+ margin-bottom: 1rem;
177
  }
178
+ .message-bubble {
179
+ background-color: #1f2937;
180
+ padding: 12px 16px;
181
  border-radius: 12px;
182
+ max-width: 90%;
 
 
 
183
  }
184
+ .chat-input {
185
+ background-color: #1f2937;
186
+ border: 1px solid #374151;
187
+ border-radius: 8px;
188
+ color: #e5e7eb;
189
+ padding: 0.75rem 1rem;
190
+ }
191
+ .gr-button.primary {
192
+ background: #2563eb;
193
  color: white;
 
 
 
194
  border-radius: 8px;
195
+ font-weight: 600;
196
  }
197
+ .gr-button.primary:hover {
198
+ background: #1e40af;
199
  }
200
  """) as demo:
201
+ gr.Markdown("""<h2 style='color:#60a5fa'>🩺 Patient History AI Assistant</h2><p>Upload a clinical Excel file and receive a structured diagnostic summary.</p>""")
202
+ with gr.Row():
203
+ with gr.Column(scale=3):
204
+ chatbot = gr.Chatbot(
205
+ label="Clinical Assistant",
206
+ height=700,
207
+ type="messages",
208
+ avatar_images=[
209
+ "https://ui-avatars.com/api/?name=AI&background=2563eb&color=fff&size=128",
210
+ "https://ui-avatars.com/api/?name=You&background=374151&color=fff&size=128"
211
+ ]
212
+ )
213
+ with gr.Column(scale=1):
214
+ with gr.Row():
215
+ file_upload = gr.File(label="", file_types=[".xlsx"], elem_id="upload-btn")
216
+ analyze_btn = gr.Button("🧠 Analyze", variant="primary")
217
+ report_output = gr.File(label="Download Report", visible=False, interactive=False)
218
+
219
+ chatbot_state = gr.State(value=[])
220
+
221
+ def update_ui(file, current_state):
222
+ messages, report_path = process_final_report(agent, file, current_state)
223
+ return messages, gr.update(visible=report_path is not None, value=report_path), messages
224
+
225
+ analyze_btn.click(fn=update_ui, inputs=[file_upload, chatbot_state], outputs=[chatbot, report_output, chatbot_state])
226
 
227
  return demo
228
 
 
230
  try:
231
  agent = init_agent()
232
  demo = create_ui(agent)
233
+ demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
234
  except Exception as e:
235
  print(f"Error: {str(e)}")
236
  sys.exit(1)