JihyukKim's picture
Initial commit
eaa3d8a
import os
import math
from typing import Union, Optional
import torch
import logging
#from vllm import LLM, SamplingParams
#from vllm.lora.request import LoRARequest
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig
import openai
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
Timeout, APIConnectionError, InvalidRequestError)
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
stop_after_delay, wait_random_exponential, stop_after_attempt)
logger = logging.getLogger(__name__)
class Summarizer:
def __init__(self,
inference_mode:str,
model_id:str,
api_key:str,
dtype:str="bfloat16",
seed=42,
context_size:int=int(1024*26),
gpu_memory_utilization:int=0.7,
tensor_parallel_size=1
) -> None:
self.inference_mode=inference_mode
self.model = None
self.tokenizer = None
self.seed = seed
openai.api_key = api_key
self.model = model_id
def get_generation_config(
self,
repetition_penalty:float = 1.2,
do_sample:bool=True,
temperature:float = 0.1,
top_p:float = 0.9,
max_tokens:int = 1024
):
return generation_config
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
before_sleep=before_sleep_log(logger, logging.WARNING))
def inference_with_gpt(self, prompt):
prompt_messages = [{"role": "user", "content": prompt}]
try:
response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1)
#finish_reason = response.choices[0].finish_reason
response = response.choices[0].message.content
except InvalidRequestError:
response = ''
return response