File size: 4,616 Bytes
0102e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os

import torch
import torchaudio
from transformers import AutoTokenizer, AutoModelForCausalLM

from tokenizer import StepAudioTokenizer
from tts import StepAudioTTS
from utils import load_audio, speech_adjust, volumn_adjust


class StepAudio:
    def __init__(self, tokenizer_path: str, tts_path: str, llm_path: str):
        # load optimus_ths for flash attention, make sure LD_LIBRARY_PATH has `nvidia/cuda_nvrtc/lib`
        # if not, please manually set LD_LIBRARY_PATH=xxx/python3.10/site-packages/nvidia/cuda_nvrtc/lib
        try:
            if torch.__version__ >= "2.5":
                torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so'))
            elif torch.__version__ >= "2.3":
                torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.3-cu121.cpython-310-x86_64-linux-gnu.so'))
            elif torch.__version__ >= "2.2":
                torch.ops.load_library(os.path.join(llm_path, 'lib/liboptimus_ths-torch2.2-cu121.cpython-310-x86_64-linux-gnu.so'))
            print("Load optimus_ths successfully and flash attn would be enabled")
        except Exception as err:
            print(f"Fail to load optimus_ths and flash attn is disabled: {err}")

        self.llm_tokenizer = AutoTokenizer.from_pretrained(
            llm_path, trust_remote_code=True
        )
        self.encoder = StepAudioTokenizer(tokenizer_path)
        self.decoder = StepAudioTTS(tts_path, self.encoder)
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )

    def __call__(
        self,
        messages: list,
        speaker_id: str,
        speed_ratio: float = 1.0,
        volumn_ratio: float = 1.0,
    ):
        text_with_audio = self.apply_chat_template(messages)
        token_ids = self.llm_tokenizer.encode(text_with_audio, return_tensors="pt")
        outputs = self.llm.generate(
            token_ids, max_new_tokens=2048, temperature=0.7, top_p=0.9, do_sample=True
        )
        output_token_ids = outputs[:, token_ids.shape[-1] : -1].tolist()[0]
        output_text = self.llm_tokenizer.decode(output_token_ids)
        output_audio, sr = self.decoder(output_text, speaker_id)
        if speed_ratio != 1.0:
            output_audio = speech_adjust(output_audio, sr, speed_ratio)
        if volumn_ratio != 1.0:
            output_audio = volumn_adjust(output_audio, volumn_ratio)
        return output_text, output_audio, sr

    def encode_audio(self, audio_path):
        audio_wav, sr = load_audio(audio_path)
        audio_tokens = self.encoder(audio_wav, sr)
        return audio_tokens

    def apply_chat_template(self, messages: list):
        text_with_audio = ""
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            if role == "user":
                role = "human"
            if isinstance(content, str):
                text_with_audio += f"<|BOT|>{role}\n{content}<|EOT|>"
            elif isinstance(content, dict):
                if content["type"] == "text":
                    text_with_audio += f"<|BOT|>{role}\n{content['text']}<|EOT|>"
                elif content["type"] == "audio":
                    audio_tokens = self.encode_audio(content["audio"])
                    text_with_audio += f"<|BOT|>{role}\n{audio_tokens}<|EOT|>"
            elif content is None:
                text_with_audio += f"<|BOT|>{role}\n"
            else:
                raise ValueError(f"Unsupported content type: {type(content)}")
        if not text_with_audio.endswith("<|BOT|>assistant\n"):
            text_with_audio += "<|BOT|>assistant\n"
        return text_with_audio


if __name__ == "__main__":
    model = StepAudio(
        encoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-encoder",
        decoder_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-decoder",
        llm_path="/mnt/ys-shai-jfs/open-step1o-audio/step1o-audio-v18",
    )

    text, audio, sr = model(
        [{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}],
        "闫雨婷",
    )
    torchaudio.save("output/output_e2e_tqta.wav", audio, sr)
    text, audio, sr = model(
        [
            {
                "role": "user",
                "content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"},
            }
        ],
        "闫雨婷",
    )
    torchaudio.save("output/output_e2e_aqta.wav", audio, sr)