openrouter fix
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from dotenv import load_dotenv
|
|
5 |
from transformers import logging
|
6 |
|
7 |
from langgraph.checkpoint.memory import MemorySaver
|
8 |
-
from
|
9 |
|
10 |
from interface import create_demo
|
11 |
from medrax.agent import *
|
@@ -14,7 +14,7 @@ from medrax.utils import *
|
|
14 |
|
15 |
warnings.filterwarnings("ignore")
|
16 |
logging.set_verbosity_error()
|
17 |
-
_ = load_dotenv()
|
18 |
|
19 |
def initialize_agent(
|
20 |
prompt_file,
|
@@ -22,11 +22,12 @@ def initialize_agent(
|
|
22 |
model_dir="./model-weights",
|
23 |
temp_dir="temp",
|
24 |
device="cuda",
|
25 |
-
model="google/gemini-
|
26 |
temperature=0.7,
|
27 |
-
top_p=0.95
|
28 |
-
openai_kwargs=None
|
29 |
):
|
|
|
|
|
30 |
prompts = load_prompts_from_file(prompt_file)
|
31 |
prompt = prompts["MEDICAL_ASSISTANT"]
|
32 |
|
@@ -55,8 +56,14 @@ def initialize_agent(
|
|
55 |
tools_dict[tool_name] = all_tools[tool_name]()
|
56 |
|
57 |
checkpointer = MemorySaver()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
-
model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
|
60 |
agent = Agent(
|
61 |
model,
|
62 |
tools=list(tools_dict.values()),
|
@@ -85,23 +92,15 @@ if __name__ == "__main__":
|
|
85 |
# "ChestXRayGeneratorTool",
|
86 |
]
|
87 |
|
88 |
-
# ✅ Collect environment variables and pass to model
|
89 |
-
openai_kwargs = {}
|
90 |
-
if api_key := os.getenv("OPENAI_API_KEY"):
|
91 |
-
openai_kwargs["openai_api_key"] = api_key
|
92 |
-
if base_url := os.getenv("OPENAI_BASE_URL"):
|
93 |
-
openai_kwargs["openai_api_base"] = base_url
|
94 |
-
|
95 |
agent, tools_dict = initialize_agent(
|
96 |
"medrax/docs/system_prompts.txt",
|
97 |
tools_to_use=selected_tools,
|
98 |
model_dir="./model-weights",
|
99 |
temp_dir="temp",
|
100 |
device="cuda",
|
101 |
-
model="google/gemini-
|
102 |
temperature=0.7,
|
103 |
-
top_p=0.95
|
104 |
-
openai_kwargs=openai_kwargs
|
105 |
)
|
106 |
|
107 |
demo = create_demo(agent, tools_dict)
|
|
|
5 |
from transformers import logging
|
6 |
|
7 |
from langgraph.checkpoint.memory import MemorySaver
|
8 |
+
from langchain_community.chat_models import ChatOpenRouter
|
9 |
|
10 |
from interface import create_demo
|
11 |
from medrax.agent import *
|
|
|
14 |
|
15 |
warnings.filterwarnings("ignore")
|
16 |
logging.set_verbosity_error()
|
17 |
+
_ = load_dotenv()
|
18 |
|
19 |
def initialize_agent(
|
20 |
prompt_file,
|
|
|
22 |
model_dir="./model-weights",
|
23 |
temp_dir="temp",
|
24 |
device="cuda",
|
25 |
+
model="google/gemini-2.5-pro-exp-03-25:free",
|
26 |
temperature=0.7,
|
27 |
+
top_p=0.95
|
|
|
28 |
):
|
29 |
+
"""Initialize the MedRAX agent with specified tools and configuration."""
|
30 |
+
|
31 |
prompts = load_prompts_from_file(prompt_file)
|
32 |
prompt = prompts["MEDICAL_ASSISTANT"]
|
33 |
|
|
|
56 |
tools_dict[tool_name] = all_tools[tool_name]()
|
57 |
|
58 |
checkpointer = MemorySaver()
|
59 |
+
model = ChatOpenRouter(
|
60 |
+
model_name=model,
|
61 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
62 |
+
base_url=os.getenv("OPENAI_BASE_URL"),
|
63 |
+
temperature=temperature,
|
64 |
+
top_p=top_p,
|
65 |
+
)
|
66 |
|
|
|
67 |
agent = Agent(
|
68 |
model,
|
69 |
tools=list(tools_dict.values()),
|
|
|
92 |
# "ChestXRayGeneratorTool",
|
93 |
]
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
agent, tools_dict = initialize_agent(
|
96 |
"medrax/docs/system_prompts.txt",
|
97 |
tools_to_use=selected_tools,
|
98 |
model_dir="./model-weights",
|
99 |
temp_dir="temp",
|
100 |
device="cuda",
|
101 |
+
model="google/gemini-2.5-pro-exp-03-25:free",
|
102 |
temperature=0.7,
|
103 |
+
top_p=0.95
|
|
|
104 |
)
|
105 |
|
106 |
demo = create_demo(agent, tools_dict)
|