tianyaogavin's picture
add text optimizer
5984435
# optimizer/llm_runner.py
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
import torch
import torch.nn.functional as F
MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat"
class LLMRunner:
def __init__(self):
print("[LLM - Qwen 1.8B] Loading model...")
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
self.pipe = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer, top_p=None)
def build_prompt(self, text: str) -> str:
return (
"本模型用于优化语音识别的转写结果。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。\n"
"不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。\n"
"示例:\n"
"原句:你门现在就得开会了,别迟到了。\n"
"修改:你们现在就得开会了,别迟到了。\n"
"原句:在我们先辱了解这些任务之前。\n"
"修改:在我们深入了解这些任务之前。\n"
"原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n"
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
"原句:系统将进入留言模式,请耐行等待。\n"
"修改:系统将进入留言模式,请耐心等待。\n"
# "原句:我们要掌握流逝加载的基本操作。\n"
# "修改:我们要掌握流式加载的基本操作。\n"
f"原句:{text}\n"
f"修改:"
)
def optimize(self, text: str) -> str:
output = self.pipe(self.build_prompt(text), max_new_tokens=128, do_sample=False, top_p=None, return_full_text=False)[0]['generated_text']
prediction = output.strip().splitlines()[0]
return prediction
def trace_until_generated(self, prompt: str, expected_prefix: str, top_k=30):
print(f"\n[Trace] 正在生成,直到输出完整目标前缀:\n→ \"{expected_prefix}\"\n")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
generated_ids = inputs["input_ids"]
past_key_values = None
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
step_count = 0
while step_count < 128:
step_count += 1
with torch.no_grad():
outputs = self.model(
input_ids=generated_ids[:, -1:],
past_key_values=past_key_values,
use_cache=True
)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
probs = F.softmax(logits[0], dim=-1)
next_token_id = torch.argmax(probs).unsqueeze(0).unsqueeze(0)
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
decoded_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# 只从 prompt 之后截取新生成部分
new_generated = decoded_text[len(prompt):]
print(f"\r[Step {step_count:03}] 生成:{new_generated}", end="")
if expected_prefix in new_generated:
print(f"\n\n✅ 已完整生成目标前缀!当前生成片段:\n\"{new_generated.strip()}\"")
topk = torch.topk(probs, k=top_k)
print(f"\n[Next Token Top-{top_k} 概率分布]")
for token_id, prob in zip(topk.indices, topk.values):
token = self.tokenizer.decode([token_id])
print(f"{token}\t{prob.item():.4f}")
return
print(f"\n❌ 未在 128 步内生成目标 \"{expected_prefix}\"")
if __name__ == "__main__":
runner = LLMRunner()
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
result = runner.optimize(test_input)
print("优化前:", test_input)
print("优化后:", result)
# prompt = runner.build_prompt(test_input)
# target_prefix = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的"
# runner.trace_until_generated(prompt, expected_prefix=target_prefix, top_k=10)