File size: 2,244 Bytes
eaa3d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

import os
import math
from typing import Union, Optional
import torch
import logging

#from vllm import LLM, SamplingParams
#from vllm.lora.request import LoRARequest

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, set_seed, BitsAndBytesConfig
import openai
from openai.error import (APIError, RateLimitError, ServiceUnavailableError,
                          Timeout, APIConnectionError, InvalidRequestError)
from tenacity import (before_sleep_log, retry, retry_if_exception_type,
                      stop_after_delay, wait_random_exponential, stop_after_attempt)


logger = logging.getLogger(__name__)

class Summarizer:
    def __init__(self,
                 inference_mode:str,
                 model_id:str,
                 api_key:str,
                 dtype:str="bfloat16",
                 seed=42,
                 context_size:int=int(1024*26),
                 gpu_memory_utilization:int=0.7,
                 tensor_parallel_size=1
                 ) -> None:
        
        self.inference_mode=inference_mode
        self.model = None
        self.tokenizer = None
        self.seed = seed
        openai.api_key = api_key
        self.model = model_id
    
    def get_generation_config(
        self,
        repetition_penalty:float = 1.2,
        do_sample:bool=True,
        temperature:float = 0.1,
        top_p:float = 0.9,
        max_tokens:int = 1024
        ):
 
        return generation_config
    
    @retry(retry=retry_if_exception_type((APIError, Timeout, RateLimitError,
                                        ServiceUnavailableError, APIConnectionError, InvalidRequestError)),
        wait=wait_random_exponential(max=60), stop=stop_after_attempt(10),
        before_sleep=before_sleep_log(logger, logging.WARNING))
    def inference_with_gpt(self, prompt):
        prompt_messages = [{"role": "user", "content": prompt}]
        try:
            response = openai.ChatCompletion.create(model = self.model, messages = prompt_messages, temperature = 0.1)
            #finish_reason = response.choices[0].finish_reason
            response = response.choices[0].message.content
        except InvalidRequestError:
            response = ''
        
        return response