|
""" |
|
ChatGPT优化器 - 使用OpenAI API优化转写结果 |
|
""" |
|
|
|
from openai import OpenAI |
|
import os |
|
import logging |
|
import time |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
"""设置日志记录器""" |
|
logger = logging.getLogger(name) |
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
|
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
logger.propagate = False |
|
return logger |
|
|
|
|
|
logger = setup_logger("optimizer.api") |
|
|
|
|
|
MODEL_NAME = "gpt-3.5-turbo" |
|
|
|
class ChatGPTRunner: |
|
""" |
|
ChatGPT优化器,使用OpenAI API优化转写结果 |
|
""" |
|
|
|
def __init__(self, model: str = MODEL_NAME): |
|
""" |
|
初始化ChatGPT优化器 |
|
|
|
:param model: 使用的模型名称 |
|
""" |
|
self.model = model |
|
api_key = os.getenv("OPENAI_API_KEY") |
|
if not api_key: |
|
logger.warning("未设置OPENAI_API_KEY环境变量") |
|
self.client = OpenAI(api_key=api_key) |
|
logger.debug(f"ChatGPT优化器初始化完成,使用模型: {model}") |
|
|
|
def build_prompt(self, text: str) -> str: |
|
""" |
|
构建优化提示 |
|
|
|
:param text: 需要优化的文本 |
|
:return: 构建好的提示 |
|
""" |
|
return ( |
|
"示例:\n" |
|
"原句:你门现在就得开会了,别迟到了。\n" |
|
"修改:你们现在就得开会了,别迟到了。\n" |
|
"原句:在我们先辱了解这些任务之前。\n" |
|
"修改:在我们深入了解这些任务之前。\n" |
|
"原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n" |
|
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n" |
|
"原句:系统将进入留言模式,请耐行等待。\n" |
|
"修改:系统将进入留言模式,请耐心等待。\n" |
|
"原句:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。\n" |
|
"修改:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法。\n" |
|
f"原句:{text}\n" |
|
f"修改:" |
|
) |
|
|
|
def optimize(self, text: str, max_tokens: int = 256) -> str: |
|
""" |
|
优化文本 |
|
|
|
:param text: 需要优化的文本 |
|
:param max_tokens: 最大生成token数 |
|
:return: 优化后的文本 |
|
""" |
|
logger.debug(f"开始优化文本: {text}") |
|
start_time = time.time() |
|
|
|
|
|
prompt = self.build_prompt(text) |
|
|
|
try: |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=[ |
|
{"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.4, |
|
max_tokens=max_tokens, |
|
) |
|
|
|
|
|
result = response.choices[0].message.content.strip() |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
logger.debug(f"优化完成,耗时: {elapsed_time:.2f}秒") |
|
logger.info(f"优化结果: {result}") |
|
|
|
return result |
|
except Exception as e: |
|
logger.error(f"优化失败: {str(e)}") |
|
|
|
return text |
|
|
|
if __name__ == "__main__": |
|
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
runner = ChatGPTRunner(MODEL_NAME) |
|
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。" |
|
|
|
logger.info(f"优化前: {test_input}") |
|
result = runner.optimize(test_input) |
|
logger.info(f"优化后: {result}") |
|
|