|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" |
|
|
|
class LLMRunner: |
|
def __init__(self): |
|
print("[LLM - Qwen 1.8B] Loading model...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
self.pipe = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer, top_p=None) |
|
|
|
def build_prompt(self, text: str) -> str: |
|
return ( |
|
"本模型用于优化语音识别的转写结果。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。\n" |
|
"不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。\n" |
|
"示例:\n" |
|
"原句:你门现在就得开会了,别迟到了。\n" |
|
"修改:你们现在就得开会了,别迟到了。\n" |
|
"原句:在我们先辱了解这些任务之前。\n" |
|
"修改:在我们深入了解这些任务之前。\n" |
|
"原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n" |
|
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n" |
|
"原句:系统将进入留言模式,请耐行等待。\n" |
|
"修改:系统将进入留言模式,请耐心等待。\n" |
|
|
|
|
|
f"原句:{text}\n" |
|
f"修改:" |
|
) |
|
|
|
|
|
def optimize(self, text: str) -> str: |
|
output = self.pipe(self.build_prompt(text), max_new_tokens=128, do_sample=False, top_p=None, return_full_text=False)[0]['generated_text'] |
|
prediction = output.strip().splitlines()[0] |
|
return prediction |
|
|
|
def trace_until_generated(self, prompt: str, expected_prefix: str, top_k=30): |
|
print(f"\n[Trace] 正在生成,直到输出完整目标前缀:\n→ \"{expected_prefix}\"\n") |
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
|
generated_ids = inputs["input_ids"] |
|
past_key_values = None |
|
|
|
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
step_count = 0 |
|
|
|
while step_count < 128: |
|
step_count += 1 |
|
|
|
with torch.no_grad(): |
|
outputs = self.model( |
|
input_ids=generated_ids[:, -1:], |
|
past_key_values=past_key_values, |
|
use_cache=True |
|
) |
|
logits = outputs.logits[:, -1, :] |
|
past_key_values = outputs.past_key_values |
|
|
|
probs = F.softmax(logits[0], dim=-1) |
|
next_token_id = torch.argmax(probs).unsqueeze(0).unsqueeze(0) |
|
generated_ids = torch.cat([generated_ids, next_token_id], dim=-1) |
|
|
|
decoded_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
new_generated = decoded_text[len(prompt):] |
|
print(f"\r[Step {step_count:03}] 生成:{new_generated}", end="") |
|
|
|
if expected_prefix in new_generated: |
|
print(f"\n\n✅ 已完整生成目标前缀!当前生成片段:\n\"{new_generated.strip()}\"") |
|
topk = torch.topk(probs, k=top_k) |
|
print(f"\n[Next Token Top-{top_k} 概率分布]") |
|
for token_id, prob in zip(topk.indices, topk.values): |
|
token = self.tokenizer.decode([token_id]) |
|
print(f"{token}\t{prob.item():.4f}") |
|
return |
|
|
|
print(f"\n❌ 未在 128 步内生成目标 \"{expected_prefix}\"") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
runner = LLMRunner() |
|
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。" |
|
result = runner.optimize(test_input) |
|
print("优化前:", test_input) |
|
print("优化后:", result) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|