File size: 1,716 Bytes
778c3d7 172af0f a92f249 778c3d7 a92f249 172af0f 778c3d7 5de0b8a a92f249 5de0b8a 172af0f 5de0b8a 172af0f a92f249 5de0b8a 172af0f a92f249 5de0b8a a92f249 5de0b8a a92f249 5de0b8a a92f249 5de0b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import json
import asyncio
import openai
from agents import LogMessagesKani
from kani.engines.openai import OpenAIEngine
from game_utils import log
# Using TGI Inference Endpoints from Hugging Face
# api_type = "tgi"
api_type = "openai"
if api_type == "tgi":
model_name = "tgi"
client = openai.Client(
base_url=os.environ['HF_ENDPOINT_URL'] + "/v1/",
api_key=os.environ['HF_API_TOKEN']
)
else:
model_name = "gpt-3.5-turbo"
client = openai.Client()
openai_engine = OpenAIEngine(model="gpt-3.5-turbo")
class Player:
def __init__(self, name: str, controller_type: str, role: str, id: str = None, log_filepath: str = None):
self.name = name
self.id = id
self.controller = controller_type
if controller_type == "ai":
self.kani = LogMessagesKani(openai_engine, log_filepath=log_filepath)
self.role = role
self.messages = []
if log_filepath:
player_info = {
"id": self.id,
"name": self.name,
"role": self.role,
"controller": controller_type,
}
log(player_info, log_filepath)
async def respond_to(self, prompt: str) -> str:
"""Makes the player respond to a prompt. Returns the response."""
# Generate a response from the controller
output = await self.__generate(prompt)
return output
async def __generate(self, prompt: str) -> str:
if self.controller == "human":
print(prompt)
return input()
elif self.controller == "ai":
output = await self.kani.chat_round_str(prompt)
return output
|