# optimizer/llm_runner.py from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline import torch import torch.nn.functional as F MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" class LLMRunner: def __init__(self): print("[LLM - Qwen 1.8B] Loading model...") self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) self.pipe = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer, top_p=None) def build_prompt(self, text: str) -> str: return ( "本模型用于优化语音识别的转写结果。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。\n" "不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。\n" "示例:\n" "原句:你门现在就得开会了,别迟到了。\n" "修改:你们现在就得开会了,别迟到了。\n" "原句:在我们先辱了解这些任务之前。\n" "修改:在我们深入了解这些任务之前。\n" "原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n" "修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n" "原句:系统将进入留言模式,请耐行等待。\n" "修改:系统将进入留言模式,请耐心等待。\n" # "原句:我们要掌握流逝加载的基本操作。\n" # "修改:我们要掌握流式加载的基本操作。\n" f"原句:{text}\n" f"修改:" ) def optimize(self, text: str) -> str: output = self.pipe(self.build_prompt(text), max_new_tokens=128, do_sample=False, top_p=None, return_full_text=False)[0]['generated_text'] prediction = output.strip().splitlines()[0] return prediction def trace_until_generated(self, prompt: str, expected_prefix: str, top_k=30): print(f"\n[Trace] 正在生成,直到输出完整目标前缀:\n→ \"{expected_prefix}\"\n") inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) generated_ids = inputs["input_ids"] past_key_values = None generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) step_count = 0 while step_count < 128: step_count += 1 with torch.no_grad(): outputs = self.model( input_ids=generated_ids[:, -1:], past_key_values=past_key_values, use_cache=True ) logits = outputs.logits[:, -1, :] past_key_values = outputs.past_key_values probs = F.softmax(logits[0], dim=-1) next_token_id = torch.argmax(probs).unsqueeze(0).unsqueeze(0) generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) decoded_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) # 只从 prompt 之后截取新生成部分 new_generated = decoded_text[len(prompt):] print(f"\r[Step {step_count:03}] 生成:{new_generated}", end="") if expected_prefix in new_generated: print(f"\n\n✅ 已完整生成目标前缀!当前生成片段:\n\"{new_generated.strip()}\"") topk = torch.topk(probs, k=top_k) print(f"\n[Next Token Top-{top_k} 概率分布]") for token_id, prob in zip(topk.indices, topk.values): token = self.tokenizer.decode([token_id]) print(f"{token}\t{prob.item():.4f}") return print(f"\n❌ 未在 128 步内生成目标 \"{expected_prefix}\"") if __name__ == "__main__": runner = LLMRunner() test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。" result = runner.optimize(test_input) print("优化前:", test_input) print("优化后:", result) # prompt = runner.build_prompt(test_input) # target_prefix = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的" # runner.trace_until_generated(prompt, expected_prefix=target_prefix, top_k=10)