File size: 15,511 Bytes
7395889
 
 
 
8550385
7395889
 
 
8550385
7395889
 
 
 
 
 
 
 
 
 
 
 
 
8550385
7395889
 
 
 
 
 
8550385
7395889
 
 
 
 
 
8550385
 
7395889
 
1b7eead
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7eead
 
 
 
7395889
 
c9543c7
 
 
 
 
 
7395889
 
 
 
 
 
 
 
 
c9543c7
7395889
c9543c7
7395889
 
 
 
 
 
1b7eead
9cbf886
7395889
69cb715
 
 
 
 
7395889
 
 
 
 
 
 
69cb715
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7eead
 
 
 
 
 
7395889
69cb715
 
 
 
 
 
 
 
 
 
7395889
ead11a7
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff9cbf
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7eead
 
 
7395889
69cb715
 
 
 
 
 
 
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
1b7eead
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69cb715
7395889
69cb715
8550385
7395889
 
 
 
1b7eead
 
 
 
 
 
7395889
4ff9cbf
c9543c7
 
 
 
 
 
 
4ff9cbf
 
 
 
7395889
c9543c7
 
 
4ff9cbf
 
 
 
7395889
4ff9cbf
 
7395889
4ff9cbf
c9543c7
 
69cb715
 
 
 
 
 
7395889
4ff9cbf
69cb715
 
 
7395889
 
4ff9cbf
7395889
 
 
 
 
 
 
 
1b7eead
7395889
 
 
 
 
 
1b7eead
7395889
 
4ff9cbf
7395889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69cb715
 
 
7395889
 
 
 
 
 
 
c9543c7
7395889
 
 
 
 
 
 
 
 
 
1b7eead
 
 
7395889
 
8550385
 
 
 
 
 
 
 
d3df147
8550385
7395889
8550385
7395889
9cbf886
8550385
 
7395889
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
import asyncio
import zlib
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING

import numpy as np
from hfendpoints.errors.config import UnsupportedModelArchitecture
from hfendpoints.openai import Context, run
from hfendpoints.openai.audio import (
    AutomaticSpeechRecognitionEndpoint,
    TranscriptionRequest,
    TranscriptionResponse,
    TranscriptionResponseKind,
    SegmentBuilder,
    Segment,
    Transcription,
    VerboseTranscription,
)
from librosa import load as load_audio, get_duration
from loguru import logger
from transformers import AutoConfig
from vllm import (
    AsyncEngineArgs,
    AsyncLLMEngine,
    SamplingParams,
)

from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures

if TYPE_CHECKING:
    from transformers import PreTrainedTokenizer
    from vllm import CompletionOutput, RequestOutput
    from vllm.sequence import SampleLogprobs

SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"]


def chunk_audio_with_duration(
        audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
) -> Sequence[np.ndarray]:
    """
    Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
    Chunks are evenly distributed except the last one which might be shorter
    :param audio: The mono timeseries waveform of the audio
    :param maximum_duration_sec: The maximum length, in seconds, for each chunk
    :param sampling_rate: The number of samples to represent one second of audio
    :return: List of numpy array representing the chunk
    """

    # We pad the input so that every chunk length is `max_duration_sec`
    max_duration_samples = sampling_rate * maximum_duration_sec
    padding = max_duration_samples - np.remainder(len(audio), max_duration_samples)
    audio = np.pad(audio, (0, padding), constant_values=0.0)
    return np.split(audio, len(audio) // max_duration_samples)


def compression_ratio(text: str) -> float:
    """

    :param text:
    :return:
    """
    text_bytes = text.encode("utf-8")
    return len(text_bytes) / len(zlib.compress(text_bytes))


def create_prompt(
        audio: np.ndarray,
        sampling_rate: int,
        language: int,
        timestamp_marker: int,
):
    """
    Generate the right prompt with the specific parameters to submit for inference over Whisper
    :param audio: PCM data containing audio signal representation
    :param sampling_rate: Number of samples in one second of audio
    :param language: Token id representing the language of the audio content
    :param timestamp_marker: Token id representing the temporal position within the audio content for this segment
    :return: Dictionary with all the prefilled value to call `generate`
    """
    return {
        "encoder_prompt": {
            "prompt": "",
            "multi_modal_data": {"audio": (audio, sampling_rate)},
        },
        "decoder_prompt": {
            "prompt_token_ids": [
                50258,
                language,
                50360,
                timestamp_marker,
            ]
        },
    }


def create_params(
        max_tokens: int, temperature: float, is_verbose: bool
) -> "SamplingParams":
    """
    Create sampling parameters to submit for inference through vLLM `generate`
    :param max_tokens: Maximum number of tokens to generate
    :param temperature: Sampling temperature for the softmax
    :param is_verbose: Flag indicating whether the response is required to contains verbose output
    :return: `SamplingParams`
    """
    return SamplingParams.from_optional(
        # output_kind=RequestOutputKind.FINAL_ONLY,  # Change if streaming
        max_tokens=max_tokens,
        skip_special_tokens=False,
        detokenize=False,
        temperature=temperature,
        logprobs=1 if is_verbose else None,
    )


def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
    """
    Aggregate the log probabilities over all generation steps by taking the log probability of the generated token
    :param logprobs: Iterable of log probabilities for all the generation steps
    :return: Averaged log probability as floating-point number
    """
    sum_logp = sum(next(iter(_step_.values())).logprob for _step_ in logprobs)
    return sum_logp / float(len(logprobs))


def process_chunk(
        tokenizer: "PreTrainedTokenizer",
        ids: np.ndarray,
        logprobs: "SampleLogprobs",
        request: TranscriptionRequest,
        segment_offset: int,
        timestamp_offset: int,
) -> Generator:
    """
    Decode a single transcribed audio chunk and generates all the segments associated
    :param tokenizer:
    :param ids:
    :param logprobs:
    :param request:
    :param segment_offset:
    :param timestamp_offset:
    :return:
    """
    # Some constants
    k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>")

    # Detect start of transcript token
    # sot_mask = ids == k_sot_token

    # Timestamps are expected to have ids greater than token_id(<|0.00|>)
    # We create a mask for all the potential tokens which are >= token_id(<|0.00|>)
    timestamps_mask = ids >= k_timestamp_token

    if np.any(timestamps_mask):
        # If we have a timestamp token, we need to check whether it's a final token or a final then the next
        is_single_ending_timestamp = np.array_equal(timestamps_mask[-2:], [False, True])

        # Iterate over timestamps
        timestamp_start, timestamp_end = 0.0, 0.0
        slice_start = 0

        for t, position in enumerate(np.flatnonzero(timestamps_mask)):
            timestamp = float(tokenizer.convert_ids_to_tokens([ids[position]])[0][2:-2])

            if t % 2 == 0:
                timestamp_end = timestamp

                # Retrieve segment info
                segment_ids = ids[slice_start:position]
                segment_text = tokenizer.decode(segment_ids)

                # Compute the avg_logprob
                avg_logprob = get_avg_logprob(logprobs) if logprobs else float("nan")

                # no-speech logprob
                # no_speech_token_id = lru_cache(tokenizer.convert_tokens_to_ids("<|nospeech|>"))
                # no_speech_logprob = logprobs[no_speech_token_id]

                # Materialize the segment in memory
                segment = (
                    SegmentBuilder()
                    .id(segment_offset + t)
                    .start(timestamp_offset + timestamp_start)
                    .end(timestamp_offset + timestamp_end)
                    .text(segment_text)
                    .tokens(segment_ids.tolist())
                    .temperature(request.temperature)
                    .avg_logprob(avg_logprob)
                    .compression_ratio(compression_ratio(segment_text))
                    .build()
                )

                yield segment, is_single_ending_timestamp

                # Update the start position
                slice_start = position
            else:
                timestamp_start = timestamp


def process_chunks(
        tokenizer: "PreTrainedTokenizer",
        chunks: List["RequestOutput"],
        request: TranscriptionRequest,
) -> Tuple[List[Segment], str]:
    """
    Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
    :param tokenizer: The tokenizer to use for decoding tokens
    :param chunks: Transcribed audio chunks
    :param request: Received request from the user
    :return: `Tuple[List[Segment], str]` holding all the consolidated segments along with full transcribed text
    """
    # k_nospeech_token = tokenizer.convert_tokens_to_ids("<|nospeech|>")
    # k_sot_token = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
    materialized_segments, materialized_segments_tokens_acc = [], []

    # Iterate over segments
    for idx, chunk in enumerate(chunks):
        time_offset = idx * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
        segment_offset = len(materialized_segments)

        generation: "CompletionOutput" = chunk.outputs[-1]
        ids: np.ndarray = np.asarray(generation.token_ids)
        logprobs = generation.logprobs

        for segment, _is_continuation in process_chunk(
                tokenizer, ids, logprobs, request, segment_offset, time_offset
        ):
            materialized_segments.append(segment)

        # Accumulate the tokens for full decoding
        materialized_segments_tokens_acc += generation.token_ids

    text = tokenizer.decode(
        materialized_segments_tokens_acc,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    return materialized_segments, text


class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
    WHISPER_SEGMENT_DURATION_SEC = 30
    WHISPER_SAMPLING_RATE = 22050

    __slots__ = ("_engine",)

    def __init__(self, model_id_or_path: str):
        super().__init__(model_id_or_path)

        self._engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(
                model_id_or_path,
                task="transcription",
                device="auto",
                dtype="bfloat16",
                kv_cache_dtype="fp8",
                enforce_eager=False,
                enable_prefix_caching=True,
                max_logprobs=1,  # TODO(mfuntowicz) : Set from config?
                disable_log_requests=True,
            )
        )

    async def transcribe(
            self,
            ctx: Context,
            request: TranscriptionRequest,
            tokenizer: "PreTrainedTokenizer",
            audio_chunks: Iterable[np.ndarray],
            params: "SamplingParams",
    ) -> (List[Segment], str):
        async def __agenerate__(request_id: str, prompt, params):
            """
            Helper method to unroll asynchronous generator and return the last element
            :param request_id: Unique identifier for this request
            :param prompt: The prompt to submit for inference on vLLM through `generate(...)`
            :param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)`
            :return: `CompletionOutput`
            """
            # Submit for inference on the segment & keep track of the background task
            async for step in self._engine.generate(prompt, params, request_id):
                pass
            return step

        # Wrap tokenizer results with LRU cache to avoid vocabulary lookup
        convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)

        coro_handles = []
        for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
            # Generate suffixed request-id to submit and identify through vLLM scheduler
            request_id = f"{ctx.request_id}-{audio_chunk_id}"

            # Compute the starting time of the chunk as each consecutive chunk represents 30s worth of audio
            timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC

            # Compute initial prompt for the segment
            is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
            language = convert_tokens_to_ids(f"<|{request.language}|>")
            timestamp = convert_tokens_to_ids(
                f"<|0.00|>" if is_verbose else "<|notimestamps|>"
            )
            prompt = create_prompt(
                audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp
            )

            # Submit the task
            coro_handles += [
                asyncio.create_task(__agenerate__(request_id, prompt, params))
            ]

        # Wait for all the segment to complete
        text_chunks = await asyncio.gather(*coro_handles)

        # if not is_cancelled.cancel_called:
        segments, text = await asyncio.get_event_loop().run_in_executor(
            None, process_chunks, tokenizer, text_chunks, request
        )
        return segments, text

    async def __call__(
            self, request: TranscriptionRequest, ctx: Context
    ) -> TranscriptionResponse:
        with logger.contextualize(request_id=ctx.request_id):
            with memoryview(request) as audio:

                # Check if we need to enable the verbose path
                is_verbose = (
                        request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
                )

                # Retrieve the tokenizer and model config asynchronously while we decode audio
                tokenizer = asyncio.create_task(self._engine.get_tokenizer())
                model_config = asyncio.create_task(self._engine.get_model_config())

                # Decode audio from librosa (for now)
                # TODO: Use native (Rust provided) decoding
                (waveform, sampling) = load_audio(BytesIO(audio), sr=22050, mono=True)
                logger.debug(
                    f"Successfully decoded {len(waveform)} bytes PCM audio chunk"
                )

                # Create parameters
                max_tokens = (await model_config).max_model_len - 4
                params = create_params(max_tokens, request.temperature, is_verbose)

                # Chunk audio in pieces
                audio_chunks = chunk_audio_with_duration(
                    waveform,
                    maximum_duration_sec=WhisperHandler.WHISPER_SEGMENT_DURATION_SEC,
                    sampling_rate=WhisperHandler.WHISPER_SAMPLING_RATE,
                )

                # Submit audio pieces to the batcher and gather them all
                segments, text = await self.transcribe(
                    ctx, request, await tokenizer, audio_chunks, params
                )

                match request.response_kind:
                    case TranscriptionResponseKind.VERBOSE_JSON:
                        return TranscriptionResponse.verbose(
                            VerboseTranscription(
                                text=text,
                                duration=get_duration(y=waveform, sr=sampling),
                                language=request.language,
                                segments=segments,
                                # word=None
                            )
                        )
                    case TranscriptionResponseKind.JSON:
                        return TranscriptionResponse.json(text)

                    case TranscriptionResponseKind.TEXT:
                        return TranscriptionResponse.text(text)

                # I don't forsee any case this would happen but at least we are safe
                raise ValueError(f"Invalid response_kind ({request.response_kind})")


def entrypoint():
    # Retrieve endpoint configuration
    endpoint_config = EndpointConfig.from_env()

    # Ensure the model is compatible is pre-downloaded
    if (model_local_path := Path(endpoint_config.model_id)).exists():
        if (config_local_path := (model_local_path / "config.json")).exists():
            config = AutoConfig.from_pretrained(config_local_path)
            ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES)

    # Initialize the endpoint
    endpoint = AutomaticSpeechRecognitionEndpoint(
        WhisperHandler(endpoint_config.model_id)
    )

    # Serve the model
    run(endpoint, endpoint_config.interface, endpoint_config.port)


if __name__ == "__main__":
    entrypoint()