Ali2206 commited on
Commit
1bdb280
·
verified ·
1 Parent(s): 13ad0d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -130
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import sys
2
  import os
3
  import pandas as pd
4
- import json
5
  import gradio as gr
6
  from typing import List, Tuple, Dict, Any
7
  import hashlib
@@ -11,20 +10,15 @@ from datetime import datetime
11
  import time
12
  from collections import defaultdict
13
 
14
- # Configuration and setup
15
- persistent_dir = "/data/hf_cache"
16
- os.makedirs(persistent_dir, exist_ok=True)
 
17
 
18
- model_cache_dir = os.path.join(persistent_dir, "txagent_models")
19
- tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
20
- file_cache_dir = os.path.join(persistent_dir, "cache")
21
- report_dir = os.path.join(persistent_dir, "reports")
22
-
23
- for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
- os.makedirs(directory, exist_ok=True)
25
-
26
- os.environ["HF_HOME"] = model_cache_dir
27
- os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
 
29
  current_dir = os.path.dirname(os.path.abspath(__file__))
30
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -36,34 +30,20 @@ from txagent.txagent import TxAgent
36
  MAX_TOKENS = 32768
37
  CHUNK_SIZE = 10000
38
  MAX_NEW_TOKENS = 2048
39
- MAX_BOOKINGS_PER_CHUNK = 5
40
-
41
- def file_hash(path: str) -> str:
42
- with open(path, "rb") as f:
43
- return hashlib.md5(f.read()).hexdigest()
44
 
45
  def clean_response(text: str) -> str:
46
- try:
47
- text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
48
- except UnicodeError:
49
- text = text.encode('utf-8', 'replace').decode('utf-8')
50
-
51
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
52
  text = re.sub(r"\n{3,}", "\n\n", text)
53
- text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
54
  return text.strip()
55
 
56
- def estimate_tokens(text: str) -> int:
57
- return len(text) // 3.5
58
-
59
  def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
 
60
  data = {
61
  'bookings': defaultdict(list),
62
  'medications': defaultdict(list),
63
  'diagnoses': defaultdict(list),
64
  'tests': defaultdict(list),
65
- 'procedures': defaultdict(list),
66
- 'doctors': set(),
67
  'timeline': []
68
  }
69
 
@@ -82,100 +62,62 @@ def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
82
 
83
  data['bookings'][booking].append(entry)
84
  data['timeline'].append(entry)
85
- data['doctors'].add(entry['doctor'])
86
 
87
  form_lower = entry['form'].lower()
88
- if 'medication' in form_lower or 'drug' in form_lower:
89
  data['medications'][entry['item']].append(entry)
90
- elif 'diagnosis' in form_lower or 'condition' in form_lower:
91
  data['diagnoses'][entry['item']].append(entry)
92
- elif 'test' in form_lower or 'lab' in form_lower or 'result' in form_lower:
93
  data['tests'][entry['item']].append(entry)
94
- elif 'procedure' in form_lower or 'surgery' in form_lower:
95
- data['procedures'][entry['item']].append(entry)
96
 
97
  return data
98
 
99
  def generate_analysis_prompt(patient_data: Dict[str, Any], bookings: List[str]) -> str:
100
- prompt_lines = [
 
101
  "**Comprehensive Patient Analysis**",
102
  f"Analyzing {len(bookings)} bookings",
103
  "",
104
- "**Key Analysis Points:**",
105
- "- Chronological progression of symptoms",
106
- "- Medication changes and interactions",
107
- "- Diagnostic consistency across providers",
108
- "- Missed diagnostic opportunities",
109
- "- Gaps in follow-up",
110
- "",
111
- "**Patient Timeline:**"
112
  ]
113
 
114
  for entry in patient_data['timeline']:
115
  if entry['booking'] in bookings:
116
- prompt_lines.append(
117
- f"- {entry['date']}: {entry['form']} - {entry['item']} = {entry['response']} (by {entry['doctor']})"
118
- )
119
 
120
- prompt_lines.extend([
121
  "",
122
- "**Medication History:**",
123
- *[f"- {med}: " + " → ".join(
124
- f"{e['date']}: {e['response']}"
125
- for e in entries if e['booking'] in bookings
126
- ) for med, entries in patient_data['medications'].items()],
127
  "",
128
- "**Required Analysis Format:**",
129
- "### Diagnostic Patterns",
130
- "### Medication Analysis",
131
- "### Provider Consistency",
132
- "### Missed Opportunities",
133
- "### Recommendations"
134
  ])
135
 
136
- return "\n".join(prompt_lines)
137
-
138
- def chunk_bookings(patient_data: Dict[str, Any]) -> List[List[str]]:
139
- all_bookings = list(patient_data['bookings'].keys())
140
- booking_sizes = []
141
-
142
- for booking in all_bookings:
143
- entries = patient_data['bookings'][booking]
144
- size = sum(estimate_tokens(str(e)) for e in entries)
145
- booking_sizes.append((booking, size))
146
-
147
- booking_sizes.sort(key=lambda x: x[1], reverse=True)
148
- chunks = [[] for _ in range(3)]
149
- chunk_sizes = [0, 0, 0]
150
-
151
- for booking, size in booking_sizes:
152
- min_chunk = chunk_sizes.index(min(chunk_sizes))
153
- chunks[min_chunk].append(booking)
154
- chunk_sizes[min_chunk] += size
155
-
156
- return chunks
157
 
158
  def init_agent():
159
- default_tool_path = os.path.abspath("data/new_tool.json")
160
- target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
161
-
162
- if not os.path.exists(target_tool_path):
163
- shutil.copy(default_tool_path, target_tool_path)
164
 
165
- agent = TxAgent(
166
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
167
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
168
- tool_files_dict={"new_tool": target_tool_path},
169
  force_finish=True,
170
  enable_checker=True,
171
  step_rag_num=4,
172
  seed=100,
173
  additional_default_tools=[]
174
  )
175
- agent.init_model()
176
- return agent
177
 
178
  def analyze_with_agent(agent, prompt: str) -> str:
 
179
  try:
180
  response = ""
181
  for result in agent.run_gradio_chat(
@@ -187,11 +129,7 @@ def analyze_with_agent(agent, prompt: str) -> str:
187
  call_agent=False,
188
  conversation=[],
189
  ):
190
- if isinstance(result, list):
191
- for r in result:
192
- if hasattr(r, 'content') and r.content:
193
- response += clean_response(r.content) + "\n"
194
- elif isinstance(result, str):
195
  response += clean_response(result) + "\n"
196
  elif hasattr(result, 'content'):
197
  response += clean_response(result.content) + "\n"
@@ -201,76 +139,64 @@ def analyze_with_agent(agent, prompt: str) -> str:
201
  return f"Error in analysis: {str(e)}"
202
 
203
  def create_ui(agent):
204
- with gr.Blocks(theme=gr.themes.Soft(), title="Patient History Analyzer") as demo:
205
- gr.Markdown("# 🏥 Patient History Analyzer")
206
 
207
  with gr.Tabs():
208
- with gr.TabItem("Analysis"):
209
  with gr.Row():
210
- with gr.Column(scale=1):
211
- file_upload = gr.File(
212
- label="Upload Excel File",
213
- file_types=[".xlsx"],
214
- file_count="single"
215
- )
216
  analyze_btn = gr.Button("Analyze", variant="primary")
217
- status = gr.Markdown("Ready")
218
 
219
- with gr.Column(scale=2):
220
  output = gr.Markdown()
221
- report = gr.File(label="Download Report")
222
 
223
  with gr.TabItem("Instructions"):
224
  gr.Markdown("""
225
- ## How to Use
226
  1. Upload patient history Excel
227
  2. Click Analyze
228
- 3. View/download report
229
 
230
  **Required Columns:**
231
  - Booking Number
232
  - Interview Date
233
  - Interviewer
234
  - Form Name
235
- - Form Item
236
  - Item Response
237
  - Description
238
  """)
239
 
240
  def analyze(file):
241
  if not file:
242
- raise gr.Error("Please upload a file")
243
 
244
  try:
 
245
  df = pd.read_excel(file.name)
246
  patient_data = process_patient_data(df)
247
- chunks = chunk_bookings(patient_data)
248
- full_report = []
249
-
250
- for i, bookings in enumerate(chunks, 1):
251
- prompt = generate_analysis_prompt(patient_data, bookings)
252
- response = analyze_with_agent(agent, prompt)
253
- full_report.append(f"## Chunk {i}\n{response}\n")
254
- yield "\n".join(full_report), None
255
 
256
- # Final summary
257
- if len(chunks) > 1:
258
- summary_prompt = "Create final summary combining all chunks"
259
- summary = analyze_with_agent(agent, summary_prompt)
260
- full_report.append(f"## Final Summary\n{summary}\n")
261
 
262
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
 
 
263
  with open(report_path, 'w') as f:
264
- f.write("\n".join(full_report))
265
 
266
- yield "\n".join(full_report), report_path
267
 
268
  except Exception as e:
269
- raise gr.Error(f"Error: {str(e)}")
270
 
271
  analyze_btn.click(
272
  analyze,
273
- inputs=file_upload,
274
  outputs=[output, report]
275
  )
276
 
@@ -283,7 +209,8 @@ if __name__ == "__main__":
283
  demo.launch(
284
  server_name="0.0.0.0",
285
  server_port=7860,
286
- show_error=True
 
287
  )
288
  except Exception as e:
289
  print(f"Error: {str(e)}")
 
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import gradio as gr
5
  from typing import List, Tuple, Dict, Any
6
  import hashlib
 
10
  import time
11
  from collections import defaultdict
12
 
13
+ # Configuration - Use paths that Gradio can access
14
+ WORKING_DIR = os.getcwd()
15
+ REPORT_DIR = os.path.join(WORKING_DIR, "reports")
16
+ os.makedirs(REPORT_DIR, exist_ok=True)
17
 
18
+ # Model configuration
19
+ MODEL_CACHE_DIR = os.path.join(WORKING_DIR, "model_cache")
20
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
21
+ os.environ["HF_HOME"] = MODEL_CACHE_DIR
 
 
 
 
 
 
22
 
23
  current_dir = os.path.dirname(os.path.abspath(__file__))
24
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
30
  MAX_TOKENS = 32768
31
  CHUNK_SIZE = 10000
32
  MAX_NEW_TOKENS = 2048
 
 
 
 
 
33
 
34
  def clean_response(text: str) -> str:
35
+ """Clean and normalize text output"""
 
 
 
 
36
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
37
  text = re.sub(r"\n{3,}", "\n\n", text)
 
38
  return text.strip()
39
 
 
 
 
40
  def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
41
+ """Process patient data into structured format"""
42
  data = {
43
  'bookings': defaultdict(list),
44
  'medications': defaultdict(list),
45
  'diagnoses': defaultdict(list),
46
  'tests': defaultdict(list),
 
 
47
  'timeline': []
48
  }
49
 
 
62
 
63
  data['bookings'][booking].append(entry)
64
  data['timeline'].append(entry)
 
65
 
66
  form_lower = entry['form'].lower()
67
+ if 'medication' in form_lower:
68
  data['medications'][entry['item']].append(entry)
69
+ elif 'diagnosis' in form_lower:
70
  data['diagnoses'][entry['item']].append(entry)
71
+ elif 'test' in form_lower:
72
  data['tests'][entry['item']].append(entry)
 
 
73
 
74
  return data
75
 
76
  def generate_analysis_prompt(patient_data: Dict[str, Any], bookings: List[str]) -> str:
77
+ """Generate analysis prompt for a set of bookings"""
78
+ prompt = [
79
  "**Comprehensive Patient Analysis**",
80
  f"Analyzing {len(bookings)} bookings",
81
  "",
82
+ "**Timeline:**"
 
 
 
 
 
 
 
83
  ]
84
 
85
  for entry in patient_data['timeline']:
86
  if entry['booking'] in bookings:
87
+ prompt.append(f"- {entry['date']}: {entry['form']} - {entry['item']} = {entry['response']}")
 
 
88
 
89
+ prompt.extend([
90
  "",
91
+ "**Analysis Focus:**",
92
+ "1. Identify missed diagnoses",
93
+ "2. Check medication conflicts",
94
+ "3. Note incomplete assessments",
95
+ "4. Flag urgent follow-ups",
96
  "",
97
+ "### Findings"
 
 
 
 
 
98
  ])
99
 
100
+ return "\n".join(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def init_agent():
103
+ """Initialize TxAgent with proper configuration"""
104
+ tool_path = os.path.join(WORKING_DIR, "data", "new_tool.json")
105
+ if not os.path.exists(tool_path):
106
+ raise FileNotFoundError(f"Tool file not found at {tool_path}")
 
107
 
108
+ return TxAgent(
109
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
110
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
111
+ tool_files_dict={"new_tool": tool_path},
112
  force_finish=True,
113
  enable_checker=True,
114
  step_rag_num=4,
115
  seed=100,
116
  additional_default_tools=[]
117
  )
 
 
118
 
119
  def analyze_with_agent(agent, prompt: str) -> str:
120
+ """Run analysis with error handling"""
121
  try:
122
  response = ""
123
  for result in agent.run_gradio_chat(
 
129
  call_agent=False,
130
  conversation=[],
131
  ):
132
+ if isinstance(result, str):
 
 
 
 
133
  response += clean_response(result) + "\n"
134
  elif hasattr(result, 'content'):
135
  response += clean_response(result.content) + "\n"
 
139
  return f"Error in analysis: {str(e)}"
140
 
141
  def create_ui(agent):
142
+ with gr.Blocks(title="Patient History Analyzer") as demo:
143
+ gr.Markdown("# 🏥 Patient History Analysis")
144
 
145
  with gr.Tabs():
146
+ with gr.TabItem("Analyze"):
147
  with gr.Row():
148
+ with gr.Column():
149
+ file_input = gr.File(label="Upload Excel File", file_types=[".xlsx"])
 
 
 
 
150
  analyze_btn = gr.Button("Analyze", variant="primary")
 
151
 
152
+ with gr.Column():
153
  output = gr.Markdown()
154
+ report = gr.File(label="Download Report", interactive=False)
155
 
156
  with gr.TabItem("Instructions"):
157
  gr.Markdown("""
158
+ **How to Use:**
159
  1. Upload patient history Excel
160
  2. Click Analyze
161
+ 3. View and download report
162
 
163
  **Required Columns:**
164
  - Booking Number
165
  - Interview Date
166
  - Interviewer
167
  - Form Name
168
+ - Form Item
169
  - Item Response
170
  - Description
171
  """)
172
 
173
  def analyze(file):
174
  if not file:
175
+ raise gr.Error("Please upload a file first")
176
 
177
  try:
178
+ # Process file
179
  df = pd.read_excel(file.name)
180
  patient_data = process_patient_data(df)
 
 
 
 
 
 
 
 
181
 
182
+ # Analyze all bookings together (fits within 32k tokens)
183
+ prompt = generate_analysis_prompt(patient_data, list(patient_data['bookings'].keys()))
184
+ analysis = analyze_with_agent(agent, prompt)
 
 
185
 
186
+ # Save report to allowed directory
187
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
188
+ report_path = os.path.join(REPORT_DIR, f"report_{timestamp}.md")
189
  with open(report_path, 'w') as f:
190
+ f.write(analysis)
191
 
192
+ return analysis, report_path
193
 
194
  except Exception as e:
195
+ raise gr.Error(f"Analysis failed: {str(e)}")
196
 
197
  analyze_btn.click(
198
  analyze,
199
+ inputs=file_input,
200
  outputs=[output, report]
201
  )
202
 
 
209
  demo.launch(
210
  server_name="0.0.0.0",
211
  server_port=7860,
212
+ show_error=True,
213
+ allowed_paths=[WORKING_DIR, REPORT_DIR] # Allow access to these paths
214
  )
215
  except Exception as e:
216
  print(f"Error: {str(e)}")