File size: 4,794 Bytes
5984435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# optimizer/llm_runner.py
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
import torch
import torch.nn.functional as F

MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat"

class LLMRunner:
    def __init__(self):
        print("[LLM - Qwen 1.8B] Loading model...")
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        self.pipe = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer, top_p=None)

    def build_prompt(self, text: str) -> str:
        return (
            "本模型用于优化语音识别的转写结果。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。\n"
            "不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。\n"
            "示例:\n"
            "原句:你门现在就得开会了,别迟到了。\n"
            "修改:你们现在就得开会了,别迟到了。\n"
            "原句:在我们先辱了解这些任务之前。\n"
            "修改:在我们深入了解这些任务之前。\n"
            "原句:本章节将为你介绍音频数据的基本概念,包括剝形、採揚、綠和平、布圖。\n"
            "修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
            "原句:系统将进入留言模式,请耐行等待。\n"
            "修改:系统将进入留言模式,请耐心等待。\n"
            # "原句:我们要掌握流逝加载的基本操作。\n"
            # "修改:我们要掌握流式加载的基本操作。\n"
            f"原句:{text}\n"
            f"修改:"
        )


    def optimize(self, text: str) -> str:
        output = self.pipe(self.build_prompt(text), max_new_tokens=128, do_sample=False, top_p=None, return_full_text=False)[0]['generated_text']
        prediction = output.strip().splitlines()[0]
        return prediction

    def trace_until_generated(self, prompt: str, expected_prefix: str, top_k=30):
        print(f"\n[Trace] 正在生成,直到输出完整目标前缀:\n→ \"{expected_prefix}\"\n")

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        generated_ids = inputs["input_ids"]
        past_key_values = None

        generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        step_count = 0

        while step_count < 128:
            step_count += 1

            with torch.no_grad():
                outputs = self.model(
                    input_ids=generated_ids[:, -1:], 
                    past_key_values=past_key_values,
                    use_cache=True
                )
                logits = outputs.logits[:, -1, :]
                past_key_values = outputs.past_key_values

            probs = F.softmax(logits[0], dim=-1)
            next_token_id = torch.argmax(probs).unsqueeze(0).unsqueeze(0)
            generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

            decoded_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)

            # 只从 prompt 之后截取新生成部分
            new_generated = decoded_text[len(prompt):]
            print(f"\r[Step {step_count:03}] 生成:{new_generated}", end="")

            if expected_prefix in new_generated:
                print(f"\n\n✅ 已完整生成目标前缀!当前生成片段:\n\"{new_generated.strip()}\"")
                topk = torch.topk(probs, k=top_k)
                print(f"\n[Next Token Top-{top_k} 概率分布]")
                for token_id, prob in zip(topk.indices, topk.values):
                    token = self.tokenizer.decode([token_id])
                    print(f"{token}\t{prob.item():.4f}")
                return

        print(f"\n❌ 未在 128 步内生成目标 \"{expected_prefix}\"")






if __name__ == "__main__":
    runner = LLMRunner()
    test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
    result = runner.optimize(test_input)
    print("优化前:", test_input)
    print("优化后:", result)


    # prompt = runner.build_prompt(test_input)
    # target_prefix = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的"
    # runner.trace_until_generated(prompt, expected_prefix=target_prefix, top_k=10)