CPS-Test-Mobile / app.py
Ali2206's picture
Update app.py
aa559b4 verified
raw
history blame
8.67 kB
# βœ… Fully updated app.py for TxAgent with strict tool validation to prevent runtime errors
import sys
import os
import pandas as pd
import json
import gradio as gr
from typing import List, Tuple, Union, Generator, Dict, Any
import re
from datetime import datetime
import atexit
import torch.distributed as dist
import logging
# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("app")
# Cleanup
def cleanup():
if dist.is_initialized():
logger.info("Cleaning up PyTorch distributed process group")
dist.destroy_process_group()
atexit.register(cleanup)
# Directories
persistent_dir = "/data/hf_cache"
os.makedirs(persistent_dir, exist_ok=True)
model_cache_dir = os.path.join(persistent_dir, "txagent_models")
tool_cache_dir = os.path.join(persistent_dir, "tool_cache")
file_cache_dir = os.path.join(persistent_dir, "cache")
report_dir = os.path.join(persistent_dir, "reports")
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]:
os.makedirs(d, exist_ok=True)
os.environ["HF_HOME"] = model_cache_dir
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir
# Import TxAgent
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src")))
from txagent.txagent import TxAgent
MAX_MODEL_TOKENS = 32768
MAX_CHUNK_TOKENS = 8192
MAX_NEW_TOKENS = 2048
PROMPT_OVERHEAD = 500
def clean_response(text: str) -> str:
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL)
text = re.sub(r"\n{3,}", "\n\n", text)
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text)
return text.strip()
def estimate_tokens(text: str) -> int:
return len(text) // 3.5 + 1
def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str:
if isinstance(file_obj, dict) and 'name' in file_obj:
file_path = file_obj['name']
elif isinstance(file_obj, str):
file_path = file_obj
else:
raise ValueError("Unsupported file input type")
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
xls = pd.ExcelFile(file_path)
all_text = []
for sheet in xls.sheet_names:
try:
df = xls.parse(sheet).astype(str).fillna("")
rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1)
sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()]
all_text.extend(sheet_text)
except Exception as e:
logger.warning(f"Failed to parse {sheet}: {e}")
return "\n".join(all_text)
def split_text_into_chunks(text: str) -> List[str]:
lines = text.split("\n")
chunks, current, current_tokens = [], [], 0
max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD
for line in lines:
t = estimate_tokens(line)
if current_tokens + t > max_tokens:
chunks.append("\n".join(current))
current, current_tokens = [line], t
else:
current.append(line)
current_tokens += t
if current:
chunks.append("\n".join(current))
return chunks
def build_prompt_from_text(chunk: str) -> str:
return f"""
### Clinical Records Analysis
Please analyze these clinical notes and provide:
- Key diagnostic indicators
- Current medications and potential issues
- Recommended follow-up actions
- Any inconsistencies or concerns
---
{chunk}
---
Provide a structured response with clear medical reasoning.
"""
def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool:
try:
with open(original_path, "r") as f:
data = json.load(f)
if isinstance(data, dict) and "tools" in data:
tools = data["tools"]
elif isinstance(data, list):
tools = data
elif isinstance(data, dict) and "name" in data:
tools = [data]
else:
return False
if not all(isinstance(t, dict) and "name" in t for t in tools):
return False
with open(cleaned_path, "w") as out:
json.dump(tools, out)
return True
except Exception as e:
logger.error(f"Failed to clean tool {original_path}: {e}")
return False
def init_agent() -> TxAgent:
new_tool_path = os.path.join(tool_cache_dir, "new_tool.json")
if not os.path.exists(new_tool_path):
with open(new_tool_path, 'w') as f:
json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f)
raw_tool_files = {
'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json',
'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json',
'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json',
'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json',
'new_tool': new_tool_path
}
validated_paths = {}
for name, original_path in raw_tool_files.items():
cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json")
if clean_and_rewrite_tool_file(original_path, cleaned_path):
validated_paths[name] = cleaned_path
if not validated_paths:
raise ValueError("No valid tools found after sanitizing.")
agent = TxAgent(
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B",
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B",
tool_files_dict=validated_paths,
force_finish=True,
enable_checker=True,
step_rag_num=4,
seed=100
)
agent.init_model()
return agent
def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]:
accumulated = ""
try:
if input_file is None:
yield "❌ Upload a valid Excel file.", None, ""
return
text = extract_text_from_excel(input_file)
chunks = split_text_into_chunks(text)
for i, chunk in enumerate(chunks):
prompt = build_prompt_from_text(chunk)
result = ""
for out in agent.run_gradio_chat(
message=prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]):
result += out if isinstance(out, str) else out.content
cleaned = clean_response(result)
accumulated += f"\n\nπŸ“„ Part {i+1}:\n{cleaned}"
yield accumulated, None, ""
summary_prompt = f"Summarize this analysis:\n\n{accumulated}"
summary = ""
for out in agent.run_gradio_chat(
message=summary_prompt, history=[], temperature=0.2,
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS,
call_agent=False, conversation=[]):
summary += out if isinstance(out, str) else out.content
final = clean_response(summary)
report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md")
with open(report_path, 'w') as f:
f.write(f"# Clinical Report\n\n{final}")
yield f"{accumulated}\n\nπŸ“Š Final Summary:\n{final}", report_path, final
except Exception as e:
logger.error(f"Stream error: {e}", exc_info=True)
yield f"❌ Error: {str(e)}", None, ""
def create_ui(agent: TxAgent) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ₯ Clinical Records Analyzer")
with gr.Row():
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"])
analyze_btn = gr.Button("Analyze", variant="primary")
with gr.Row():
with gr.Column(scale=2):
report_output = gr.Markdown()
with gr.Column(scale=1):
report_file = gr.File(label="Download", visible=False)
full_output = gr.State()
analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output])
return demo
if __name__ == "__main__":
try:
agent = init_agent()
demo = create_ui(agent)
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
except Exception as e:
logger.error(f"App error: {e}", exc_info=True)
print(f"❌ Application error: {e}", file=sys.stderr)
sys.exit(1)