Ali2206 commited on
Commit
8c16b9e
·
verified ·
1 Parent(s): aa559b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -147
app.py CHANGED
@@ -1,43 +1,30 @@
1
- # ✅ Fully updated app.py for TxAgent with strict tool validation to prevent runtime errors
2
-
3
  import sys
4
  import os
5
  import pandas as pd
6
  import json
7
  import gradio as gr
8
- from typing import List, Tuple, Union, Generator, Dict, Any
 
 
9
  import re
10
  from datetime import datetime
11
- import atexit
12
- import torch.distributed as dist
13
- import logging
14
-
15
- # Logging
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
- logger = logging.getLogger("app")
18
-
19
- # Cleanup
20
-
21
- def cleanup():
22
- if dist.is_initialized():
23
- logger.info("Cleaning up PyTorch distributed process group")
24
- dist.destroy_process_group()
25
-
26
- atexit.register(cleanup)
27
 
28
- # Directories
29
  persistent_dir = "/data/hf_cache"
30
  os.makedirs(persistent_dir, exist_ok=True)
 
31
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
32
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
33
  file_cache_dir = os.path.join(persistent_dir, "cache")
34
  report_dir = os.path.join(persistent_dir, "reports")
 
35
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
36
  os.makedirs(d, exist_ok=True)
 
37
  os.environ["HF_HOME"] = model_cache_dir
38
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
39
 
40
- # Import TxAgent
41
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
42
  from txagent.txagent import TxAgent
43
 
@@ -46,125 +33,73 @@ MAX_CHUNK_TOKENS = 8192
46
  MAX_NEW_TOKENS = 2048
47
  PROMPT_OVERHEAD = 500
48
 
49
-
50
  def clean_response(text: str) -> str:
51
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
52
  text = re.sub(r"\n{3,}", "\n\n", text)
53
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
54
  return text.strip()
55
 
56
-
57
  def estimate_tokens(text: str) -> int:
58
  return len(text) // 3.5 + 1
59
 
60
-
61
- def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
62
- if isinstance(file_obj, dict) and 'name' in file_obj:
63
- file_path = file_obj['name']
64
- elif isinstance(file_obj, str):
65
- file_path = file_obj
66
- else:
67
- raise ValueError("Unsupported file input type")
68
- if not os.path.exists(file_path):
69
- raise FileNotFoundError(f"File not found: {file_path}")
70
- xls = pd.ExcelFile(file_path)
71
  all_text = []
72
- for sheet in xls.sheet_names:
73
- try:
74
- df = xls.parse(sheet).astype(str).fillna("")
75
- rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1)
76
- sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()]
77
- all_text.extend(sheet_text)
78
- except Exception as e:
79
- logger.warning(f"Failed to parse {sheet}: {e}")
 
80
  return "\n".join(all_text)
81
 
82
-
83
- def split_text_into_chunks(text: str) -> List[str]:
84
- lines = text.split("\n")
85
- chunks, current, current_tokens = [], [], 0
86
- max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
87
  for line in lines:
88
  t = estimate_tokens(line)
89
- if current_tokens + t > max_tokens:
90
- chunks.append("\n".join(current))
91
- current, current_tokens = [line], t
 
 
 
92
  else:
93
- current.append(line)
94
- current_tokens += t
95
- if current:
96
- chunks.append("\n".join(current))
97
  return chunks
98
 
99
-
100
  def build_prompt_from_text(chunk: str) -> str:
101
  return f"""
102
- ### Clinical Records Analysis
103
 
104
- Please analyze these clinical notes and provide:
105
- - Key diagnostic indicators
106
- - Current medications and potential issues
107
- - Recommended follow-up actions
108
- - Any inconsistencies or concerns
 
109
 
110
  ---
111
 
112
  {chunk}
113
 
114
  ---
115
- Provide a structured response with clear medical reasoning.
116
  """
117
 
118
-
119
- def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool:
120
- try:
121
- with open(original_path, "r") as f:
122
- data = json.load(f)
123
- if isinstance(data, dict) and "tools" in data:
124
- tools = data["tools"]
125
- elif isinstance(data, list):
126
- tools = data
127
- elif isinstance(data, dict) and "name" in data:
128
- tools = [data]
129
- else:
130
- return False
131
- if not all(isinstance(t, dict) and "name" in t for t in tools):
132
- return False
133
- with open(cleaned_path, "w") as out:
134
- json.dump(tools, out)
135
- return True
136
- except Exception as e:
137
- logger.error(f"Failed to clean tool {original_path}: {e}")
138
- return False
139
-
140
-
141
- def init_agent() -> TxAgent:
142
- new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
143
- if not os.path.exists(new_tool_path):
144
- with open(new_tool_path, 'w') as f:
145
- json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f)
146
-
147
- raw_tool_files = {
148
- 'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
149
- 'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
150
- 'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
151
- 'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
152
- 'new_tool': new_tool_path
153
- }
154
-
155
- validated_paths = {}
156
- for name, original_path in raw_tool_files.items():
157
- cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json")
158
- if clean_and_rewrite_tool_file(original_path, cleaned_path):
159
- validated_paths[name] = cleaned_path
160
-
161
- if not validated_paths:
162
- raise ValueError("No valid tools found after sanitizing.")
163
-
164
  agent = TxAgent(
165
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
166
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
167
- tool_files_dict=validated_paths,
168
  force_finish=True,
169
  enable_checker=True,
170
  step_rag_num=4,
@@ -173,65 +108,120 @@ def init_agent() -> TxAgent:
173
  agent.init_model()
174
  return agent
175
 
176
-
177
- def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
178
- accumulated = ""
179
  try:
180
  if input_file is None:
181
- yield "❌ Upload a valid Excel file.", None, ""
182
  return
183
- text = extract_text_from_excel(input_file)
 
 
 
 
 
 
 
184
  chunks = split_text_into_chunks(text)
 
185
  for i, chunk in enumerate(chunks):
186
  prompt = build_prompt_from_text(chunk)
187
- result = ""
188
- for out in agent.run_gradio_chat(
189
  message=prompt, history=[], temperature=0.2,
190
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
191
- call_agent=False, conversation=[]):
192
- result += out if isinstance(out, str) else out.content
193
- cleaned = clean_response(result)
194
- accumulated += f"\n\n📄 Part {i+1}:\n{cleaned}"
195
- yield accumulated, None, ""
196
- summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
197
- summary = ""
198
- for out in agent.run_gradio_chat(
 
 
 
 
 
199
  message=summary_prompt, history=[], temperature=0.2,
200
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
201
- call_agent=False, conversation=[]):
202
- summary += out if isinstance(out, str) else out.content
203
- final = clean_response(summary)
 
 
 
 
 
 
204
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
205
  with open(report_path, 'w') as f:
206
- f.write(f"# Clinical Report\n\n{final}")
207
- yield f"{accumulated}\n\n📊 Final Summary:\n{final}", report_path, final
 
 
208
  except Exception as e:
209
- logger.error(f"Stream error: {e}", exc_info=True)
210
  yield f"❌ Error: {str(e)}", None, ""
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- def create_ui(agent: TxAgent) -> gr.Blocks:
214
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
215
- gr.Markdown("# 🏥 Clinical Records Analyzer")
216
- with gr.Row():
217
- file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
218
- analyze_btn = gr.Button("Analyze", variant="primary")
219
- with gr.Row():
220
- with gr.Column(scale=2):
221
- report_output = gr.Markdown()
222
- with gr.Column(scale=1):
223
- report_file = gr.File(label="Download", visible=False)
224
- full_output = gr.State()
225
- analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
226
  return demo
227
 
228
-
229
  if __name__ == "__main__":
230
  try:
231
  agent = init_agent()
232
  demo = create_ui(agent)
233
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
234
  except Exception as e:
235
- logger.error(f"App error: {e}", exc_info=True)
236
- print(f"❌ Application error: {e}", file=sys.stderr)
237
  sys.exit(1)
 
 
 
1
  import sys
2
  import os
3
  import pandas as pd
4
  import json
5
  import gradio as gr
6
+ from typing import List, Tuple, Union, Generator
7
+ import hashlib
8
+ import shutil
9
  import re
10
  from datetime import datetime
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Setup directories
14
  persistent_dir = "/data/hf_cache"
15
  os.makedirs(persistent_dir, exist_ok=True)
16
+
17
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
18
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
19
  file_cache_dir = os.path.join(persistent_dir, "cache")
20
  report_dir = os.path.join(persistent_dir, "reports")
21
+
22
  for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
23
  os.makedirs(d, exist_ok=True)
24
+
25
  os.environ["HF_HOME"] = model_cache_dir
26
  os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
27
 
 
28
  sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
29
  from txagent.txagent import TxAgent
30
 
 
33
  MAX_NEW_TOKENS = 2048
34
  PROMPT_OVERHEAD = 500
35
 
 
36
  def clean_response(text: str) -> str:
37
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
38
  text = re.sub(r"\n{3,}", "\n\n", text)
39
  text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
40
  return text.strip()
41
 
 
42
  def estimate_tokens(text: str) -> int:
43
  return len(text) // 3.5 + 1
44
 
45
+ def extract_text_from_excel(file_obj: Union[str, os.PathLike, 'file']) -> str:
 
 
 
 
 
 
 
 
 
 
46
  all_text = []
47
+ try:
48
+ xls = pd.ExcelFile(file_obj)
49
+ except Exception as e:
50
+ raise ValueError(f" Error reading Excel file: {e}")
51
+ for sheet_name in xls.sheet_names:
52
+ df = xls.parse(sheet_name).astype(str).fillna("")
53
+ rows = df.apply(lambda row: " | ".join([cell for cell in row if cell.strip()]), axis=1)
54
+ sheet_text = [f"[{sheet_name}] {line}" for line in rows if line.strip()]
55
+ all_text.extend(sheet_text)
56
  return "\n".join(all_text)
57
 
58
+ def split_text_into_chunks(text: str, max_tokens: int = MAX_CHUNK_TOKENS, max_chunks: int = 30) -> List[str]:
59
+ effective_max = max_tokens - PROMPT_OVERHEAD
60
+ lines, chunks, curr_chunk, curr_tokens = text.split("\n"), [], [], 0
 
 
61
  for line in lines:
62
  t = estimate_tokens(line)
63
+ if curr_tokens + t > effective_max:
64
+ if curr_chunk:
65
+ chunks.append("\n".join(curr_chunk))
66
+ if len(chunks) >= max_chunks:
67
+ break
68
+ curr_chunk, curr_tokens = [line], t
69
  else:
70
+ curr_chunk.append(line)
71
+ curr_tokens += t
72
+ if curr_chunk and len(chunks) < max_chunks:
73
+ chunks.append("\n".join(curr_chunk))
74
  return chunks
75
 
 
76
  def build_prompt_from_text(chunk: str) -> str:
77
  return f"""
78
+ ### Unstructured Clinical Records
79
 
80
+ Analyze the following clinical notes and provide a detailed, concise summary focusing on:
81
+ - Diagnostic Patterns
82
+ - Medication Issues
83
+ - Missed Opportunities
84
+ - Inconsistencies
85
+ - Follow-up Recommendations
86
 
87
  ---
88
 
89
  {chunk}
90
 
91
  ---
92
+ Respond in well-structured bullet points with medical reasoning.
93
  """
94
 
95
+ def init_agent():
96
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
97
+ if not os.path.exists(tool_path):
98
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  agent = TxAgent(
100
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
101
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
102
+ tool_files_dict={"new_tool": tool_path},
103
  force_finish=True,
104
  enable_checker=True,
105
  step_rag_num=4,
 
108
  agent.init_model()
109
  return agent
110
 
111
+ def stream_report(agent, input_file: Union[str, 'file'], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
112
+ accumulated_text = ""
 
113
  try:
114
  if input_file is None:
115
+ yield "❌ Please upload a valid Excel file.", None, ""
116
  return
117
+
118
+ if hasattr(input_file, "read"):
119
+ text = extract_text_from_excel(input_file)
120
+ elif isinstance(input_file, str) and os.path.exists(input_file):
121
+ text = extract_text_from_excel(input_file)
122
+ else:
123
+ raise ValueError("❌ Invalid or missing file.")
124
+
125
  chunks = split_text_into_chunks(text)
126
+
127
  for i, chunk in enumerate(chunks):
128
  prompt = build_prompt_from_text(chunk)
129
+ partial = ""
130
+ for res in agent.run_gradio_chat(
131
  message=prompt, history=[], temperature=0.2,
132
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
133
+ call_agent=False, conversation=[]
134
+ ):
135
+ if isinstance(res, str):
136
+ partial += res
137
+ elif hasattr(res, "content"):
138
+ partial += res.content
139
+ cleaned = clean_response(partial)
140
+ accumulated_text += f"\n\n📄 **Chunk {i+1}**:\n{cleaned}"
141
+ yield accumulated_text, None, ""
142
+
143
+ summary_prompt = f"Summarize this analysis in a final structured report:\n\n" + accumulated_text
144
+ final_report = ""
145
+ for res in agent.run_gradio_chat(
146
  message=summary_prompt, history=[], temperature=0.2,
147
  max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
148
+ call_agent=False, conversation=[]
149
+ ):
150
+ if isinstance(res, str):
151
+ final_report += res
152
+ elif hasattr(res, "content"):
153
+ final_report += res.content
154
+
155
+ cleaned = clean_response(final_report)
156
+ accumulated_text += f"\n\n📊 **Final Summary**:\n{cleaned}"
157
  report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
158
  with open(report_path, 'w') as f:
159
+ f.write(f"# 🧠 Final Patient Report\n\n{cleaned}")
160
+
161
+ yield accumulated_text, report_path, cleaned
162
+
163
  except Exception as e:
 
164
  yield f"❌ Error: {str(e)}", None, ""
165
 
166
+ def create_ui(agent):
167
+ with gr.Blocks(css="""
168
+ body {
169
+ background: #10141f;
170
+ color: #ffffff;
171
+ font-family: 'Inter', sans-serif;
172
+ margin: 0;
173
+ padding: 0;
174
+ }
175
+ .gradio-container {
176
+ padding: 30px;
177
+ width: 100vw;
178
+ max-width: 100%;
179
+ border-radius: 0;
180
+ background-color: #1a1f2e;
181
+ }
182
+ .output-markdown {
183
+ background-color: #131720;
184
+ border-radius: 12px;
185
+ padding: 20px;
186
+ min-height: 600px;
187
+ overflow-y: auto;
188
+ border: 1px solid #2c3344;
189
+ }
190
+ .gr-button {
191
+ background: linear-gradient(135deg, #4b4ced, #37b6e9);
192
+ color: white;
193
+ font-weight: 500;
194
+ border: none;
195
+ padding: 10px 20px;
196
+ border-radius: 8px;
197
+ transition: background 0.3s ease;
198
+ }
199
+ .gr-button:hover {
200
+ background: linear-gradient(135deg, #37b6e9, #4b4ced);
201
+ }
202
+ """) as demo:
203
+ gr.Markdown("""# 🧠 Clinical Reasoning Assistant
204
+ Upload clinical Excel records below and click **Analyze** to generate a medical summary.
205
+ """)
206
+ file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
207
+ analyze_btn = gr.Button("Analyze")
208
+ report_output_markdown = gr.Markdown(elem_classes="output-markdown")
209
+ report_file = gr.File(label="Download Report", visible=False)
210
+ full_output = gr.State(value="")
211
+
212
+ analyze_btn.click(
213
+ fn=stream_report,
214
+ inputs=[file_upload, full_output],
215
+ outputs=[report_output_markdown, report_file, full_output]
216
+ )
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  return demo
219
 
 
220
  if __name__ == "__main__":
221
  try:
222
  agent = init_agent()
223
  demo = create_ui(agent)
224
+ demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=True)
225
  except Exception as e:
226
+ print(f"Error: {str(e)}")
 
227
  sys.exit(1)