|
|
|
|
|
import sys |
|
import os |
|
import pandas as pd |
|
import json |
|
import gradio as gr |
|
from typing import List, Tuple, Union, Generator, Dict, Any |
|
import re |
|
from datetime import datetime |
|
import atexit |
|
import torch.distributed as dist |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger("app") |
|
|
|
|
|
|
|
def cleanup(): |
|
if dist.is_initialized(): |
|
logger.info("Cleaning up PyTorch distributed process group") |
|
dist.destroy_process_group() |
|
|
|
atexit.register(cleanup) |
|
|
|
|
|
persistent_dir = "/data/hf_cache" |
|
os.makedirs(persistent_dir, exist_ok=True) |
|
model_cache_dir = os.path.join(persistent_dir, "txagent_models") |
|
tool_cache_dir = os.path.join(persistent_dir, "tool_cache") |
|
file_cache_dir = os.path.join(persistent_dir, "cache") |
|
report_dir = os.path.join(persistent_dir, "reports") |
|
for d in [model_cache_dir, tool_cache_dir, file_cache_dir, report_dir]: |
|
os.makedirs(d, exist_ok=True) |
|
os.environ["HF_HOME"] = model_cache_dir |
|
os.environ["TRANSFORMERS_CACHE"] = model_cache_dir |
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "src"))) |
|
from txagent.txagent import TxAgent |
|
|
|
MAX_MODEL_TOKENS = 32768 |
|
MAX_CHUNK_TOKENS = 8192 |
|
MAX_NEW_TOKENS = 2048 |
|
PROMPT_OVERHEAD = 500 |
|
|
|
|
|
def clean_response(text: str) -> str: |
|
text = re.sub(r"\[.*?\]|\bNone\b", "", text, flags=re.DOTALL) |
|
text = re.sub(r"\n{3,}", "\n\n", text) |
|
text = re.sub(r"[^\n#\-\*\w\s\.,:\(\)]+", "", text) |
|
return text.strip() |
|
|
|
|
|
def estimate_tokens(text: str) -> int: |
|
return len(text) // 3.5 + 1 |
|
|
|
|
|
def extract_text_from_excel(file_obj: Union[str, Dict[str, Any]]) -> str: |
|
if isinstance(file_obj, dict) and 'name' in file_obj: |
|
file_path = file_obj['name'] |
|
elif isinstance(file_obj, str): |
|
file_path = file_obj |
|
else: |
|
raise ValueError("Unsupported file input type") |
|
if not os.path.exists(file_path): |
|
raise FileNotFoundError(f"File not found: {file_path}") |
|
xls = pd.ExcelFile(file_path) |
|
all_text = [] |
|
for sheet in xls.sheet_names: |
|
try: |
|
df = xls.parse(sheet).astype(str).fillna("") |
|
rows = df.apply(lambda r: " | ".join([c for c in r if c.strip()]), axis=1) |
|
sheet_text = [f"[{sheet}] {line}" for line in rows if line.strip()] |
|
all_text.extend(sheet_text) |
|
except Exception as e: |
|
logger.warning(f"Failed to parse {sheet}: {e}") |
|
return "\n".join(all_text) |
|
|
|
|
|
def split_text_into_chunks(text: str) -> List[str]: |
|
lines = text.split("\n") |
|
chunks, current, current_tokens = [], [], 0 |
|
max_tokens = MAX_CHUNK_TOKENS - PROMPT_OVERHEAD |
|
for line in lines: |
|
t = estimate_tokens(line) |
|
if current_tokens + t > max_tokens: |
|
chunks.append("\n".join(current)) |
|
current, current_tokens = [line], t |
|
else: |
|
current.append(line) |
|
current_tokens += t |
|
if current: |
|
chunks.append("\n".join(current)) |
|
return chunks |
|
|
|
|
|
def build_prompt_from_text(chunk: str) -> str: |
|
return f""" |
|
### Clinical Records Analysis |
|
|
|
Please analyze these clinical notes and provide: |
|
- Key diagnostic indicators |
|
- Current medications and potential issues |
|
- Recommended follow-up actions |
|
- Any inconsistencies or concerns |
|
|
|
--- |
|
|
|
{chunk} |
|
|
|
--- |
|
Provide a structured response with clear medical reasoning. |
|
""" |
|
|
|
|
|
def clean_and_rewrite_tool_file(original_path: str, cleaned_path: str) -> bool: |
|
try: |
|
with open(original_path, "r") as f: |
|
data = json.load(f) |
|
if isinstance(data, dict) and "tools" in data: |
|
tools = data["tools"] |
|
elif isinstance(data, list): |
|
tools = data |
|
elif isinstance(data, dict) and "name" in data: |
|
tools = [data] |
|
else: |
|
return False |
|
if not all(isinstance(t, dict) and "name" in t for t in tools): |
|
return False |
|
with open(cleaned_path, "w") as out: |
|
json.dump(tools, out) |
|
return True |
|
except Exception as e: |
|
logger.error(f"Failed to clean tool {original_path}: {e}") |
|
return False |
|
|
|
|
|
def init_agent() -> TxAgent: |
|
new_tool_path = os.path.join(tool_cache_dir, "new_tool.json") |
|
if not os.path.exists(new_tool_path): |
|
with open(new_tool_path, 'w') as f: |
|
json.dump([{"name": "dummy_tool", "description": "test", "version": "1.0"}], f) |
|
|
|
raw_tool_files = { |
|
'opentarget': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/opentarget_tools.json', |
|
'fda_drug_label': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/fda_drug_labeling_tools.json', |
|
'special_tools': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/special_tools.json', |
|
'monarch': '/home/user/.pyenv/versions/3.10.17/lib/python3.10/site-packages/tooluniverse/data/monarch_tools.json', |
|
'new_tool': new_tool_path |
|
} |
|
|
|
validated_paths = {} |
|
for name, original_path in raw_tool_files.items(): |
|
cleaned_path = os.path.join(tool_cache_dir, f"{name}_cleaned.json") |
|
if clean_and_rewrite_tool_file(original_path, cleaned_path): |
|
validated_paths[name] = cleaned_path |
|
|
|
if not validated_paths: |
|
raise ValueError("No valid tools found after sanitizing.") |
|
|
|
agent = TxAgent( |
|
model_name="mims-harvard/TxAgent-T1-Llama-3.1-8B", |
|
rag_model_name="mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B", |
|
tool_files_dict=validated_paths, |
|
force_finish=True, |
|
enable_checker=True, |
|
step_rag_num=4, |
|
seed=100 |
|
) |
|
agent.init_model() |
|
return agent |
|
|
|
|
|
def stream_report(agent: TxAgent, input_file: Union[str, Dict[str, Any]], full_output: str) -> Generator[Tuple[str, Union[str, None], str], None, None]: |
|
accumulated = "" |
|
try: |
|
if input_file is None: |
|
yield "β Upload a valid Excel file.", None, "" |
|
return |
|
text = extract_text_from_excel(input_file) |
|
chunks = split_text_into_chunks(text) |
|
for i, chunk in enumerate(chunks): |
|
prompt = build_prompt_from_text(chunk) |
|
result = "" |
|
for out in agent.run_gradio_chat( |
|
message=prompt, history=[], temperature=0.2, |
|
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, |
|
call_agent=False, conversation=[]): |
|
result += out if isinstance(out, str) else out.content |
|
cleaned = clean_response(result) |
|
accumulated += f"\n\nπ Part {i+1}:\n{cleaned}" |
|
yield accumulated, None, "" |
|
summary_prompt = f"Summarize this analysis:\n\n{accumulated}" |
|
summary = "" |
|
for out in agent.run_gradio_chat( |
|
message=summary_prompt, history=[], temperature=0.2, |
|
max_new_tokens=MAX_NEW_TOKENS, max_token=MAX_MODEL_TOKENS, |
|
call_agent=False, conversation=[]): |
|
summary += out if isinstance(out, str) else out.content |
|
final = clean_response(summary) |
|
report_path = os.path.join(report_dir, f"report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md") |
|
with open(report_path, 'w') as f: |
|
f.write(f"# Clinical Report\n\n{final}") |
|
yield f"{accumulated}\n\nπ Final Summary:\n{final}", report_path, final |
|
except Exception as e: |
|
logger.error(f"Stream error: {e}", exc_info=True) |
|
yield f"β Error: {str(e)}", None, "" |
|
|
|
|
|
def create_ui(agent: TxAgent) -> gr.Blocks: |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π₯ Clinical Records Analyzer") |
|
with gr.Row(): |
|
file_upload = gr.File(label="Upload Excel File", file_types=[".xlsx"]) |
|
analyze_btn = gr.Button("Analyze", variant="primary") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
report_output = gr.Markdown() |
|
with gr.Column(scale=1): |
|
report_file = gr.File(label="Download", visible=False) |
|
full_output = gr.State() |
|
analyze_btn.click(fn=stream_report, inputs=[file_upload, full_output], outputs=[report_output, report_file, full_output]) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
agent = init_agent() |
|
demo = create_ui(agent) |
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
except Exception as e: |
|
logger.error(f"App error: {e}", exc_info=True) |
|
print(f"β Application error: {e}", file=sys.stderr) |
|
sys.exit(1) |
|
|