Dhruv-Ty commited on
Commit
3f298d8
·
1 Parent(s): c044359

openrouter fix

Browse files
Files changed (1) hide show
  1. app.py +15 -16
app.py CHANGED
@@ -5,7 +5,7 @@ from dotenv import load_dotenv
5
  from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
- from langchain_openai import ChatOpenAI
9
 
10
  from interface import create_demo
11
  from medrax.agent import *
@@ -14,7 +14,7 @@ from medrax.utils import *
14
 
15
  warnings.filterwarnings("ignore")
16
  logging.set_verbosity_error()
17
- _ = load_dotenv() # Loads your .env file (OPENAI_API_KEY and OPENAI_BASE_URL)
18
 
19
  def initialize_agent(
20
  prompt_file,
@@ -22,11 +22,12 @@ def initialize_agent(
22
  model_dir="./model-weights",
23
  temp_dir="temp",
24
  device="cuda",
25
- model="google/gemini-1.5-flash-latest", # ✅ updated model name for OpenRouter
26
  temperature=0.7,
27
- top_p=0.95,
28
- openai_kwargs=None
29
  ):
 
 
30
  prompts = load_prompts_from_file(prompt_file)
31
  prompt = prompts["MEDICAL_ASSISTANT"]
32
 
@@ -55,8 +56,14 @@ def initialize_agent(
55
  tools_dict[tool_name] = all_tools[tool_name]()
56
 
57
  checkpointer = MemorySaver()
 
 
 
 
 
 
 
58
 
59
- model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
60
  agent = Agent(
61
  model,
62
  tools=list(tools_dict.values()),
@@ -85,23 +92,15 @@ if __name__ == "__main__":
85
  # "ChestXRayGeneratorTool",
86
  ]
87
 
88
- # ✅ Collect environment variables and pass to model
89
- openai_kwargs = {}
90
- if api_key := os.getenv("OPENAI_API_KEY"):
91
- openai_kwargs["openai_api_key"] = api_key
92
- if base_url := os.getenv("OPENAI_BASE_URL"):
93
- openai_kwargs["openai_api_base"] = base_url
94
-
95
  agent, tools_dict = initialize_agent(
96
  "medrax/docs/system_prompts.txt",
97
  tools_to_use=selected_tools,
98
  model_dir="./model-weights",
99
  temp_dir="temp",
100
  device="cuda",
101
- model="google/gemini-1.5-flash-latest", # ✅ Updated OpenRouter model
102
  temperature=0.7,
103
- top_p=0.95,
104
- openai_kwargs=openai_kwargs
105
  )
106
 
107
  demo = create_demo(agent, tools_dict)
 
5
  from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
+ from langchain_community.chat_models import ChatOpenRouter
9
 
10
  from interface import create_demo
11
  from medrax.agent import *
 
14
 
15
  warnings.filterwarnings("ignore")
16
  logging.set_verbosity_error()
17
+ _ = load_dotenv()
18
 
19
  def initialize_agent(
20
  prompt_file,
 
22
  model_dir="./model-weights",
23
  temp_dir="temp",
24
  device="cuda",
25
+ model="google/gemini-2.5-pro-exp-03-25:free",
26
  temperature=0.7,
27
+ top_p=0.95
 
28
  ):
29
+ """Initialize the MedRAX agent with specified tools and configuration."""
30
+
31
  prompts = load_prompts_from_file(prompt_file)
32
  prompt = prompts["MEDICAL_ASSISTANT"]
33
 
 
56
  tools_dict[tool_name] = all_tools[tool_name]()
57
 
58
  checkpointer = MemorySaver()
59
+ model = ChatOpenRouter(
60
+ model_name=model,
61
+ api_key=os.getenv("OPENAI_API_KEY"),
62
+ base_url=os.getenv("OPENAI_BASE_URL"),
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ )
66
 
 
67
  agent = Agent(
68
  model,
69
  tools=list(tools_dict.values()),
 
92
  # "ChestXRayGeneratorTool",
93
  ]
94
 
 
 
 
 
 
 
 
95
  agent, tools_dict = initialize_agent(
96
  "medrax/docs/system_prompts.txt",
97
  tools_to_use=selected_tools,
98
  model_dir="./model-weights",
99
  temp_dir="temp",
100
  device="cuda",
101
+ model="google/gemini-2.5-pro-exp-03-25:free",
102
  temperature=0.7,
103
+ top_p=0.95
 
104
  )
105
 
106
  demo = create_demo(agent, tools_dict)