Eric Botti
commited on
Commit
·
a92f249
1
Parent(s):
760a529
added Kani agents
Browse files- src/agents.py +12 -30
- src/parser.py +3 -2
- src/player.py +21 -21
src/agents.py
CHANGED
@@ -1,34 +1,16 @@
|
|
1 |
-
from
|
2 |
-
from langchain.agents import AgentExecutor, create_openai_functions_agent
|
3 |
-
from langchain_openai import ChatOpenAI
|
4 |
|
5 |
-
from langchain.prompts import PromptTemplate
|
6 |
-
from reasoning_tools import animal_tools, extract_vote
|
7 |
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
'temperature': 1
|
13 |
-
},
|
14 |
-
"herd": {
|
15 |
-
'model': 'gpt-3.5-turbo',
|
16 |
-
'temperature': 1
|
17 |
-
},
|
18 |
-
"judge": {
|
19 |
-
'model': 'gpt-3.5-turbo',
|
20 |
-
'temperature': 1
|
21 |
-
}
|
22 |
-
}
|
23 |
|
24 |
-
|
|
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
agent = create_openai_functions_agent(llm, animal_tools, prompt)
|
33 |
-
|
34 |
-
super().__init__(agent=agent, tools=animal_tools, verbose=True, return_intermediate_steps=True)
|
|
|
1 |
+
from kani import Kani
|
|
|
|
|
2 |
|
|
|
|
|
3 |
|
4 |
+
class LogMessagesKani(Kani):
|
5 |
+
def __init__(self, engine, log_filepath: str = None, *args, **kwargs):
|
6 |
+
super().__init__(engine, *args, **kwargs)
|
7 |
+
self.log_filepath = log_filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
async def add_to_history(self, message, *args, **kwargs):
|
10 |
+
await super().add_to_history(message, *args, **kwargs)
|
11 |
|
12 |
+
# Logs Message to File
|
13 |
+
if self.log_filepath:
|
14 |
+
with open(self.log_filepath, "a") as log_file:
|
15 |
+
log_file.write(message.model_dump_json())
|
16 |
+
log_file.write("\n")
|
|
|
|
|
|
|
|
src/parser.py
CHANGED
@@ -2,10 +2,11 @@ from typing import Type
|
|
2 |
import asyncio
|
3 |
import json
|
4 |
|
5 |
-
from kani import Kani
|
6 |
from kani.engines.openai import OpenAIEngine
|
7 |
from pydantic import BaseModel, ValidationError
|
8 |
|
|
|
|
|
9 |
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
10 |
Here is the output schema:
|
11 |
```
|
@@ -24,7 +25,7 @@ Output:
|
|
24 |
"""
|
25 |
|
26 |
|
27 |
-
class ParserKani(
|
28 |
def __init__(self, engine, *args, **kwargs):
|
29 |
super().__init__(engine, *args, **kwargs)
|
30 |
|
|
|
2 |
import asyncio
|
3 |
import json
|
4 |
|
|
|
5 |
from kani.engines.openai import OpenAIEngine
|
6 |
from pydantic import BaseModel, ValidationError
|
7 |
|
8 |
+
from agents import LogMessagesKani
|
9 |
+
|
10 |
FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
|
11 |
Here is the output schema:
|
12 |
```
|
|
|
25 |
"""
|
26 |
|
27 |
|
28 |
+
class ParserKani(LogMessagesKani):
|
29 |
def __init__(self, engine, *args, **kwargs):
|
30 |
super().__init__(engine, *args, **kwargs)
|
31 |
|
src/player.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
import os
|
|
|
|
|
2 |
import openai
|
|
|
|
|
|
|
3 |
|
4 |
# Using TGI Inference Endpoints from Hugging Face
|
5 |
# api_type = "tgi"
|
@@ -15,40 +20,35 @@ else:
|
|
15 |
model_name = "gpt-3.5-turbo"
|
16 |
client = openai.Client()
|
17 |
|
|
|
|
|
|
|
18 |
class Player:
|
19 |
-
def __init__(self, name: str,
|
20 |
self.name = name
|
21 |
-
self.controller =
|
|
|
|
|
|
|
22 |
self.role = role
|
23 |
self.messages = []
|
24 |
|
25 |
-
def
|
26 |
-
"""
|
27 |
-
|
28 |
-
output = self.
|
29 |
-
|
30 |
return output
|
31 |
|
32 |
-
def
|
33 |
if self.controller == "human":
|
34 |
print(prompt)
|
35 |
return input()
|
36 |
|
37 |
elif self.controller == "ai":
|
38 |
-
|
39 |
-
model=model_name,
|
40 |
-
messages=self.messages,
|
41 |
-
stream=False,
|
42 |
-
)
|
43 |
-
|
44 |
-
return chat_completion.choices[0].message.content
|
45 |
-
|
46 |
-
|
47 |
-
def add_message(self, message: str):
|
48 |
-
"""Add a message to the messages list. No response required."""
|
49 |
-
self.messages.append({"role": "user", "content": message})
|
50 |
-
|
51 |
|
|
|
52 |
|
53 |
|
54 |
|
|
|
1 |
import os
|
2 |
+
import asyncio
|
3 |
+
|
4 |
import openai
|
5 |
+
from agents import LogMessagesKani
|
6 |
+
from kani.engines.openai import OpenAIEngine
|
7 |
+
|
8 |
|
9 |
# Using TGI Inference Endpoints from Hugging Face
|
10 |
# api_type = "tgi"
|
|
|
20 |
model_name = "gpt-3.5-turbo"
|
21 |
client = openai.Client()
|
22 |
|
23 |
+
openai_engine = OpenAIEngine(model="gpt-3.5-turbo")
|
24 |
+
|
25 |
+
|
26 |
class Player:
|
27 |
+
def __init__(self, name: str, controller_type: str, role: str, log_filepath: str = None):
|
28 |
self.name = name
|
29 |
+
self.controller = controller_type
|
30 |
+
if controller_type == "ai":
|
31 |
+
self.kani = LogMessagesKani(openai_engine, log_filepath=log_filepath)
|
32 |
+
|
33 |
self.role = role
|
34 |
self.messages = []
|
35 |
|
36 |
+
async def respond_to(self, prompt: str) -> str:
|
37 |
+
"""Makes the player respond to a prompt. Returns the response."""
|
38 |
+
# Generate a response from the controller
|
39 |
+
output = await self.__generate(prompt)
|
40 |
+
|
41 |
return output
|
42 |
|
43 |
+
async def __generate(self, prompt: str) -> str:
|
44 |
if self.controller == "human":
|
45 |
print(prompt)
|
46 |
return input()
|
47 |
|
48 |
elif self.controller == "ai":
|
49 |
+
output = await self.kani.chat_round_str(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
return output
|
52 |
|
53 |
|
54 |
|