Ali2206 commited on
Commit
70839bb
·
verified ·
1 Parent(s): 7bacf67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import torch
4
+ import gradio as gr
5
+ from txagent import TxAgent
6
+
7
+ # Setup logging
8
+ logging.basicConfig(
9
+ level=logging.INFO,
10
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Configuration
15
+ MODEL_NAME = "mims-harvard/TxAgent-T1-Llama-3.1-8B"
16
+ RAG_MODEL_NAME = "mims-harvard/ToolRAG-T1-GTE-Qwen2-1.5B"
17
+ TOOL_FILE = "data/new_tool.json"
18
+
19
+ # Environment setup
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
22
+ os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
23
+
24
+ class TxAgentSystem:
25
+ def __init__(self):
26
+ self.agent = None
27
+ self.is_initialized = False
28
+ self.examples = [
29
+ ["A 68-year-old with CKD prescribed metformin. Safe for renal clearance?"],
30
+ ["30-year-old on Prozac diagnosed with WHIM. Safe to take Xolremdi?"]
31
+ ]
32
+
33
+ if not torch.cuda.is_available():
34
+ raise RuntimeError("CUDA is not available - GPU required")
35
+
36
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
37
+ logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
38
+
39
+ self._initialize_system()
40
+
41
+ def _initialize_system(self):
42
+ try:
43
+ os.makedirs("data", exist_ok=True)
44
+ if not os.path.exists(TOOL_FILE):
45
+ with open(TOOL_FILE, "w") as f:
46
+ f.write("[]")
47
+
48
+ logger.info("Initializing TxAgent...")
49
+
50
+ # Initialize with RAG disabled first
51
+ try:
52
+ self.agent = TxAgent(
53
+ model_name=MODEL_NAME,
54
+ rag_model_name=RAG_MODEL_NAME,
55
+ tool_files_dict={"new_tool": TOOL_FILE},
56
+ force_finish=True,
57
+ enable_checker=True,
58
+ step_rag_num=10,
59
+ seed=100,
60
+ enable_rag=True
61
+ )
62
+ except Exception as e:
63
+ logger.warning(f"Failed to initialize with RAG: {str(e)}")
64
+ logger.info("Retrying without RAG...")
65
+ self.agent = TxAgent(
66
+ model_name=MODEL_NAME,
67
+ rag_model_name=None,
68
+ tool_files_dict={"new_tool": TOOL_FILE},
69
+ force_finish=True,
70
+ enable_checker=True,
71
+ step_rag_num=0,
72
+ seed=100,
73
+ enable_rag=False
74
+ )
75
+
76
+ logger.info("Loading main model...")
77
+ self.agent.init_model()
78
+
79
+ self.is_initialized = True
80
+ logger.info("System initialization completed successfully")
81
+
82
+ except Exception as e:
83
+ logger.error(f"System initialization failed: {str(e)}")
84
+ self.is_initialized = False
85
+ raise
86
+
87
+ def chat_fn(self, message, history, temperature, max_tokens, rag_depth):
88
+ if not self.is_initialized:
89
+ return "", history + [(message, "System initialization failed. Please check logs.")]
90
+
91
+ try:
92
+ response = self.agent.run_gradio_chat(
93
+ message=message,
94
+ history=history,
95
+ temperature=temperature,
96
+ max_new_tokens=max_tokens,
97
+ max_total_tokens=16384,
98
+ enable_multi_agent=False,
99
+ conv_history=history,
100
+ max_steps=rag_depth,
101
+ seed=100
102
+ )
103
+ new_history = history + [(message, response)]
104
+ return "", new_history
105
+
106
+ except torch.cuda.OutOfMemoryError:
107
+ torch.cuda.empty_cache()
108
+ return "", history + [(message, "⚠️ GPU memory overflow. Please try a shorter query.")]
109
+
110
+ except Exception as e:
111
+ logger.error(f"Chat error: {str(e)}")
112
+ return "", history + [(message, f"🚨 Error: {str(e)}")]
113
+
114
+ def launch_ui(self):
115
+ with gr.Blocks(theme=gr.themes.Soft(), title="TxAgent Medical AI") as demo:
116
+ gr.Markdown("## 🧠 TxAgent (A100/H100 Optimized)")
117
+
118
+ status = gr.Textbox(
119
+ value="✅ System ready" if self.is_initialized else "❌ Initialization failed",
120
+ label="System Status",
121
+ interactive=False
122
+ )
123
+
124
+ with gr.Row():
125
+ with gr.Column(scale=3):
126
+ chatbot = gr.Chatbot(height=600, label="Conversation History")
127
+ msg = gr.Textbox(label="Enter Medical Query", placeholder="Type your question here...")
128
+ with gr.Column(scale=1):
129
+ temp = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
130
+ max_tokens = gr.Slider(128, 8192, value=2048, label="Max Response Tokens")
131
+ rag_depth = gr.Slider(1, 20, value=10, label="RAG Depth")
132
+ clear_btn = gr.Button("Clear History")
133
+
134
+ gr.Examples(
135
+ examples=self.examples,
136
+ inputs=msg,
137
+ label="Example Queries"
138
+ )
139
+
140
+ msg.submit(
141
+ self.chat_fn,
142
+ inputs=[msg, chatbot, temp, max_tokens, rag_depth],
143
+ outputs=[msg, chatbot]
144
+ )
145
+ clear_btn.click(lambda: None, None, chatbot, queue=False)
146
+
147
+ demo.launch(
148
+ server_name="0.0.0.0",
149
+ server_port=7860
150
+ )
151
+
152
+ if __name__ == "__main__":
153
+ try:
154
+ system = TxAgentSystem()
155
+ system.launch_ui()
156
+ except Exception as e:
157
+ logger.critical(f"Fatal error: {str(e)}")
158
+ raise