Ali2206 commited on
Commit
f6e551c
·
verified ·
1 Parent(s): 1bdb280

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -56
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import sys
2
  import os
3
  import pandas as pd
 
4
  import gradio as gr
5
  from typing import List, Tuple, Dict, Any
6
  import hashlib
@@ -10,15 +11,20 @@ from datetime import datetime
10
  import time
11
  from collections import defaultdict
12
 
13
- # Configuration - Use paths that Gradio can access
14
- WORKING_DIR = os.getcwd()
15
- REPORT_DIR = os.path.join(WORKING_DIR, "reports")
16
- os.makedirs(REPORT_DIR, exist_ok=True)
17
 
18
- # Model configuration
19
- MODEL_CACHE_DIR = os.path.join(WORKING_DIR, "model_cache")
20
- os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
21
- os.environ["HF_HOME"] = MODEL_CACHE_DIR
 
 
 
 
 
 
22
 
23
  current_dir = os.path.dirname(os.path.abspath(__file__))
24
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
@@ -30,20 +36,34 @@ from txagent.txagent import TxAgent
30
  MAX_TOKENS = 32768
31
  CHUNK_SIZE = 10000
32
  MAX_NEW_TOKENS = 2048
 
 
 
 
 
33
 
34
  def clean_response(text: str) -> str:
35
- """Clean and normalize text output"""
 
 
 
 
36
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
37
  text = re.sub(r"\n{3,}", "\n\n", text)
 
38
  return text.strip()
39
 
 
 
 
40
  def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
41
- """Process patient data into structured format"""
42
  data = {
43
  'bookings': defaultdict(list),
44
  'medications': defaultdict(list),
45
  'diagnoses': defaultdict(list),
46
  'tests': defaultdict(list),
 
 
47
  'timeline': []
48
  }
49
 
@@ -62,62 +82,100 @@ def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
62
 
63
  data['bookings'][booking].append(entry)
64
  data['timeline'].append(entry)
 
65
 
66
  form_lower = entry['form'].lower()
67
- if 'medication' in form_lower:
68
  data['medications'][entry['item']].append(entry)
69
- elif 'diagnosis' in form_lower:
70
  data['diagnoses'][entry['item']].append(entry)
71
- elif 'test' in form_lower:
72
  data['tests'][entry['item']].append(entry)
 
 
73
 
74
  return data
75
 
76
  def generate_analysis_prompt(patient_data: Dict[str, Any], bookings: List[str]) -> str:
77
- """Generate analysis prompt for a set of bookings"""
78
- prompt = [
79
  "**Comprehensive Patient Analysis**",
80
  f"Analyzing {len(bookings)} bookings",
81
  "",
82
- "**Timeline:**"
 
 
 
 
 
 
 
83
  ]
84
 
85
  for entry in patient_data['timeline']:
86
  if entry['booking'] in bookings:
87
- prompt.append(f"- {entry['date']}: {entry['form']} - {entry['item']} = {entry['response']}")
 
 
88
 
89
- prompt.extend([
90
  "",
91
- "**Analysis Focus:**",
92
- "1. Identify missed diagnoses",
93
- "2. Check medication conflicts",
94
- "3. Note incomplete assessments",
95
- "4. Flag urgent follow-ups",
96
  "",
97
- "### Findings"
 
 
 
 
 
98
  ])
99
 
100
- return "\n".join(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def init_agent():
103
- """Initialize TxAgent with proper configuration"""
104
- tool_path = os.path.join(WORKING_DIR, "data", "new_tool.json")
105
- if not os.path.exists(tool_path):
106
- raise FileNotFoundError(f"Tool file not found at {tool_path}")
 
107
 
108
- return TxAgent(
109
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
110
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
111
- tool_files_dict={"new_tool": tool_path},
112
  force_finish=True,
113
  enable_checker=True,
114
  step_rag_num=4,
115
  seed=100,
116
  additional_default_tools=[]
117
  )
 
 
118
 
119
  def analyze_with_agent(agent, prompt: str) -> str:
120
- """Run analysis with error handling"""
121
  try:
122
  response = ""
123
  for result in agent.run_gradio_chat(
@@ -129,7 +187,11 @@ def analyze_with_agent(agent, prompt: str) -> str:
129
  call_agent=False,
130
  conversation=[],
131
  ):
132
- if isinstance(result, str):
 
 
 
 
133
  response += clean_response(result) + "\n"
134
  elif hasattr(result, 'content'):
135
  response += clean_response(result.content) + "\n"
@@ -139,64 +201,76 @@ def analyze_with_agent(agent, prompt: str) -> str:
139
  return f"Error in analysis: {str(e)}"
140
 
141
  def create_ui(agent):
142
- with gr.Blocks(title="Patient History Analyzer") as demo:
143
- gr.Markdown("# 🏥 Patient History Analysis")
144
 
145
  with gr.Tabs():
146
- with gr.TabItem("Analyze"):
147
  with gr.Row():
148
- with gr.Column():
149
- file_input = gr.File(label="Upload Excel File", file_types=[".xlsx"])
 
 
 
 
150
  analyze_btn = gr.Button("Analyze", variant="primary")
 
151
 
152
- with gr.Column():
153
  output = gr.Markdown()
154
- report = gr.File(label="Download Report", interactive=False)
155
 
156
  with gr.TabItem("Instructions"):
157
  gr.Markdown("""
158
- **How to Use:**
159
  1. Upload patient history Excel
160
  2. Click Analyze
161
- 3. View and download report
162
 
163
  **Required Columns:**
164
  - Booking Number
165
  - Interview Date
166
  - Interviewer
167
  - Form Name
168
- - Form Item
169
  - Item Response
170
  - Description
171
  """)
172
 
173
  def analyze(file):
174
  if not file:
175
- raise gr.Error("Please upload a file first")
176
 
177
  try:
178
- # Process file
179
  df = pd.read_excel(file.name)
180
  patient_data = process_patient_data(df)
 
 
 
 
 
 
 
 
181
 
182
- # Analyze all bookings together (fits within 32k tokens)
183
- prompt = generate_analysis_prompt(patient_data, list(patient_data['bookings'].keys()))
184
- analysis = analyze_with_agent(agent, prompt)
 
 
185
 
186
- # Save report to allowed directory
187
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
188
- report_path = os.path.join(REPORT_DIR, f"report_{timestamp}.md")
189
  with open(report_path, 'w') as f:
190
- f.write(analysis)
191
 
192
- return analysis, report_path
193
 
194
  except Exception as e:
195
- raise gr.Error(f"Analysis failed: {str(e)}")
196
 
197
  analyze_btn.click(
198
  analyze,
199
- inputs=file_input,
200
  outputs=[output, report]
201
  )
202
 
@@ -210,7 +284,7 @@ if __name__ == "__main__":
210
  server_name="0.0.0.0",
211
  server_port=7860,
212
  show_error=True,
213
- allowed_paths=[WORKING_DIR, REPORT_DIR] # Allow access to these paths
214
  )
215
  except Exception as e:
216
  print(f"Error: {str(e)}")
 
1
  import sys
2
  import os
3
  import pandas as pd
4
+ import json
5
  import gradio as gr
6
  from typing import List, Tuple, Dict, Any
7
  import hashlib
 
11
  import time
12
  from collections import defaultdict
13
 
14
+ # Configuration and setup
15
+ persistent_dir = "/data/hf_cache"
16
+ os.makedirs(persistent_dir, exist_ok=True)
 
17
 
18
+ model_cache_dir = os.path.join(persistent_dir, "txagent_models")
19
+ tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
20
+ file_cache_dir = os.path.join(persistent_dir, "cache")
21
+ report_dir = os.path.join(persistent_dir, "reports")
22
+
23
+ for directory in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
24
+ os.makedirs(directory, exist_ok=True)
25
+
26
+ os.environ["HF_HOME"] = model_cache_dir
27
+ os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
28
 
29
  current_dir = os.path.dirname(os.path.abspath(__file__))
30
  src_path = os.path.abspath(os.path.join(current_dir, "src"))
 
36
  MAX_TOKENS = 32768
37
  CHUNK_SIZE = 10000
38
  MAX_NEW_TOKENS = 2048
39
+ MAX_BOOKINGS_PER_CHUNK = 5
40
+
41
+ def file_hash(path: str) -> str:
42
+ with open(path, "rb") as f:
43
+ return hashlib.md5(f.read()).hexdigest()
44
 
45
  def clean_response(text: str) -> str:
46
+ try:
47
+ text = text.encode('utf-8', 'surrogatepass').decode('utf-8')
48
+ except UnicodeError:
49
+ text = text.encode('utf-8', 'replace').decode('utf-8')
50
+
51
  text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
52
  text = re.sub(r"\n{3,}", "\n\n", text)
53
+ text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
54
  return text.strip()
55
 
56
+ def estimate_tokens(text: str) -> int:
57
+ return len(text) // 3.5
58
+
59
  def process_patient_data(df: pd.DataFrame) -> Dict[str, Any]:
 
60
  data = {
61
  'bookings': defaultdict(list),
62
  'medications': defaultdict(list),
63
  'diagnoses': defaultdict(list),
64
  'tests': defaultdict(list),
65
+ 'procedures': defaultdict(list),
66
+ 'doctors': set(),
67
  'timeline': []
68
  }
69
 
 
82
 
83
  data['bookings'][booking].append(entry)
84
  data['timeline'].append(entry)
85
+ data['doctors'].add(entry['doctor'])
86
 
87
  form_lower = entry['form'].lower()
88
+ if 'medication' in form_lower or 'drug' in form_lower:
89
  data['medications'][entry['item']].append(entry)
90
+ elif 'diagnosis' in form_lower or 'condition' in form_lower:
91
  data['diagnoses'][entry['item']].append(entry)
92
+ elif 'test' in form_lower or 'lab' in form_lower or 'result' in form_lower:
93
  data['tests'][entry['item']].append(entry)
94
+ elif 'procedure' in form_lower or 'surgery' in form_lower:
95
+ data['procedures'][entry['item']].append(entry)
96
 
97
  return data
98
 
99
  def generate_analysis_prompt(patient_data: Dict[str, Any], bookings: List[str]) -> str:
100
+ prompt_lines = [
 
101
  "**Comprehensive Patient Analysis**",
102
  f"Analyzing {len(bookings)} bookings",
103
  "",
104
+ "**Key Analysis Points:**",
105
+ "- Chronological progression of symptoms",
106
+ "- Medication changes and interactions",
107
+ "- Diagnostic consistency across providers",
108
+ "- Missed diagnostic opportunities",
109
+ "- Gaps in follow-up",
110
+ "",
111
+ "**Patient Timeline:**"
112
  ]
113
 
114
  for entry in patient_data['timeline']:
115
  if entry['booking'] in bookings:
116
+ prompt_lines.append(
117
+ f"- {entry['date']}: {entry['form']} - {entry['item']} = {entry['response']} (by {entry['doctor']})"
118
+ )
119
 
120
+ prompt_lines.extend([
121
  "",
122
+ "**Medication History:**",
123
+ *[f"- {med}: " + " → ".join(
124
+ f"{e['date']}: {e['response']}"
125
+ for e in entries if e['booking'] in bookings
126
+ ) for med, entries in patient_data['medications'].items()],
127
  "",
128
+ "**Required Analysis Format:**",
129
+ "### Diagnostic Patterns",
130
+ "### Medication Analysis",
131
+ "### Provider Consistency",
132
+ "### Missed Opportunities",
133
+ "### Recommendations"
134
  ])
135
 
136
+ return "\n".join(prompt_lines)
137
+
138
+ def chunk_bookings(patient_data: Dict[str, Any]) -> List[List[str]]:
139
+ all_bookings = list(patient_data['bookings'].keys())
140
+ booking_sizes = []
141
+
142
+ for booking in all_bookings:
143
+ entries = patient_data['bookings'][booking]
144
+ size = sum(estimate_tokens(str(e)) for e in entries)
145
+ booking_sizes.append((booking, size))
146
+
147
+ booking_sizes.sort(key=lambda x: x[1], reverse=True)
148
+ chunks = [[] for _ in range(3)]
149
+ chunk_sizes = [0, 0, 0]
150
+
151
+ for booking, size in booking_sizes:
152
+ min_chunk = chunk_sizes.index(min(chunk_sizes))
153
+ chunks[min_chunk].append(booking)
154
+ chunk_sizes[min_chunk] += size
155
+
156
+ return chunks
157
 
158
  def init_agent():
159
+ default_tool_path = os.path.abspath("data/new_tool.json")
160
+ target_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
161
+
162
+ if not os.path.exists(target_tool_path):
163
+ shutil.copy(default_tool_path, target_tool_path)
164
 
165
+ agent = TxAgent(
166
  model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
167
  rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
168
+ tool_files_dict={"new_tool": target_tool_path},
169
  force_finish=True,
170
  enable_checker=True,
171
  step_rag_num=4,
172
  seed=100,
173
  additional_default_tools=[]
174
  )
175
+ agent.init_model()
176
+ return agent
177
 
178
  def analyze_with_agent(agent, prompt: str) -> str:
 
179
  try:
180
  response = ""
181
  for result in agent.run_gradio_chat(
 
187
  call_agent=False,
188
  conversation=[],
189
  ):
190
+ if isinstance(result, list):
191
+ for r in result:
192
+ if hasattr(r, 'content') and r.content:
193
+ response += clean_response(r.content) + "\n"
194
+ elif isinstance(result, str):
195
  response += clean_response(result) + "\n"
196
  elif hasattr(result, 'content'):
197
  response += clean_response(result.content) + "\n"
 
201
  return f"Error in analysis: {str(e)}"
202
 
203
  def create_ui(agent):
204
+ with gr.Blocks(theme=gr.themes.Soft(), title="Patient History Analyzer") as demo:
205
+ gr.Markdown("# 🏥 Patient History Analyzer")
206
 
207
  with gr.Tabs():
208
+ with gr.TabItem("Analysis"):
209
  with gr.Row():
210
+ with gr.Column(scale=1):
211
+ file_upload = gr.File(
212
+ label="Upload Excel File",
213
+ file_types=[".xlsx"],
214
+ file_count="single"
215
+ )
216
  analyze_btn = gr.Button("Analyze", variant="primary")
217
+ status = gr.Markdown("Ready")
218
 
219
+ with gr.Column(scale=2):
220
  output = gr.Markdown()
221
+ report = gr.File(label="Download Report")
222
 
223
  with gr.TabItem("Instructions"):
224
  gr.Markdown("""
225
+ ## How to Use
226
  1. Upload patient history Excel
227
  2. Click Analyze
228
+ 3. View/download report
229
 
230
  **Required Columns:**
231
  - Booking Number
232
  - Interview Date
233
  - Interviewer
234
  - Form Name
235
+ - Form Item
236
  - Item Response
237
  - Description
238
  """)
239
 
240
  def analyze(file):
241
  if not file:
242
+ raise gr.Error("Please upload a file")
243
 
244
  try:
 
245
  df = pd.read_excel(file.name)
246
  patient_data = process_patient_data(df)
247
+ chunks = chunk_bookings(patient_data)
248
+ full_report = []
249
+
250
+ for i, bookings in enumerate(chunks, 1):
251
+ prompt = generate_analysis_prompt(patient_data, bookings)
252
+ response = analyze_with_agent(agent, prompt)
253
+ full_report.append(f"## Chunk {i}\n{response}\n")
254
+ yield "\n".join(full_report), None
255
 
256
+ # Final summary
257
+ if len(chunks) > 1:
258
+ summary_prompt = "Create final summary combining all chunks"
259
+ summary = analyze_with_agent(agent, summary_prompt)
260
+ full_report.append(f"## Final Summary\n{summary}\n")
261
 
262
+ report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
 
 
263
  with open(report_path, 'w') as f:
264
+ f.write("\n".join(full_report))
265
 
266
+ yield "\n".join(full_report), report_path
267
 
268
  except Exception as e:
269
+ raise gr.Error(f"Error: {str(e)}")
270
 
271
  analyze_btn.click(
272
  analyze,
273
+ inputs=file_upload,
274
  outputs=[output, report]
275
  )
276
 
 
284
  server_name="0.0.0.0",
285
  server_port=7860,
286
  show_error=True,
287
+ allowed_paths=["/data/hf_cache/reports"]
288
  )
289
  except Exception as e:
290
  print(f"Error: {str(e)}")