File size: 2,244 Bytes
eaa3d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import math
from typing import Union, Optional
import torch
import logging
#from vllm import LLM, SamplingParams
#from vllm.lora.request import LoRARequest
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig
import openai
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
Timeout, APIConnectionError, InvalidRequestError)
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
stop_after_delay, wait_random_exponential, stop_after_attempt)
logger = logging.getLogger(__name__)
class Summarizer:
def __init__(self,
inference_mode:str,
model_id:str,
api_key:str,
dtype:str="bfloat16",
seed=42,
context_size:int=int(1024*26),
gpu_memory_utilization:int=0.7,
tensor_parallel_size=1
) -> None:
self.inference_mode=inference_mode
self.model = None
self.tokenizer = None
self.seed = seed
openai.api_key = api_key
self.model = model_id
def get_generation_config(
self,
repetition_penalty:float = 1.2,
do_sample:bool=True,
temperature:float = 0.1,
top_p:float = 0.9,
max_tokens:int = 1024
):
return generation_config
@retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
before_sleep=before_sleep_log(logger, logging.WARNING))
def inference_with_gpt(self, prompt):
prompt_messages = [{"role": "user", "content": prompt}]
try:
response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1)
#finish_reason = response.choices[0].finish_reason
response = response.choices[0].message.content
except InvalidRequestError:
response = ''
return response
|