File size: 25,702 Bytes
698b586
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
# Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Processor class for Phi4Multimodal
"""

import re
import os
import requests
import base64
from io import BytesIO
from typing import List, Optional, Union, TypedDict

import librosa
import numpy as np
import PIL.Image

from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
from transformers.tokenization_utils_base import TextInput
from transformers.utils import logging


from .feature_extraction_phi4_multimodal import AudioInput


logger = logging.get_logger(__name__)


class ChatTemplateLoadKwargs(TypedDict, total=False):
    """
    Keyword arguments used to load multimodal data in processor chat templates.

    num_frames (`int`, *optional*):
        Number of frames to sample uniformly. If not passed, the whole video is loaded.
    video_load_backend (`str`, *optional*, defaults to `"pyav"`):
        The backend to use when loading the video which will be used only when there are videos in the conversation.
        Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend
        that supports all types of sources to load from.
    video_fps (`int`, *optional*):
        Number of frames to sample per second. Should be passed only when `num_frames=None`.
        If not specified and `num_frames==None`, all frames are sampled.
    sample_indices_fn (`Callable`, *optional*):
            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
            If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
            The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
            indices at which the video should be sampled. For example:

            def sample_indices_fn(num_frames, fps, metadata, **kwargs):
                # add you sampling logic here ...
                return np.linspace(start_idx, end_idx, num_frames, dtype=int)
    """

    num_frames: Optional[int] = None
    video_load_backend: Optional[str] = "pyav"
    video_fps: Optional[int] = None
    sampling_rate: Optional[int] = 16_000
    load_audio_from_video: Optional[bool] = False


class AllKwargsForChatTemplate(
    TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
):
    processor_kwargs: ProcessingKwargs = {
        **ProcessingKwargs.__annotations__,
    }
    mm_load_kwargs: ChatTemplateLoadKwargs = {
        **TextKwargs.__annotations__,
    }
    template_kwargs: ProcessorChatTemplateKwargs = {
        **ProcessorChatTemplateKwargs.__annotations__,
    }


class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False):
    _defaults = {
        "audio_kwargs": {
            "device": "cpu",
        },
    }


def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
    """
    Loads `audio` to an np.ndarray object.

    Args:
        audio (`str` or `np.ndarray`):
            The audio to be laoded to the numpy array format.
        sampling_rate (`int`, *optional*, defaults to 16000):
            The samlping rate to be used when loading the audio. It should be same as the
            sampling rate the model you will be using further was trained with.
        timeout (`float`, *optional*):
            The timeout value in seconds for the URL request.

    Returns:
        `np.ndarray`: A numpy artay representing the audio.
    """

    if isinstance(audio, str):
        # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
        if audio.startswith("http://") or audio.startswith("https://"):
            audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
        elif os.path.isfile(audio):
            audio = librosa.load(audio, sr=sampling_rate)[0]
    elif isinstance(audio, np.ndarray):
        audio = audio
    else:
        raise TypeError(
            "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
        )
    return audio


def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
    """
    Loads `image` to a PIL Image.

    Args:
        image (`str` or `PIL.Image.Image`):
            The image to convert to the PIL Image format.
        timeout (`float`, *optional*):
            The timeout value in seconds for the URL request.

    Returns:
        `PIL.Image.Image`: A PIL Image.
    """
    if isinstance(image, str):
        if image.startswith("http://") or image.startswith("https://"):
            # We need to actually check for a real protocol, otherwise it's impossible to use a local file
            # like http_huggingface_co.png
            image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
        elif os.path.isfile(image):
            image = PIL.Image.open(image)
        else:
            if image.startswith("data:image/"):
                image = image.split(",")[1]

            # Try to load as base64
            try:
                b64 = base64.decodebytes(image.encode())
                image = PIL.Image.open(BytesIO(b64))
            except Exception as e:
                raise ValueError(
                    f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
                )
    elif isinstance(image, PIL.Image.Image):
        image = image
    else:
        raise TypeError(
            "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
        )
    image = PIL.ImageOps.exif_transpose(image)
    image = image.convert("RGB")
    return image


class Phi4MultimodalProcessor(ProcessorMixin):
    r"""
    Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor.

    [`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the
    [`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information.

    Args:
        image_processor (`Phi4MultimodalImageProcessorFast`):
            The image processor to use for images.
        audio_processor (`Phi4MultimodalFeatureExtractor`):
            The audio processor to use for audio inputs.
        tokenizer (`GPT2TokenizerFast`):
            The tokenizer to use for text.
        fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`):
            The fake image token pattern.
        fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`):
            The fake audio token pattern.
    """

    attributes = ["image_processor", "audio_processor", "tokenizer"]
    tokenizer_class = "GPT2TokenizerFast"
    image_processor_class = "AutoImageProcessor"
    audio_processor_class = "AutoFeatureExtractor"
    valid_kwargs = ["chat_template"]

    def __init__(
        self,
        image_processor,
        audio_processor,
        tokenizer,
        **kwargs,
    ):
        self.image_token = tokenizer.image_token
        self.image_token_id = tokenizer.image_token_id
        self.audio_token = tokenizer.audio_token
        self.audio_token_id = tokenizer.audio_token_id
        super().__init__(image_processor, audio_processor, tokenizer, **kwargs)

    def __call__(
        self,
        text: Union[TextInput, List[TextInput]],
        images: Optional[ImageInput] = None,
        audio: Optional[AudioInput] = None,
        **kwargs: Unpack[ProcessingKwargs],
    ) -> BatchFeature:
        """
        Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text`
        and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode
        the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
        Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring
        of the above two methods for more information.

        Args:
            text (`str`, `List[str]`, `List[List[str]]`):
                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
            images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
                The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
                tensor. Both channels-first and channels-last formats are supported.
            audio (`List[Union[np.ndarray, torch.Tensor]]`):
                List of the audios to be prepared.

        Returns:
            [`BatchFeature`]: A [`BatchFeature`] with the following fields:

            - **input_ids** -- List of token ids to be fed to a model.
            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
            - **input_image_embeds** -- Pixel values to be fed to a model.
            - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`.
            - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`.
            - **input_audio_embeds** -- Audio embeddings to be fed to a model.
            - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`.
        """

        output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs)
        image_kwargs = output_kwargs["images_kwargs"]
        audio_kwargs = output_kwargs["audio_kwargs"]

        image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {}
        audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {}

        # We pop here for images as we don't need it later
        num_img_tokens = image_inputs.pop("num_img_tokens", [])
        audio_embed_sizes = audio_inputs.get("audio_embed_sizes", [])

        # Replace certain special tokens for compatibility
        if isinstance(text, str):
            text = [text]
        elif not isinstance(text, list) and not isinstance(text[0], str):
            raise ValueError("Invalid input text. Please provide a string, or a list of strings")

        image_token = self.tokenizer.image_token
        audio_token = self.tokenizer.audio_token

        # Check that the number of special tokens is sound
        concatenated_prompt = "".join(text)
        if concatenated_prompt.count(image_token) != len(num_img_tokens):
            raise ValueError(
                "You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ",
                f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images",
            )
        if concatenated_prompt.count(audio_token) != len(audio_embed_sizes):
            raise ValueError(
                "You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. "
                f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios"
            )

        # Add appropriate number of image/audio tokens (note that the count of replacement is dynamic)
        image_count_iter = iter(num_img_tokens)
        audio_count_iter = iter(audio_embed_sizes)
        processed_text = [
            re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text
        ]
        processed_text = [
            re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text
        ]

        return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
        text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"])
        self._check_special_mm_tokens(processed_text, text_inputs, modalities=["image"])

        # prepare batch feature
        data = {
            **text_inputs,
            **image_inputs,
            **audio_inputs,
        }

        return BatchFeature(data=data, tensor_type=return_tensors)

    def batch_decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
        refer to the docstring of this method for more information.
        """
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        """
        This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
        the docstring of this method for more information.
        """
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        tokenizer_input_names = self.tokenizer.model_input_names
        image_processor_input_names = self.image_processor.model_input_names
        audio_processor_input_names = self.audio_processor.model_input_names
        return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names))
    
    def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]):
        """
        Checks that number of special tokens in text and processed text is same. The count can be different
        if tokenized text was truncated, leading to issues in model code.
        """
        for modality in modalities:
            token_str = getattr(self, f"{modality}_token")
            token_id = getattr(self, f"{modality}_token_id")
            ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]]
            text_count = [sample.count(token_str) for sample in text]

            if ids_count != text_count:
                raise ValueError(
                    f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. "
                    "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`."
                )
            
    def apply_chat_template(
        self,
        conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
        chat_template: Optional[str] = None,
        **kwargs: Unpack[AllKwargsForChatTemplate],
    ) -> str:
        """
        Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
        conversations to turn them into a single tokenizable string.

        The input is expected to be in the following format, where each message content is a list consisting of text and
        optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
        `pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.

        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"},
                    {"type": "text", "text": "Please describe this image in detail."},
                ],
            },
        ]

        Args:
            conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`):
                The conversation to format.
            chat_template (`Optional[str]`, *optional*):
                The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
                chat template is used.
        """

        if chat_template is None:
            if isinstance(self.chat_template, dict) and "default" in self.chat_template:
                chat_template = self.chat_template["default"]
            elif isinstance(self.chat_template, dict):
                raise ValueError(
                    'The processor has multiple chat templates but none of them are named "default". You need to specify'
                    " which one to use by passing the `chat_template` argument. Available templates are: "
                    f"{', '.join(self.chat_template.keys())}"
                )
            elif self.chat_template is not None:
                chat_template = self.chat_template
            else:
                raise ValueError(
                    "Cannot use apply_chat_template because this processor does not have a chat template."
                )
        else:
            if isinstance(self.chat_template, dict) and chat_template in self.chat_template:
                # It's the name of a template, not a full template string
                chat_template = self.chat_template[chat_template]
            else:
                # It's a template string, render it directly
                chat_template = chat_template

        # Fill sets of kwargs that should be used by different parts of template
        processed_kwargs = {
            "mm_load_kwargs": {},
            "template_kwargs": {},
        }

        for kwarg_type in processed_kwargs:
            for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys():
                kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
                default_value = getattr(kwarg_type_defaults, key, None)
                value = kwargs.pop(key, default_value)
                if value is not None and not isinstance(value, dict):
                    processed_kwargs[kwarg_type][key] = value

        if isinstance(conversation, (list, tuple)) and (
            isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
        ):
            is_batched = True
            conversations = conversation
        else:
            is_batched = False
            conversations = [conversation]

        tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
        return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
        mm_load_kwargs = processed_kwargs["mm_load_kwargs"]

        if tokenize:
            batch_images, batch_videos = [], []
            batch_audios = []
            batch_video_metadata = []
            for conversation in conversations:
                images, videos = [], []
                video_metadata = []
                for message in conversation:
                    visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
                    audio_fnames = [
                        content[key]
                        for content in message["content"]
                        for key in ["audio", "url", "path"]
                        if key in content and content["type"] == "audio"
                    ]
                    image_fnames = [
                        vision_info[key]
                        for vision_info in visuals
                        for key in ["image", "url", "path", "base64"]
                        if key in vision_info and vision_info["type"] == "image"
                    ]
                    video_fnames = [
                        vision_info[key]
                        for vision_info in visuals
                        for key in ["video", "url", "path"]
                        if key in vision_info and vision_info["type"] == "video"
                    ]

                    for fname in image_fnames:
                        images.append(load_image(fname))

                    # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
                    if not mm_load_kwargs["load_audio_from_video"]:
                        for fname in audio_fnames:
                            batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
                    else:
                        for fname in video_fnames:
                            batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))

                    for fname in video_fnames:
                        if isinstance(fname, (list, tuple)) and isinstance(fname[0], str):
                            video = [np.array(load_image(image_fname)) for image_fname in fname]
                            # create a 4D video because `load_video` always returns a 4D array
                            video = np.stack(video)
                            metadata = None
                            logger.warning(
                                "When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. "
                                "If your model uses this metadata during processing, please load the whole video and let the model sample frames instead."
                            )
                        else:
                            # TODO: raushan, should be `self.video_processor.load_video_for_model` when API is added
                            video, metadata = self._load_video_for_model(
                                fname,
                                num_frames=mm_load_kwargs.get("num_frames", None),
                                fps=mm_load_kwargs.get("video_fps", None),
                                backend=mm_load_kwargs["video_load_backend"],
                                **kwargs,
                            )
                        videos.append(video)
                        video_metadata.append(metadata)

                # Currently all processors can accept nested list of batches, but not flat list of visuals
                # So we'll make a batched list of images and let the processor handle it
                if images:
                    batch_images.append(images)
                if videos:
                    batch_videos.append(videos)
                    batch_video_metadata.append(video_metadata)

            # Process conversation with video/image information if needed. Then convert into a prompt using Jinja template
            conversations = self._process_messages_for_chat_template(
                conversations,
                batch_images=batch_images,
                batch_videos=batch_videos,
                batch_video_metadata=batch_video_metadata,
                **processed_kwargs["mm_load_kwargs"],
            )

        prompt = self.tokenizer.apply_chat_template(
            conversations,
            chat_template=chat_template,
            tokenize=False,
            return_dict=False,
            **processed_kwargs["template_kwargs"],
        )

        if not is_batched:
            prompt = prompt[0]

        if tokenize:
            # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
            # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
            # and pass it to the processor. Users thus never worried about special tokens relying on processor handling
            # everything internally. The below line is to keep BC for that and be able to work with model that have
            # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
            # without actionable solution for users
            single_prompt = prompt[0] if is_batched else prompt
            if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
                kwargs["add_special_tokens"] = False

            out = self(
                text=prompt,
                images=batch_images if batch_images else None,
                videos=batch_videos if batch_videos else None,
                audio=batch_audios if batch_audios else None,
                **kwargs,
            )
            if return_dict:
                return out
            else:
                return out["input_ids"]
        return prompt


__all__ = ["Phi4MultimodalProcessor"]


Phi4MultimodalProcessor.register_for_auto_class()