File size: 4,786 Bytes
1bf36cc
 
 
 
5984435
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5984435
1bf36cc
 
5984435
 
1bf36cc
 
 
 
 
 
 
 
 
 
5984435
1bf36cc
 
 
 
 
5984435
 
1bf36cc
 
 
 
 
 
5984435
 
 
 
 
 
 
 
 
 
1bf36cc
 
5984435
 
 
 
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
5984435
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5984435
 
1bf36cc
 
 
 
5984435
 
1bf36cc
 
5984435
1bf36cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
ChatGPT优化器 - 使用OpenAI API优化转写结果
"""

from openai import OpenAI
import os
import logging
import time

# 配置日志
def setup_logger(name, level=logging.INFO):
    """设置日志记录器"""
    logger = logging.getLogger(name)
    # 清除所有已有的handler,避免重复
    if logger.handlers:
        logger.handlers.clear()
    
    # 添加新的handler
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(level)
    # 禁止传播到父logger,避免重复日志
    logger.propagate = False
    return logger

# 创建日志记录器
logger = setup_logger("optimizer.api")

# 默认模型
MODEL_NAME = "gpt-3.5-turbo"

class ChatGPTRunner:
    """
    ChatGPT优化器,使用OpenAI API优化转写结果
    """
    
    def __init__(self, model: str = MODEL_NAME):
        """
        初始化ChatGPT优化器
        
        :param model: 使用的模型名称
        """
        self.model = model
        api_key = os.getenv("OPENAI_API_KEY")
        if not api_key:
            logger.warning("未设置OPENAI_API_KEY环境变量")
        self.client = OpenAI(api_key=api_key)
        logger.debug(f"ChatGPT优化器初始化完成,使用模型: {model}")

    def build_prompt(self, text: str) -> str:
        """
        构建优化提示
        
        :param text: 需要优化的文本
        :return: 构建好的提示
        """
        return (
            "示例:\n"
            "原句:你门现在就得开会了,别迟到了。\n"
            "修改:你们现在就得开会了,别迟到了。\n"
            "原句:在我们先辱了解这些任务之前。\n"
            "修改:在我们深入了解这些任务之前。\n"
            "原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n"
            "修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
            "原句:系统将进入留言模式,请耐行等待。\n"
            "修改:系统将进入留言模式,请耐心等待。\n"
            "原句:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。\n"
            "修改:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法。\n"
            f"原句:{text}\n"
            f"修改:"
        )

    def optimize(self, text: str, max_tokens: int = 256) -> str:
        """
        优化文本
        
        :param text: 需要优化的文本
        :param max_tokens: 最大生成token数
        :return: 优化后的文本
        """
        logger.debug(f"开始优化文本: {text}")
        start_time = time.time()
        
        # 构建提示
        prompt = self.build_prompt(text)
        
        try:
            # 调用API
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.4,
                max_tokens=max_tokens,
            )
            
            # 提取结果
            result = response.choices[0].message.content.strip()
            
            # 记录耗时
            elapsed_time = time.time() - start_time
            logger.debug(f"优化完成,耗时: {elapsed_time:.2f}秒")
            logger.info(f"优化结果: {result}")
            
            return result
        except Exception as e:
            logger.error(f"优化失败: {str(e)}")
            # 出错时返回原文
            return text
    
if __name__ == "__main__":
    # 设置日志级别为DEBUG以查看详细信息
    logger.setLevel(logging.DEBUG)
    
    # 测试优化
    runner = ChatGPTRunner(MODEL_NAME)
    test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
    
    logger.info(f"优化前: {test_input}")
    result = runner.optimize(test_input)
    logger.info(f"优化后: {result}")