Dhruv-Ty commited on
Commit
c044359
Β·
1 Parent(s): 8080b77
Files changed (1) hide show
  1. app.py +10 -42
app.py CHANGED
@@ -4,8 +4,6 @@ from typing import *
4
  from dotenv import load_dotenv
5
  from transformers import logging
6
 
7
- from langgraph.checkpoint.memory import MemorySaver
8
- from langchain_openai import ChatOpenAI
9
  from langgraph.checkpoint.memory import MemorySaver
10
  from langchain_openai import ChatOpenAI
11
 
@@ -16,12 +14,7 @@ from medrax.utils import *
16
 
17
  warnings.filterwarnings("ignore")
18
  logging.set_verbosity_error()
19
- _ = load_dotenv()
20
-
21
- openai_kwargs = {
22
- "openai_api_key": os.environ.get("OPENAI_API_KEY"),
23
- "openai_api_base": os.environ.get("OPENAI_BASE_URL"),
24
- }
25
 
26
  def initialize_agent(
27
  prompt_file,
@@ -29,28 +22,11 @@ def initialize_agent(
29
  model_dir="./model-weights",
30
  temp_dir="temp",
31
  device="cuda",
32
- model="google/gemini-2.5-pro-exp-03-25:free",
33
  temperature=0.7,
34
  top_p=0.95,
35
- openai_kwargs=openai_kwargs
36
-
37
  ):
38
- """Initialize the MedRAX agent with specified tools and configuration.
39
-
40
- Args:
41
- prompt_file (str): Path to file containing system prompts
42
- tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
43
- model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
44
- temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
45
- device (str, optional): Device to run models on. Defaults to "cuda".
46
- model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
47
- temperature (float, optional): Temperature for the model. Defaults to 0.7.
48
- top_p (float, optional): Top P for the model. Defaults to 0.95.
49
- openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
50
-
51
- Returns:
52
- Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
53
- """
54
  prompts = load_prompts_from_file(prompt_file)
55
  prompt = prompts["MEDICAL_ASSISTANT"]
56
 
@@ -72,7 +48,6 @@ def initialize_agent(
72
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
73
  }
74
 
75
- # Initialize only selected tools or all if none specified
76
  tools_dict = {}
77
  tools_to_use = tools_to_use or all_tools.keys()
78
  for tool_name in tools_to_use:
@@ -80,6 +55,7 @@ def initialize_agent(
80
  tools_dict[tool_name] = all_tools[tool_name]()
81
 
82
  checkpointer = MemorySaver()
 
83
  model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
84
  agent = Agent(
85
  model,
@@ -95,14 +71,8 @@ def initialize_agent(
95
 
96
 
97
  if __name__ == "__main__":
98
- """
99
- This is the main entry point for the MedRAX application.
100
- It initializes the agent with the selected tools and creates the demo.
101
- """
102
  print("Starting server...")
103
 
104
- # Example: initialize with only specific tools
105
- # Here three tools are commented out, you can uncomment them to use them
106
  selected_tools = [
107
  "ImageVisualizerTool",
108
  "DicomProcessorTool",
@@ -115,26 +85,24 @@ if __name__ == "__main__":
115
  # "ChestXRayGeneratorTool",
116
  ]
117
 
118
- # Collect the ENV variables
119
  openai_kwargs = {}
120
  if api_key := os.getenv("OPENAI_API_KEY"):
121
  openai_kwargs["openai_api_key"] = api_key
122
-
123
  if base_url := os.getenv("OPENAI_BASE_URL"):
124
  openai_kwargs["openai_api_base"] = base_url
125
-
126
 
127
  agent, tools_dict = initialize_agent(
128
  "medrax/docs/system_prompts.txt",
129
  tools_to_use=selected_tools,
130
- model_dir="./model-weights", # Change this to the path of the model weights
131
- temp_dir="temp", # Change this to the path of the temporary directory
132
- device="cuda", # Change this to the device you want to use
133
- model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
134
  temperature=0.7,
135
  top_p=0.95,
136
  openai_kwargs=openai_kwargs
137
  )
138
- demo = create_demo(agent, tools_dict)
139
 
 
140
  demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
 
4
  from dotenv import load_dotenv
5
  from transformers import logging
6
 
 
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
  from langchain_openai import ChatOpenAI
9
 
 
14
 
15
  warnings.filterwarnings("ignore")
16
  logging.set_verbosity_error()
17
+ _ = load_dotenv() # Loads your .env file (OPENAI_API_KEY and OPENAI_BASE_URL)
 
 
 
 
 
18
 
19
  def initialize_agent(
20
  prompt_file,
 
22
  model_dir="./model-weights",
23
  temp_dir="temp",
24
  device="cuda",
25
+ model="google/gemini-1.5-flash-latest", # βœ… updated model name for OpenRouter
26
  temperature=0.7,
27
  top_p=0.95,
28
+ openai_kwargs=None
 
29
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  prompts = load_prompts_from_file(prompt_file)
31
  prompt = prompts["MEDICAL_ASSISTANT"]
32
 
 
48
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
49
  }
50
 
 
51
  tools_dict = {}
52
  tools_to_use = tools_to_use or all_tools.keys()
53
  for tool_name in tools_to_use:
 
55
  tools_dict[tool_name] = all_tools[tool_name]()
56
 
57
  checkpointer = MemorySaver()
58
+
59
  model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
60
  agent = Agent(
61
  model,
 
71
 
72
 
73
  if __name__ == "__main__":
 
 
 
 
74
  print("Starting server...")
75
 
 
 
76
  selected_tools = [
77
  "ImageVisualizerTool",
78
  "DicomProcessorTool",
 
85
  # "ChestXRayGeneratorTool",
86
  ]
87
 
88
+ # βœ… Collect environment variables and pass to model
89
  openai_kwargs = {}
90
  if api_key := os.getenv("OPENAI_API_KEY"):
91
  openai_kwargs["openai_api_key"] = api_key
 
92
  if base_url := os.getenv("OPENAI_BASE_URL"):
93
  openai_kwargs["openai_api_base"] = base_url
 
94
 
95
  agent, tools_dict = initialize_agent(
96
  "medrax/docs/system_prompts.txt",
97
  tools_to_use=selected_tools,
98
+ model_dir="./model-weights",
99
+ temp_dir="temp",
100
+ device="cuda",
101
+ model="google/gemini-1.5-flash-latest", # βœ… Updated OpenRouter model
102
  temperature=0.7,
103
  top_p=0.95,
104
  openai_kwargs=openai_kwargs
105
  )
 
106
 
107
+ demo = create_demo(agent, tools_dict)
108
  demo.launch(server_name="0.0.0.0", server_port=8585, share=True)