Ali2206 commited on
Commit
a135a34
·
verified ·
1 Parent(s): 8955687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -153
app.py CHANGED
@@ -60,162 +60,29 @@ def remove_duplicate_paragraphs(text: str) -> str:
60
  seen.add(clean_p)
61
  return "\n\n".join(unique_paragraphs)
62
 
63
- def extract_text(file_path: str) -> str:
64
- if file_path.endswith(".xlsx"):
65
- return pd.read_excel(file_path).astype(str).fillna("").to_string(index=False)
66
- elif file_path.endswith(".csv"):
67
- return pd.read_csv(file_path).astype(str).fillna("").to_string(index=False)
68
- elif file_path.endswith(".pdf"):
69
- try:
70
- with pdfplumber.open(file_path) as pdf:
71
- return "\n".join(page.extract_text() or '' for page in pdf.pages)
72
- except Exception:
73
- return ""
74
- else:
75
- return ""
76
 
77
- def split_text(text: str, max_tokens=MAX_CHUNK_TOKENS) -> List[str]:
78
- effective_limit = max_tokens - PROMPT_OVERHEAD
79
- chunks, current, current_tokens = [], [], 0
80
- for line in text.split("\n"):
81
- tokens = estimate_tokens(line)
82
- if current_tokens + tokens > effective_limit:
83
- if current:
84
- chunks.append("\n".join(current))
85
- current, current_tokens = [line], tokens
86
- else:
87
- current.append(line)
88
- current_tokens += tokens
89
- if current:
90
- chunks.append("\n".join(current))
91
- return chunks
92
 
93
- def batch_chunks(chunks: List[str], batch_size: int = BATCH_SIZE) -> List[List[str]]:
94
- return [chunks[i:i+batch_size] for i in range(0, len(chunks), batch_size)]
95
-
96
- def build_prompt(chunk: str) -> str:
97
- return f"""### Unstructured Clinical Records\n\nAnalyze the clinical notes below and summarize with:\n- Diagnostic Patterns\n- Medication Issues\n- Missed Opportunities\n- Inconsistencies\n- Follow-up Recommendations\n\n---\n\n{chunk}\n\n---\nRespond concisely in bullet points with clinical reasoning."""
98
-
99
- def remove_non_ascii(text):
100
- return ''.join(c for c in text if ord(c) < 256)
101
-
102
- def init_agent() -> TxAgent:
103
- tool_path = os.path.join(tool_cache_dir, "new_tool.json")
104
- if not os.path.exists(tool_path):
105
- shutil.copy(os.path.abspath("data/new_tool.json"), tool_path)
106
- agent = TxAgent(
107
- model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
108
- rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
109
- tool_files_dict={"new_tool": tool_path},
110
- force_finish=True,
111
- enable_checker=True,
112
- step_rag_num=4,
113
- seed=100
114
- )
115
- agent.init_model()
116
- return agent
117
-
118
- def analyze_batches(agent, batches: List[List[str]]) -> List[str]:
119
- results = []
120
- for batch in batches:
121
- prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
122
- try:
123
- batch_response = ""
124
- for r in agent.run_gradio_chat(
125
- message=prompt,
126
- history=[],
127
- temperature=0.0,
128
- max_new_tokens=MAX_NEW_TOKENS,
129
- max_token=MAX_MODEL_TOKENS,
130
- call_agent=False,
131
- conversation=[]
132
- ):
133
- if isinstance(r, str):
134
- batch_response += r
135
- elif isinstance(r, list):
136
- for m in r:
137
- if hasattr(m, "content"):
138
- batch_response += m.content
139
- elif hasattr(r, "content"):
140
- batch_response += r.content
141
- results.append(clean_response(batch_response))
142
- time.sleep(SAFE_SLEEP)
143
- except Exception as e:
144
- results.append(f"❌ Batch failed: {str(e)}")
145
- time.sleep(SAFE_SLEEP * 2)
146
- torch.cuda.empty_cache()
147
- gc.collect()
148
- return results
149
-
150
- def generate_final_summary(agent, combined: str) -> str:
151
- combined = remove_duplicate_paragraphs(combined)
152
- final_prompt = f"""
153
- You are an expert clinical summarizer. Analyze the following summaries carefully and generate a **single final concise structured medical report**, avoiding any repetition or redundancy.
154
-
155
- Summaries:
156
- {combined}
157
-
158
- Respond with:
159
- - Diagnostic Patterns
160
- - Medication Issues
161
- - Missed Opportunities
162
- - Inconsistencies
163
- - Follow-up Recommendations
164
- Avoid repeating the same points multiple times.
165
- """.strip()
166
-
167
- final_response = ""
168
- for r in agent.run_gradio_chat(
169
- message=final_prompt,
170
- history=[],
171
- temperature=0.0,
172
- max_new_tokens=MAX_NEW_TOKENS,
173
- max_token=MAX_MODEL_TOKENS,
174
- call_agent=False,
175
- conversation=[]
176
- ):
177
- if isinstance(r, str):
178
- final_response += r
179
- elif isinstance(r, list):
180
- for m in r:
181
- if hasattr(m, "content"):
182
- final_response += m.content
183
- elif hasattr(r, "content"):
184
- final_response += r.content
185
-
186
- final_response = clean_response(final_response)
187
- final_response = remove_duplicate_paragraphs(final_response)
188
- return final_response
189
-
190
- def handle_analysis(file):
191
  messages = []
192
- if not file or not hasattr(file, "name"):
193
- return "❌ Please upload a valid file.", None
194
- try:
195
- extracted = extract_text(file.name)
196
- if not extracted:
197
- return "❌ Could not extract text.", None
198
-
199
- chunks = split_text(extracted)
200
- batches = batch_chunks(chunks, batch_size=BATCH_SIZE)
201
- batch_results = analyze_batches(agent, batches)
202
- valid = [res for res in batch_results if not res.startswith("❌")]
203
-
204
- if not valid:
205
- return "❌ No valid batch outputs.", None
206
-
207
- summary = generate_final_summary(agent, "\n\n".join(valid))
208
- report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
209
- with open(report_path, 'w', encoding='utf-8') as f:
210
- f.write(summary)
211
- return summary, report_path
212
- except Exception as e:
213
- return f"❌ Error: {str(e)}", None
214
 
 
215
  if __name__ == "__main__":
216
  agent = init_agent()
217
- gr.Interface(
218
- fn=handle_analysis,
219
- inputs=gr.File(file_types=[".pdf", ".csv", ".xlsx"]),
220
- outputs=[gr.Textbox(label="Summary"), gr.File(label="Download Report")]
221
- ).queue().launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)
 
60
  seen.add(clean_p)
61
  return "\n\n".join(unique_paragraphs)
62
 
63
+ # === FastAPI for mobile API endpoint ===
64
+ from fastapi import FastAPI, UploadFile, File
65
+ from fastapi.responses import JSONResponse
66
+ import uvicorn
 
 
 
 
 
 
 
 
 
67
 
68
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ @app.post("/analyze")
71
+ async def analyze_file_api(file: UploadFile = File(...)):
72
+ agent = init_agent()
73
+ temp_file_path = os.path.join(file_cache_dir, file.filename)
74
+ with open(temp_file_path, "wb") as f:
75
+ f.write(await file.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  messages = []
77
+ messages, pdf_path = process_report(agent, open(temp_file_path, "rb"), messages)
78
+ if pdf_path:
79
+ return JSONResponse(content={"summary": messages[-2]['content'], "pdf": pdf_path})
80
+ return JSONResponse(content={"error": "Processing failed."}, status_code=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # === Original Gradio UI launch preserved ===
83
  if __name__ == "__main__":
84
  agent = init_agent()
85
+ ui = create_ui(agent)
86
+ import threading
87
+ threading.Thread(target=lambda: ui.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=["/data/hf_cache/reports"], share=False)).start()
88
+ uvicorn.run(app, host="0.0.0.0", port=8000)