File size: 8,668 Bytes
aa559b4 f75a23b f394b25 d184610 a57b988 f394b25 1a611b9 d16299c 1c5bd8e da7f195 1a611b9 aa559b4 da7f195 a4b1ab0 aa559b4 da7f195 aa559b4 da7f195 d8282f1 aa559b4 f6e551c a57b988 f6e551c 3ed8d49 4bfbcac 0fb33af f75a23b aa559b4 62ef904 8b1bbeb 1244d40 7a8204e f6e551c aa559b4 d16299c f6e551c d16299c aa559b4 a57b988 aa559b4 7771dd9 1a611b9 ad85a12 1a611b9 8b1bbeb aa559b4 7771dd9 1a611b9 ad85a12 3ed8d49 1a611b9 ad85a12 1a611b9 ad85a12 aa559b4 ad85a12 a57b988 7771dd9 0e6914c 7771dd9 3ed8d49 7771dd9 a57b988 aa559b4 7771dd9 1a611b9 aa559b4 a4b1ab0 aa559b4 73810ec 1a611b9 73810ec a4b1ab0 aa559b4 a4b1ab0 aa559b4 a4b1ab0 1a611b9 aa559b4 1a611b9 a57b988 aa559b4 7771dd9 1a611b9 6762641 aa559b4 1a611b9 aa559b4 1a611b9 aa559b4 6762641 aa559b4 7771dd9 1a611b9 7771dd9 1a611b9 7771dd9 1a611b9 a71a831 55e3db0 aa559b4 abd27cc d8282f1 a57b988 1a611b9 d8282f1 1a611b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
# β
Fully updated app.py for TxAgent with strict tool validation to prevent runtime errors
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple, Union, Generator, Dict, Any
import re
from datetime import datetime
import atexit
import torch.distributed as dist
import logging
# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("app")
# Cleanup
def cleanup():
if dist.is_initialized():
logger.info("Cleaning up PyTorch distributed process group")
dist.destroy_process_group()
atexit.register(cleanup)
# Directories
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(d, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
# Import TxAgent
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
MAX_MODEL_TOKENS = 32768
MAX_CHUNK_TOKENS = 8192
MAX_NEW_TOKENS = 2048
PROMPT_OVERHEAD = 500
def clean_response(text: str) -> str:
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
return text.strip()
def estimate_tokens(text: str) -> int:
return len(text) // 3.5 + 1
def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
if isinstance(file_obj, dict) and 'name' in file_obj:
file_path = file_obj['name']
elif isinstance(file_obj, str):
file_path = file_obj
else:
raise ValueError("Unsupported file input type")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
xls = pd.ExcelFile(file_path)
all_text = []
for sheet in xls.sheet_names:
try:
df = xls.parse(sheet).astype(str).fillna("")
rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1)
sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()]
all_text.extend(sheet_text)
except Exception as e:
logger.warning(f"Failed to parse {sheet}: {e}")
return "\n".join(all_text)
def split_text_into_chunks(text: str) -> List[str]:
lines = text.split("\n")
chunks, current, current_tokens = [], [], 0
max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
for line in lines:
t = estimate_tokens(line)
if current_tokens + t > max_tokens:
chunks.append("\n".join(current))
current, current_tokens = [line], t
else:
current.append(line)
current_tokens += t
if current:
chunks.append("\n".join(current))
return chunks
def build_prompt_from_text(chunk: str) -> str:
return f"""
### Clinical Records Analysis
Please analyze these clinical notes and provide:
- Key diagnostic indicators
- Current medications and potential issues
- Recommended follow-up actions
- Any inconsistencies or concerns
---
{chunk}
---
Provide a structured response with clear medical reasoning.
"""
def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool:
try:
with open(original_path, "r") as f:
data = json.load(f)
if isinstance(data, dict) and "tools" in data:
tools = data["tools"]
elif isinstance(data, list):
tools = data
elif isinstance(data, dict) and "name" in data:
tools = [data]
else:
return False
if not all(isinstance(t, dict) and "name" in t for t in tools):
return False
with open(cleaned_path, "w") as out:
json.dump(tools, out)
return True
except Exception as e:
logger.error(f"Failed to clean tool {original_path}: {e}")
return False
def init_agent() -> TxAgent:
new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(new_tool_path):
with open(new_tool_path, 'w') as f:
json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f)
raw_tool_files = {
'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
'new_tool': new_tool_path
}
validated_paths = {}
for name, original_path in raw_tool_files.items():
cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json")
if clean_and_rewrite_tool_file(original_path, cleaned_path):
validated_paths[name] = cleaned_path
if not validated_paths:
raise ValueError("No valid tools found after sanitizing.")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=validated_paths,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
return agent
def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
accumulated = ""
try:
if input_file is None:
yield "β Upload a valid Excel file.", None, ""
return
text = extract_text_from_excel(input_file)
chunks = split_text_into_chunks(text)
for i, chunk in enumerate(chunks):
prompt = build_prompt_from_text(chunk)
result = ""
for out in agent.run_gradio_chat(
message=prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]):
result += out if isinstance(out, str) else out.content
cleaned = clean_response(result)
accumulated += f"\n\nπ Part {i+1}:\n{cleaned}"
yield accumulated, None, ""
summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
summary = ""
for out in agent.run_gradio_chat(
message=summary_prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]):
summary += out if isinstance(out, str) else out.content
final = clean_response(summary)
report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
with open(report_path, 'w') as f:
f.write(f"# Clinical Report\n\n{final}")
yield f"{accumulated}\n\nπ Final Summary:\n{final}", report_path, final
except Exception as e:
logger.error(f"Stream error: {e}", exc_info=True)
yield f"β Error: {str(e)}", None, ""
def create_ui(agent: TxAgent) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π₯ Clinical Records Analyzer")
with gr.Row():
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Row():
with gr.Column(scale=2):
report_output = gr.Markdown()
with gr.Column(scale=1):
report_file = gr.File(label="Download", visible=False)
full_output = gr.State()
analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
return demo
if __name__ == "__main__":
try:
agent = init_agent()
demo = create_ui(agent)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
except Exception as e:
logger.error(f"App error: {e}", exc_info=True)
print(f"β Application error: {e}", file=sys.stderr)
sys.exit(1)
|