Ali2206 commited on
Commit
a8c9181
·
verified ·
1 Parent(s): a135a34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +401 -28
app.py CHANGED
@@ -6,16 +6,20 @@ import re
6
  import gc
7
  import time
8
  from datetime import datetime
9
- from typing import List, Tuple, Dict, Union
10
  import pandas as pd
11
  import pdfplumber
12
- import gradio as gr
13
  import torch
14
  import matplotlib.pyplot as plt
15
  from fpdf import FPDF
16
  import unicodedata
 
 
 
 
17
 
18
  # === Configuration ===
 
19
  persistent_dir = "/data/hf_cache"
20
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
21
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
@@ -41,11 +45,45 @@ BATCH_SIZE = 1
41
  PROMPT_OVERHEAD = 300
42
  SAFE_SLEEP = 0.5
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def estimate_tokens(text: str) -> int:
45
  return len(text) // 4 + 1
46
 
47
  def clean_response(text: str) -> str:
48
- text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
49
  text = re.sub(r"\n{3,}", "\n\n", text)
50
  return text.strip()
51
 
@@ -60,29 +98,364 @@ def remove_duplicate_paragraphs(text: str) -> str:
60
  seen.add(clean_p)
61
  return "\n\n".join(unique_paragraphs)
62
 
63
- # === FastAPI for mobile API endpoint ===
64
- from fastapi import FastAPI, UploadFile, File
65
- from fastapi.responses import JSONResponse
66
- import uvicorn
67
-
68
- app = FastAPI()
69
-
70
- @app.post("/analyze")
71
- async def analyze_file_api(file: UploadFile = File(...)):
72
- agent = init_agent()
73
- temp_file_path = os.path.join(file_cache_dir, file.filename)
74
- with open(temp_file_path, "wb") as f:
75
- f.write(await file.read())
76
- messages = []
77
- messages, pdf_path = process_report(agent, open(temp_file_path, "rb"), messages)
78
- if pdf_path:
79
- return JSONResponse(content={"summary": messages[-2]['content'], "pdf": pdf_path})
80
- return JSONResponse(content={"error": "Processing failed."}, status_code=400)
81
-
82
- # === Original Gradio UI launch preserved ===
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if __name__ == "__main__":
84
- agent = init_agent()
85
- ui = create_ui(agent)
86
- import threading
87
- threading.Thread(target=lambda: ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)).start()
88
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
6
  import gc
7
  import time
8
  from datetime import datetime
9
+ from typing import List, Tuple, Dict, Union, Optional
10
  import pandas as pd
11
  import pdfplumber
 
12
  import torch
13
  import matplotlib.pyplot as plt
14
  from fpdf import FPDF
15
  import unicodedata
16
+ from fastapi import FastAPI, UploadFile, File, HTTPException
17
+ from fastapi.responses import FileResponse, JSONResponse
18
+ from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel
20
 
21
  # === Configuration ===
22
+
23
  persistent_dir = "/data/hf_cache"
24
  model_cache_dir = os.path.join(persistent_dir, "txagent_models")
25
  tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
 
45
  PROMPT_OVERHEAD = 300
46
  SAFE_SLEEP = 0.5
47
 
48
+ # === FastAPI App Setup ===
49
+ app = FastAPI(title="Clinical Patient Support System API",
50
+ description="API for analyzing and summarizing unstructured medical files")
51
+
52
+ # CORS configuration for mobile app access
53
+ app.add_middleware(
54
+ CORSMiddleware,
55
+ allow_origins=["*"],
56
+ allow_credentials=True,
57
+ allow_methods=["*"],
58
+ allow_headers=["*"],
59
+ )
60
+
61
+ # === Data Models ===
62
+ class AnalysisRequest(BaseModel):
63
+ """Request model for file analysis"""
64
+ filename: str
65
+ file_content: str # Base64 encoded file content (mobile apps can send this)
66
+
67
+ class AnalysisResponse(BaseModel):
68
+ """Response model for analysis results"""
69
+ status: str
70
+ message: str
71
+ report_id: Optional[str] = None
72
+ summary: Optional[str] = None
73
+ error: Optional[str] = None
74
+
75
+ class ReportResponse(BaseModel):
76
+ """Response model for report download"""
77
+ status: str
78
+ report_id: str
79
+ download_url: str
80
+
81
+ # === Helper Functions (same as original) ===
82
  def estimate_tokens(text: str) -> int:
83
  return len(text) // 4 + 1
84
 
85
  def clean_response(text: str) -> str:
86
+ text = re.sub(r"$.*?$|\bNone\b", "", text, flags=re.DOTALL)
87
  text = re.sub(r"\n{3,}", "\n\n", text)
88
  return text.strip()
89
 
 
98
  seen.add(clean_p)
99
  return "\n\n".join(unique_paragraphs)
100
 
101
+ def extract_text_from_excel(path: str) -> str:
102
+ all_text = []
103
+ xls = pd.ExcelFile(path)
104
+ for sheet_name in xls.sheet_names:
105
+ try:
106
+ df = xls.parse(sheet_name).astype(str).fillna("")
107
+ except Exception:
108
+ continue
109
+ for _, row in df.iterrows():
110
+ non_empty = [cell.strip() for cell in row if cell.strip()]
111
+ if len(non_empty) >= 2:
112
+ text_line = " | ".join(non_empty)
113
+ if len(text_line) > 15:
114
+ all_text.append(f"[{sheet_name}] {text_line}")
115
+ return "\n".join(all_text)
116
+
117
+ def extract_text_from_csv(path: str) -> str:
118
+ all_text = []
119
+ try:
120
+ df = pd.read_csv(path).astype(str).fillna("")
121
+ except Exception:
122
+ return ""
123
+ for _, row in df.iterrows():
124
+ non_empty = [cell.strip() for cell in row if cell.strip()]
125
+ if len(non_empty) >= 2:
126
+ text_line = " | ".join(non_empty)
127
+ if len(text_line) > 15:
128
+ all_text.append(text_line)
129
+ return "\n".join(all_text)
130
+
131
+ def extract_text_from_pdf(path: str) -> str:
132
+ import logging
133
+ logging.getLogger("pdfminer").setLevel(logging.ERROR)
134
+ all_text = []
135
+ try:
136
+ with pdfplumber.open(path) as pdf:
137
+ for page in pdf.pages:
138
+ text = page.extract_text()
139
+ if text:
140
+ all_text.append(text.strip())
141
+ except Exception:
142
+ return ""
143
+ return "\n".join(all_text)
144
+
145
+ def extract_text(file_path: str) -> str:
146
+ if file_path.endswith(".xlsx"):
147
+ return extract_text_from_excel(file_path)
148
+ elif file_path.endswith(".csv"):
149
+ return extract_text_from_csv(file_path)
150
+ elif file_path.endswith(".pdf"):
151
+ return extract_text_from_pdf(file_path)
152
+ else:
153
+ return ""
154
+
155
+ def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
156
+ effective_limit = max_tokens - PROMPT_OVERHEAD
157
+ chunks, current, current_tokens = [], [], 0
158
+ for line in text.split("\n"):
159
+ tokens = estimate_tokens(line)
160
+ if current_tokens + tokens > effective_limit:
161
+ if current:
162
+ chunks.append("\n".join(current))
163
+ current, current_tokens = [line], tokens
164
+ else:
165
+ current.append(line)
166
+ current_tokens += tokens
167
+ if current:
168
+ chunks.append("\n".join(current))
169
+ return chunks
170
+
171
+ def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
172
+ return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
173
+
174
+ def build_prompt(chunk: str) -> str:
175
+ return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
176
+
177
+ def init_agent() -> TxAgent:
178
+ tool_path = os.path.join(tool_cache_dir, "new_tool.json")
179
+ if not os.path.exists(tool_path):
180
+ shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
181
+ agent = TxAgent(
182
+ model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
183
+ rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
184
+ tool_files_dict={"new_tool": tool_path},
185
+ force_finish=True,
186
+ enable_checker=True,
187
+ step_rag_num=4,
188
+ seed=100
189
+ )
190
+ agent.init_model()
191
+ return agent
192
+
193
+ def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
194
+ results = []
195
+ for batch in batches:
196
+ prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
197
+ try:
198
+ batch_response = ""
199
+ for r in agent.run_gradio_chat(
200
+ message=prompt,
201
+ history=[],
202
+ temperature=0.0,
203
+ max_new_tokens=MAX_NEW_TOKENS,
204
+ max_token=MAX_MODEL_TOKENS,
205
+ call_agent=False,
206
+ conversation=[]
207
+ ):
208
+ if isinstance(r, str):
209
+ batch_response += r
210
+ elif isinstance(r, list):
211
+ for m in r:
212
+ if hasattr(m, "content"):
213
+ batch_response += m.content
214
+ elif hasattr(r, "content"):
215
+ batch_response += r.content
216
+ results.append(clean_response(batch_response))
217
+ time.sleep(SAFE_SLEEP)
218
+ except Exception as e:
219
+ results.append(f"❌ Batch failed: {str(e)}")
220
+ time.sleep(SAFE_SLEEP * 2)
221
+ torch.cuda.empty_cache()
222
+ gc.collect()
223
+ return results
224
+
225
+ def generate_final_summary(agent, combined: str) -> str:
226
+ combined = remove_duplicate_paragraphs(combined)
227
+ final_prompt = f"""
228
+ You are an expert clinical summarizer. Analyze the following summaries carefully and generate a **single final concise structured medical report**, avoiding any repetition or redundancy.
229
+ Summaries:
230
+ {combined}
231
+ Respond with:
232
+
233
+ * Diagnostic Patterns
234
+ * Medication Issues
235
+ * Missed Opportunities
236
+ * Inconsistencies
237
+ * Follow-up Recommendations
238
+ Avoid repeating the same points multiple times.
239
+ """.strip()
240
+
241
+ final_response = ""
242
+ for r in agent.run_gradio_chat(
243
+ message=final_prompt,
244
+ history=[],
245
+ temperature=0.0,
246
+ max_new_tokens=MAX_NEW_TOKENS,
247
+ max_token=MAX_MODEL_TOKENS,
248
+ call_agent=False,
249
+ conversation=[]
250
+ ):
251
+ if isinstance(r, str):
252
+ final_response += r
253
+ elif isinstance(r, list):
254
+ for m in r:
255
+ if hasattr(m, "content"):
256
+ final_response += m.content
257
+ elif hasattr(r, "content"):
258
+ final_response += r.content
259
+
260
+ final_response = clean_response(final_response)
261
+ final_response = remove_duplicate_paragraphs(final_response)
262
+ return final_response
263
+
264
+ def remove_non_ascii(text):
265
+ return ''.join(c for c in text if ord(c) < 256)
266
+
267
+ def generate_pdf_report_with_charts(summary: str, report_path: str, detailed_batches: List[str] = None):
268
+ chart_dir = os.path.join(os.path.dirname(report_path), "charts")
269
+ os.makedirs(chart_dir, exist_ok=True)
270
+
271
+ # Prepare data
272
+ categories = ['Diagnostics', 'Medications', 'Missed', 'Inconsistencies', 'Follow-up']
273
+ values = [4, 2, 3, 1, 5]
274
+
275
+ # Chart 1: Bar
276
+ bar_chart_path = os.path.join(chart_dir, "bar_chart.png")
277
+ plt.figure(figsize=(6, 4))
278
+ plt.bar(categories, values)
279
+ plt.title('Clinical Issues Overview')
280
+ plt.tight_layout()
281
+ plt.savefig(bar_chart_path)
282
+ plt.close()
283
+
284
+ # Chart 2: Pie
285
+ pie_chart_path = os.path.join(chart_dir, "pie_chart.png")
286
+ plt.figure(figsize=(6, 6))
287
+ plt.pie(values, labels=categories, autopct='%1.1f%%')
288
+ plt.title('Issue Distribution')
289
+ plt.tight_layout()
290
+ plt.savefig(pie_chart_path)
291
+ plt.close()
292
+
293
+ # Chart 3: Line
294
+ trend_chart_path = os.path.join(chart_dir, "trend_chart.png")
295
+ plt.figure(figsize=(6, 4))
296
+ plt.plot(categories, values, marker='o')
297
+ plt.title('Trend Analysis')
298
+ plt.tight_layout()
299
+ plt.savefig(trend_chart_path)
300
+ plt.close()
301
+
302
+ # PDF init
303
+ pdf_path = report_path.replace('.md', '.pdf')
304
+ pdf = FPDF()
305
+ pdf.set_auto_page_break(auto=True, margin=15)
306
+
307
+ # === Title Page ===
308
+ pdf.add_page()
309
+ pdf.set_font("Arial", 'B', 24)
310
+ pdf.cell(0, 20, remove_non_ascii("Final Medical Report"), ln=True, align='C')
311
+ pdf.set_font("Arial", '', 14)
312
+ pdf.cell(0, 10, datetime.now().strftime("Generated on %B %d, %Y at %H:%M"), ln=True, align='C')
313
+ pdf.ln(20)
314
+ pdf.set_font("Arial", 'I', 12)
315
+ pdf.multi_cell(0, 10, remove_non_ascii(
316
+ "This report contains a professional summary of clinical observations, potential inconsistencies, and follow-up recommendations based on the uploaded medical document."
317
+ ), align="C")
318
+
319
+ # === Summary Section ===
320
+ pdf.add_page()
321
+ pdf.set_font("Arial", 'B', 16)
322
+ pdf.cell(0, 10, remove_non_ascii("Final Summary"), ln=True)
323
+ pdf.set_draw_color(200, 200, 200)
324
+ pdf.line(10, pdf.get_y(), 200, pdf.get_y())
325
+ pdf.ln(5)
326
+ pdf.set_font("Arial", '', 12)
327
+ for line in summary.split("\n"):
328
+ clean_line = remove_non_ascii(line.strip())
329
+ if clean_line:
330
+ pdf.multi_cell(0, 8, txt=clean_line)
331
+
332
+ # === Charts Section ===
333
+ pdf.add_page()
334
+ pdf.set_font("Arial", 'B', 16)
335
+ pdf.cell(0, 10, remove_non_ascii("Statistical Overview"), ln=True)
336
+ pdf.line(10, pdf.get_y(), 200, pdf.get_y())
337
+ pdf.ln(5)
338
+
339
+ pdf.set_font("Arial", 'B', 12)
340
+ pdf.cell(0, 10, remove_non_ascii("1. Clinical Issues Overview"), ln=True)
341
+ pdf.image(bar_chart_path, w=180)
342
+ pdf.ln(5)
343
+
344
+ pdf.cell(0, 10, remove_non_ascii("2. Issue Distribution"), ln=True)
345
+ pdf.image(pie_chart_path, w=150)
346
+ pdf.ln(5)
347
+
348
+ pdf.cell(0, 10, remove_non_ascii("3. Trend Analysis"), ln=True)
349
+ pdf.image(trend_chart_path, w=180)
350
+
351
+ # === Detailed Tool Outputs ===
352
+ if detailed_batches:
353
+ pdf.add_page()
354
+ pdf.set_font("Arial", 'B', 16)
355
+ pdf.cell(0, 10, remove_non_ascii("Detailed Tool Insights"), ln=True)
356
+ pdf.line(10, pdf.get_y(), 200, pdf.get_y())
357
+ pdf.ln(5)
358
+
359
+ for idx, detail in enumerate(detailed_batches):
360
+ pdf.set_font("Arial", 'B', 13)
361
+ pdf.cell(0, 10, remove_non_ascii(f"Tool Output #{idx + 1}"), ln=True)
362
+ pdf.set_font("Arial", '', 11)
363
+ for line in remove_non_ascii(detail).split("\n"):
364
+ pdf.multi_cell(0, 8, txt=line.strip())
365
+ pdf.ln(3)
366
+
367
+ pdf.output(pdf_path)
368
+ return pdf_path
369
+
370
+ # === API Endpoints ===
371
+ @app.post("/analyze", response_model=AnalysisResponse)
372
+ async def analyze_file(file: UploadFile = File(...)):
373
+ """Endpoint for analyzing medical files"""
374
+ try:
375
+ start_time = time.time()
376
+
377
+ # Save the uploaded file temporarily
378
+ temp_path = os.path.join(file_cache_dir, file.filename)
379
+ with open(temp_path, "wb") as f:
380
+ f.write(await file.read())
381
+
382
+ # Generate a unique report ID
383
+ report_id = f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
384
+
385
+ # Initialize agent (could be done once at startup)
386
+ agent = init_agent()
387
+
388
+ # Process the file
389
+ extracted = extract_text(temp_path)
390
+ if not extracted:
391
+ raise HTTPException(status_code=400, detail="Could not extract text from file")
392
+
393
+ chunks = split_text(extracted)
394
+ batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
395
+ batch_results = analyze_batches(agent, batches)
396
+ all_tool_outputs = batch_results.copy()
397
+ valid = [res for res in batch_results if not res.startswith("❌")]
398
+
399
+ if not valid:
400
+ raise HTTPException(status_code=400, detail="No valid batch outputs generated")
401
+
402
+ summary = generate_final_summary(agent, "\n\n".join(valid))
403
+
404
+ # Save report files
405
+ report_path = os.path.join(report_dir, f"{report_id}.md")
406
+ with open(report_path, 'w', encoding='utf-8') as f:
407
+ f.write(f"# Final Medical Report\n\n{summary}")
408
+
409
+ pdf_path = generate_pdf_report_with_charts(summary, report_path, detailed_batches=all_tool_outputs)
410
+
411
+ end_time = time.time()
412
+ elapsed_time = end_time - start_time
413
+
414
+ # Clean up temp file
415
+ os.remove(temp_path)
416
+
417
+ return {
418
+ "status": "success",
419
+ "message": f"Report generated in {elapsed_time:.2f} seconds",
420
+ "report_id": report_id,
421
+ "summary": summary
422
+ }
423
+
424
+ except Exception as e:
425
+ raise HTTPException(status_code=500, detail=str(e))
426
+
427
+ @app.get("/report/{report_id}", response_model=ReportResponse)
428
+ async def get_report(report_id: str):
429
+ """Endpoint for downloading generated reports"""
430
+ pdf_path = os.path.join(report_dir, f"{report_id}.pdf")
431
+ if not os.path.exists(pdf_path):
432
+ raise HTTPException(status_code=404, detail="Report not found")
433
+
434
+ return {
435
+ "status": "success",
436
+ "report_id": report_id,
437
+ "download_url": f"/download/{report_id}"
438
+ }
439
+
440
+ @app.get("/download/{report_id}")
441
+ async def download_report(report_id: str):
442
+ """Endpoint for actual file download"""
443
+ pdf_path = os.path.join(report_dir, f"{report_id}.pdf")
444
+ if not os.path.exists(pdf_path):
445
+ raise HTTPException(status_code=404, detail="Report not found")
446
+
447
+ return FileResponse(
448
+ pdf_path,
449
+ media_type="application/pdf",
450
+ filename=f"medical_report_{report_id}.pdf"
451
+ )
452
+
453
+ @app.get("/health")
454
+ async def health_check():
455
+ """Health check endpoint"""
456
+ return {"status": "healthy"}
457
+
458
+ # === Main Application ===
459
  if __name__ == "__main__":
460
+ import uvicorn
461
+ uvicorn.run(app, host="0.0.0.0", port=8000)