dexhunter commited on
Commit
a4d58d9
·
unverified ·
1 Parent(s): 67fa666

feat: Add openrouter backend (#55) (#56)

Browse files

* Update router to determine api provider
* Add backend for openrouter
* Update webui app
* Update README
* Update requirements for new openai python package

README.md CHANGED
@@ -99,6 +99,8 @@ Set up your OpenAI (or Anthropic) API key:
99
  export OPENAI_API_KEY=<your API key>
100
  # or
101
  export ANTHROPIC_API_KEY=<your API key>
 
 
102
  ```
103
 
104
  To run AIDE:
 
99
  export OPENAI_API_KEY=<your API key>
100
  # or
101
  export ANTHROPIC_API_KEY=<your API key>
102
+ # or
103
+ export OPENROUTER_API_KEY=<your API key>
104
  ```
105
 
106
  To run AIDE:
aide/backend/__init__.py CHANGED
@@ -1,5 +1,26 @@
1
- from . import backend_anthropic, backend_openai
2
  from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def query(
@@ -35,13 +56,14 @@ def query(
35
 
36
  # Handle models with beta limitations
37
  # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations
38
- if model.startswith("o1"):
39
  if system_message:
40
  user_message = system_message
41
  system_message = None
42
  model_kwargs["temperature"] = 1
43
 
44
- query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
 
45
  output, req_time, in_tok_count, out_tok_count, info = query_func(
46
  system_message=compile_prompt_to_md(system_message) if system_message else None,
47
  user_message=compile_prompt_to_md(user_message) if user_message else None,
 
1
+ from . import backend_anthropic, backend_openai, backend_openrouter
2
  from .utils import FunctionSpec, OutputType, PromptType, compile_prompt_to_md
3
+ import re
4
+ import logging
5
+
6
+ logger = logging.getLogger("aide")
7
+
8
+
9
+ def determine_provider(model: str) -> str:
10
+ if model.startswith("gpt-") or re.match(r"^o\d", model):
11
+ return "openai"
12
+ elif model.startswith("claude-"):
13
+ return "anthropic"
14
+ # all other models are handle by openrouter
15
+ else:
16
+ return "openrouter"
17
+
18
+
19
+ provider_to_query_func = {
20
+ "openai": backend_openai.query,
21
+ "anthropic": backend_anthropic.query,
22
+ "openrouter": backend_openrouter.query,
23
+ }
24
 
25
 
26
  def query(
 
56
 
57
  # Handle models with beta limitations
58
  # ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations
59
+ if re.match(r"^o\d", model):
60
  if system_message:
61
  user_message = system_message
62
  system_message = None
63
  model_kwargs["temperature"] = 1
64
 
65
+ provider = determine_provider(model)
66
+ query_func = provider_to_query_func[provider]
67
  output, req_time, in_tok_count, out_tok_count, info = query_func(
68
  system_message=compile_prompt_to_md(system_message) if system_message else None,
69
  user_message=compile_prompt_to_md(user_message) if user_message else None,
aide/backend/backend_openrouter.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Backend for OpenRouter API"""
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+
7
+ from funcy import notnone, once, select_values
8
+ import openai
9
+
10
+ from .utils import FunctionSpec, OutputType, backoff_create
11
+
12
+ logger = logging.getLogger("aide")
13
+
14
+ _client: openai.OpenAI = None # type: ignore
15
+
16
+ OPENAI_TIMEOUT_EXCEPTIONS = (
17
+ openai.RateLimitError,
18
+ openai.APIConnectionError,
19
+ openai.APITimeoutError,
20
+ openai.InternalServerError,
21
+ )
22
+
23
+
24
+ @once
25
+ def _setup_openrouter_client():
26
+ global _client
27
+ _client = openai.OpenAI(
28
+ base_url="https://openrouter.ai/api/v1",
29
+ api_key=os.getenv("OPENROUTER_API_KEY"),
30
+ max_retries=0,
31
+ )
32
+
33
+
34
+ def query(
35
+ system_message: str | None,
36
+ user_message: str | None,
37
+ func_spec: FunctionSpec | None = None,
38
+ **model_kwargs,
39
+ ) -> tuple[OutputType, float, int, int, dict]:
40
+ _setup_openrouter_client()
41
+ filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
42
+
43
+ if func_spec is not None:
44
+ raise NotImplementedError(
45
+ "We are not supporting function calling in OpenRouter for now."
46
+ )
47
+
48
+ # in case some backends dont support system roles, just convert everything to user
49
+ messages = [
50
+ {"role": "user", "content": message}
51
+ for message in [system_message, user_message]
52
+ if message
53
+ ]
54
+
55
+ t0 = time.time()
56
+ completion = backoff_create(
57
+ _client.chat.completions.create,
58
+ OPENAI_TIMEOUT_EXCEPTIONS,
59
+ messages=messages,
60
+ extra_body={
61
+ "provider": {
62
+ "order": ["Fireworks"],
63
+ "ignore": ["Together", "DeepInfra", "Hyperbolic"],
64
+ },
65
+ },
66
+ **filtered_kwargs,
67
+ )
68
+ req_time = time.time() - t0
69
+
70
+ output = completion.choices[0].message.content
71
+
72
+ in_tokens = completion.usage.prompt_tokens
73
+ out_tokens = completion.usage.completion_tokens
74
+
75
+ info = {
76
+ "system_fingerprint": completion.system_fingerprint,
77
+ "model": completion.model,
78
+ "created": completion.created,
79
+ }
80
+
81
+ return output, req_time, in_tokens, out_tokens, info
aide/webui/app.py CHANGED
@@ -51,6 +51,7 @@ class WebUI:
51
  return {
52
  "openai_key": os.getenv("OPENAI_API_KEY", ""),
53
  "anthropic_key": os.getenv("ANTHROPIC_API_KEY", ""),
 
54
  }
55
 
56
  @staticmethod
@@ -127,9 +128,20 @@ class WebUI:
127
  type="password",
128
  label_visibility="collapsed",
129
  )
 
 
 
 
 
 
 
 
 
 
130
  if st.button("Save API Keys", use_container_width=True):
131
  st.session_state.openai_key = openai_key
132
  st.session_state.anthropic_key = anthropic_key
 
133
  st.success("API keys saved!")
134
 
135
  def render_input_section(self, results_col):
@@ -340,6 +352,8 @@ class WebUI:
340
  os.environ["OPENAI_API_KEY"] = st.session_state.openai_key
341
  if st.session_state.get("anthropic_key"):
342
  os.environ["ANTHROPIC_API_KEY"] = st.session_state.anthropic_key
 
 
343
 
344
  def prepare_input_directory(self, files):
345
  """
 
51
  return {
52
  "openai_key": os.getenv("OPENAI_API_KEY", ""),
53
  "anthropic_key": os.getenv("ANTHROPIC_API_KEY", ""),
54
+ "openrouter_key": os.getenv("OPENROUTER_API_KEY", ""),
55
  }
56
 
57
  @staticmethod
 
128
  type="password",
129
  label_visibility="collapsed",
130
  )
131
+ st.markdown(
132
+ "<p style='text-align: center;'>OpenRouter API Key</p>",
133
+ unsafe_allow_html=True,
134
+ )
135
+ openrouter_key = st.text_input(
136
+ "OpenRouter API Key",
137
+ value=self.env_vars["openrouter_key"],
138
+ type="password",
139
+ label_visibility="collapsed",
140
+ )
141
  if st.button("Save API Keys", use_container_width=True):
142
  st.session_state.openai_key = openai_key
143
  st.session_state.anthropic_key = anthropic_key
144
+ st.session_state.openrouter_key = openrouter_key
145
  st.success("API keys saved!")
146
 
147
  def render_input_section(self, results_col):
 
352
  os.environ["OPENAI_API_KEY"] = st.session_state.openai_key
353
  if st.session_state.get("anthropic_key"):
354
  os.environ["ANTHROPIC_API_KEY"] = st.session_state.anthropic_key
355
+ if st.session_state.get("openrouter_key"):
356
+ os.environ["OPENROUTER_API_KEY"] = st.session_state.openrouter_key
357
 
358
  def prepare_input_directory(self, files):
359
  """
requirements.txt CHANGED
@@ -4,7 +4,7 @@ funcy==2.0
4
  humanize==4.8.0
5
  jsonschema==4.19.2
6
  numpy==1.26.2
7
- openai>=1.3.5
8
  anthropic>=0.20.0
9
  pandas==2.1.4
10
  pytest==7.4.3
 
4
  humanize==4.8.0
5
  jsonschema==4.19.2
6
  numpy==1.26.2
7
+ openai>=1.69.0
8
  anthropic>=0.20.0
9
  pandas==2.1.4
10
  pytest==7.4.3