Eric Botti commited on
Commit
a92f249
·
1 Parent(s): 760a529

added Kani agents

Browse files
Files changed (3) hide show
  1. src/agents.py +12 -30
  2. src/parser.py +3 -2
  3. src/player.py +21 -21
src/agents.py CHANGED
@@ -1,34 +1,16 @@
1
- from langchain import hub
2
- from langchain.agents import AgentExecutor, create_openai_functions_agent
3
- from langchain_openai import ChatOpenAI
4
 
5
- from langchain.prompts import PromptTemplate
6
- from reasoning_tools import animal_tools, extract_vote
7
 
8
- # LLM Configuration for each role
9
- llm_parameters = {
10
- "chameleon": {
11
- 'model': 'gpt-4-turbo-preview',
12
- 'temperature': 1
13
- },
14
- "herd": {
15
- 'model': 'gpt-3.5-turbo',
16
- 'temperature': 1
17
- },
18
- "judge": {
19
- 'model': 'gpt-3.5-turbo',
20
- 'temperature': 1
21
- }
22
- }
23
 
24
- prompt = hub.pull("hwchase17/openai-functions-agent")
 
25
 
26
-
27
- class PlayerAgent(AgentExecutor):
28
-
29
- def __init__(self, role):
30
- llm = ChatOpenAI(**llm_parameters[role])
31
-
32
- agent = create_openai_functions_agent(llm, animal_tools, prompt)
33
-
34
- super().__init__(agent=agent, tools=animal_tools, verbose=True, return_intermediate_steps=True)
 
1
+ from kani import Kani
 
 
2
 
 
 
3
 
4
+ class LogMessagesKani(Kani):
5
+ def __init__(self, engine, log_filepath: str = None, *args, **kwargs):
6
+ super().__init__(engine, *args, **kwargs)
7
+ self.log_filepath = log_filepath
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ async def add_to_history(self, message, *args, **kwargs):
10
+ await super().add_to_history(message, *args, **kwargs)
11
 
12
+ # Logs Message to File
13
+ if self.log_filepath:
14
+ with open(self.log_filepath, "a") as log_file:
15
+ log_file.write(message.model_dump_json())
16
+ log_file.write("\n")
 
 
 
 
src/parser.py CHANGED
@@ -2,10 +2,11 @@ from typing import Type
2
  import asyncio
3
  import json
4
 
5
- from kani import Kani
6
  from kani.engines.openai import OpenAIEngine
7
  from pydantic import BaseModel, ValidationError
8
 
 
 
9
  FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
10
  Here is the output schema:
11
  ```
@@ -24,7 +25,7 @@ Output:
24
  """
25
 
26
 
27
- class ParserKani(Kani):
28
  def __init__(self, engine, *args, **kwargs):
29
  super().__init__(engine, *args, **kwargs)
30
 
 
2
  import asyncio
3
  import json
4
 
 
5
  from kani.engines.openai import OpenAIEngine
6
  from pydantic import BaseModel, ValidationError
7
 
8
+ from agents import LogMessagesKani
9
+
10
  FORMAT_INSTRUCTIONS = """The output should be reformatted as a JSON instance that conforms to the JSON schema below.
11
  Here is the output schema:
12
  ```
 
25
  """
26
 
27
 
28
+ class ParserKani(LogMessagesKani):
29
  def __init__(self, engine, *args, **kwargs):
30
  super().__init__(engine, *args, **kwargs)
31
 
src/player.py CHANGED
@@ -1,5 +1,10 @@
1
  import os
 
 
2
  import openai
 
 
 
3
 
4
  # Using TGI Inference Endpoints from Hugging Face
5
  # api_type = "tgi"
@@ -15,40 +20,35 @@ else:
15
  model_name = "gpt-3.5-turbo"
16
  client = openai.Client()
17
 
 
 
 
18
  class Player:
19
- def __init__(self, name: str, controller: str, role: str):
20
  self.name = name
21
- self.controller = controller
 
 
 
22
  self.role = role
23
  self.messages = []
24
 
25
- def collect_input(self, prompt: str) -> str:
26
- """Store the input and output in the messages list. Return the output."""
27
- self.messages.append({"role": "user", "content": prompt})
28
- output = self.respond(prompt)
29
- self.messages.append({"role": "assistant", "content": output})
30
  return output
31
 
32
- def respond(self, prompt: str) -> str:
33
  if self.controller == "human":
34
  print(prompt)
35
  return input()
36
 
37
  elif self.controller == "ai":
38
- chat_completion = client.chat.completions.create(
39
- model=model_name,
40
- messages=self.messages,
41
- stream=False,
42
- )
43
-
44
- return chat_completion.choices[0].message.content
45
-
46
-
47
- def add_message(self, message: str):
48
- """Add a message to the messages list. No response required."""
49
- self.messages.append({"role": "user", "content": message})
50
-
51
 
 
52
 
53
 
54
 
 
1
  import os
2
+ import asyncio
3
+
4
  import openai
5
+ from agents import LogMessagesKani
6
+ from kani.engines.openai import OpenAIEngine
7
+
8
 
9
  # Using TGI Inference Endpoints from Hugging Face
10
  # api_type = "tgi"
 
20
  model_name = "gpt-3.5-turbo"
21
  client = openai.Client()
22
 
23
+ openai_engine = OpenAIEngine(model="gpt-3.5-turbo")
24
+
25
+
26
  class Player:
27
+ def __init__(self, name: str, controller_type: str, role: str, log_filepath: str = None):
28
  self.name = name
29
+ self.controller = controller_type
30
+ if controller_type == "ai":
31
+ self.kani = LogMessagesKani(openai_engine, log_filepath=log_filepath)
32
+
33
  self.role = role
34
  self.messages = []
35
 
36
+ async def respond_to(self, prompt: str) -> str:
37
+ """Makes the player respond to a prompt. Returns the response."""
38
+ # Generate a response from the controller
39
+ output = await self.__generate(prompt)
40
+
41
  return output
42
 
43
+ async def __generate(self, prompt: str) -> str:
44
  if self.controller == "human":
45
  print(prompt)
46
  return input()
47
 
48
  elif self.controller == "ai":
49
+ output = await self.kani.chat_round_str(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ return output
52
 
53
 
54