|
|
|
import os |
|
import math |
|
from typing import Union, Optional |
|
import torch |
|
import logging |
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig |
|
import openai |
|
from openai.error import (APIError, RateLimitError, ServiceUnavailableError, |
|
Timeout, APIConnectionError, InvalidRequestError) |
|
from tenacity import (before_sleep_log, retry, retry_if_exception_type, |
|
stop_after_delay, wait_random_exponential, stop_after_attempt) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class Summarizer: |
|
def __init__(self, |
|
inference_mode:str, |
|
model_id:str, |
|
api_key:str, |
|
dtype:str="bfloat16", |
|
seed=42, |
|
context_size:int=int(1024*26), |
|
gpu_memory_utilization:int=0.7, |
|
tensor_parallel_size=1 |
|
) -> None: |
|
|
|
self.inference_mode=inference_mode |
|
self.model = None |
|
self.tokenizer = None |
|
self.seed = seed |
|
openai.api_key = api_key |
|
self.model = model_id |
|
|
|
def get_generation_config( |
|
self, |
|
repetition_penalty:float = 1.2, |
|
do_sample:bool=True, |
|
temperature:float = 0.1, |
|
top_p:float = 0.9, |
|
max_tokens:int = 1024 |
|
): |
|
|
|
return generation_config |
|
|
|
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError, |
|
ServiceUnavailableError, APIConnectionError, InvalidRequestError)), |
|
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10), |
|
before_sleep=before_sleep_log(logger, logging.WARNING)) |
|
def inference_with_gpt(self, prompt): |
|
prompt_messages = [{"role": "user", "content": prompt}] |
|
try: |
|
response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1) |
|
|
|
response = response.choices[0].message.content |
|
except InvalidRequestError: |
|
response = '' |
|
|
|
return response |
|
|