faster-whisper-small / optimizer /llm_api_runner.py
tianyaogavin's picture
init main framework
1bf36cc
"""
ChatGPT优化器 - 使用OpenAI API优化转写结果
"""
from openai import OpenAI
import os
import logging
import time
# 配置日志
def setup_logger(name, level=logging.INFO):
"""设置日志记录器"""
logger = logging.getLogger(name)
# 清除所有已有的handler,避免重复
if logger.handlers:
logger.handlers.clear()
# 添加新的handler
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
# 禁止传播到父logger,避免重复日志
logger.propagate = False
return logger
# 创建日志记录器
logger = setup_logger("optimizer.api")
# 默认模型
MODEL_NAME = "gpt-3.5-turbo"
class ChatGPTRunner:
"""
ChatGPT优化器,使用OpenAI API优化转写结果
"""
def __init__(self, model: str = MODEL_NAME):
"""
初始化ChatGPT优化器
:param model: 使用的模型名称
"""
self.model = model
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("未设置OPENAI_API_KEY环境变量")
self.client = OpenAI(api_key=api_key)
logger.debug(f"ChatGPT优化器初始化完成,使用模型: {model}")
def build_prompt(self, text: str) -> str:
"""
构建优化提示
:param text: 需要优化的文本
:return: 构建好的提示
"""
return (
"示例:\n"
"原句:你门现在就得开会了,别迟到了。\n"
"修改:你们现在就得开会了,别迟到了。\n"
"原句:在我们先辱了解这些任务之前。\n"
"修改:在我们深入了解这些任务之前。\n"
"原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n"
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
"原句:系统将进入留言模式,请耐行等待。\n"
"修改:系统将进入留言模式,请耐心等待。\n"
"原句:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。\n"
"修改:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法。\n"
f"原句:{text}\n"
f"修改:"
)
def optimize(self, text: str, max_tokens: int = 256) -> str:
"""
优化文本
:param text: 需要优化的文本
:param max_tokens: 最大生成token数
:return: 优化后的文本
"""
logger.debug(f"开始优化文本: {text}")
start_time = time.time()
# 构建提示
prompt = self.build_prompt(text)
try:
# 调用API
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"},
{"role": "user", "content": prompt}
],
temperature=0.4,
max_tokens=max_tokens,
)
# 提取结果
result = response.choices[0].message.content.strip()
# 记录耗时
elapsed_time = time.time() - start_time
logger.debug(f"优化完成,耗时: {elapsed_time:.2f}秒")
logger.info(f"优化结果: {result}")
return result
except Exception as e:
logger.error(f"优化失败: {str(e)}")
# 出错时返回原文
return text
if __name__ == "__main__":
# 设置日志级别为DEBUG以查看详细信息
logger.setLevel(logging.DEBUG)
# 测试优化
runner = ChatGPTRunner(MODEL_NAME)
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
logger.info(f"优化前: {test_input}")
result = runner.optimize(test_input)
logger.info(f"优化后: {result}")