Dhruv-Ty's picture
initial commit
cb3a670
import abc
import asyncio
from abc import abstractmethod
import math
import tiktoken
import openai
import backoff
class LLM(abc.ABC):
prompt_percent = 0.9
@abstractmethod
def __init__(self):
raise NotImplementedError("Subclasses should implement this!")
@abstractmethod
def infer(self, prompts):
raise NotImplementedError("Subclasses should implement this!")
@abstractmethod
def split_input(
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
):
raise NotImplementedError("Subclasses should implement this!")
class GPT(LLM):
prompt_percent = 0.8
openai_cxn_dict = {
"default": {
"endpoint": "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
"api_key": "INSERT YOUR AZURE OPENAI API KEY HERE",
},
}
deployment_max_length_dict = {
"gpt-4": 8192,
"gpt-4-0314": 8192,
"gpt-4-32k": 32768,
"gpt-35-turbo": 4096,
"gpt-35-turbo-16k": 16385,
}
def __init__(self, model_id):
self.temperature = 0.0
self.top_k = 1
self.encoding = tiktoken.encoding_for_model(
"-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
)
self.openai_api = "default"
self.model_id = model_id
self.max_length = self.deployment_max_length_dict[model_id]
self.client = openai.AsyncAzureOpenAI(
api_key=self.openai_cxn_dict[self.openai_api]["api_key"],
api_version="2023-12-01-preview",
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
)
def gen_messages(
self, fixed_instruction, few_shot_examples, input, input_header, output_header
):
messages = [
{
"role": "system",
"content": fixed_instruction,
},
]
for example in few_shot_examples:
messages.extend(
[
{
"role": "user",
"content": input_header + "\n" + example["user"] + "\n\n" + output_header,
},
{
"role": "assistant",
"content": example["assistant"],
},
]
)
messages.extend(
[
{
"role": "user",
"content": input_header + "\n" + input + "\n\n" + output_header,
},
]
)
return messages
# Define the coroutine for making API calls to GPT
@backoff.on_exception(backoff.expo, openai.RateLimitError)
async def make_api_call_to_gpt(self, messages):
response = await self.client.chat.completions.create(
model=self.model_id,
messages=messages,
temperature=self.temperature,
)
return response.choices[0].message.content
async def dispatch_openai_requests(
self,
messages_list,
):
# Asynchronously call the function for each prompt
tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list]
# Gather and run the tasks concurrently
results = await asyncio.gather(*tasks)
return results
def infer(
self,
messages_list,
):
return asyncio.run(self.dispatch_openai_requests(messages_list))
def split_input(
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
):
# Tokenize fixed_prompt
fixed_token_ids = self.encoding.encode(
fixed_instruction
+ " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
)
# Calculate remaining token length
remaining_token_len = math.ceil(
(self.prompt_percent * self.max_length) - len(fixed_token_ids)
)
# Tokenize splittable_input
split_token_ids = self.encoding.encode(splittable_input)
# Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length
split_token_ids_list = [
split_token_ids[i : i + remaining_token_len + 10]
for i in range(0, len(split_token_ids), remaining_token_len)
]
split_input_list = [
self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
]
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
return [
self.gen_messages(
fixed_instruction, few_shot_examples, split_input, input_header, output_header
)
for split_input in split_input_list
]