|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Processor class for Phi4Multimodal |
|
""" |
|
|
|
import re |
|
import os |
|
import requests |
|
import base64 |
|
from io import BytesIO |
|
from typing import List, Optional, Union, TypedDict |
|
|
|
import librosa |
|
import numpy as np |
|
import PIL.Image |
|
|
|
from transformers.image_processing_utils import BatchFeature |
|
from transformers.image_utils import ImageInput |
|
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs |
|
from transformers.tokenization_utils_base import TextInput |
|
from transformers.utils import logging |
|
|
|
|
|
from .feature_extraction_phi4_multimodal import AudioInput |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ChatTemplateLoadKwargs(TypedDict, total=False): |
|
""" |
|
Keyword arguments used to load multimodal data in processor chat templates. |
|
|
|
num_frames (`int`, *optional*): |
|
Number of frames to sample uniformly. If not passed, the whole video is loaded. |
|
video_load_backend (`str`, *optional*, defaults to `"pyav"`): |
|
The backend to use when loading the video which will be used only when there are videos in the conversation. |
|
Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav" because it is the only backend |
|
that supports all types of sources to load from. |
|
video_fps (`int`, *optional*): |
|
Number of frames to sample per second. Should be passed only when `num_frames=None`. |
|
If not specified and `num_frames==None`, all frames are sampled. |
|
sample_indices_fn (`Callable`, *optional*): |
|
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using |
|
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`. |
|
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args. |
|
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid |
|
indices at which the video should be sampled. For example: |
|
|
|
def sample_indices_fn(num_frames, fps, metadata, **kwargs): |
|
# add you sampling logic here ... |
|
return np.linspace(start_idx, end_idx, num_frames, dtype=int) |
|
""" |
|
|
|
num_frames: Optional[int] = None |
|
video_load_backend: Optional[str] = "pyav" |
|
video_fps: Optional[int] = None |
|
sampling_rate: Optional[int] = 16_000 |
|
load_audio_from_video: Optional[bool] = False |
|
|
|
|
|
class AllKwargsForChatTemplate( |
|
TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs |
|
): |
|
processor_kwargs: ProcessingKwargs = { |
|
**ProcessingKwargs.__annotations__, |
|
} |
|
mm_load_kwargs: ChatTemplateLoadKwargs = { |
|
**TextKwargs.__annotations__, |
|
} |
|
template_kwargs: ProcessorChatTemplateKwargs = { |
|
**ProcessorChatTemplateKwargs.__annotations__, |
|
} |
|
|
|
|
|
class Phi4MultimodalProcessorKwargs(ProcessingKwargs, total=False): |
|
_defaults = { |
|
"audio_kwargs": { |
|
"device": "cpu", |
|
}, |
|
} |
|
|
|
|
|
def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray: |
|
""" |
|
Loads `audio` to an np.ndarray object. |
|
|
|
Args: |
|
audio (`str` or `np.ndarray`): |
|
The audio to be laoded to the numpy array format. |
|
sampling_rate (`int`, *optional*, defaults to 16000): |
|
The samlping rate to be used when loading the audio. It should be same as the |
|
sampling rate the model you will be using further was trained with. |
|
timeout (`float`, *optional*): |
|
The timeout value in seconds for the URL request. |
|
|
|
Returns: |
|
`np.ndarray`: A numpy artay representing the audio. |
|
""" |
|
|
|
if isinstance(audio, str): |
|
|
|
if audio.startswith("http://") or audio.startswith("https://"): |
|
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0] |
|
elif os.path.isfile(audio): |
|
audio = librosa.load(audio, sr=sampling_rate)[0] |
|
elif isinstance(audio, np.ndarray): |
|
audio = audio |
|
else: |
|
raise TypeError( |
|
"Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array." |
|
) |
|
return audio |
|
|
|
|
|
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image": |
|
""" |
|
Loads `image` to a PIL Image. |
|
|
|
Args: |
|
image (`str` or `PIL.Image.Image`): |
|
The image to convert to the PIL Image format. |
|
timeout (`float`, *optional*): |
|
The timeout value in seconds for the URL request. |
|
|
|
Returns: |
|
`PIL.Image.Image`: A PIL Image. |
|
""" |
|
if isinstance(image, str): |
|
if image.startswith("http://") or image.startswith("https://"): |
|
|
|
|
|
image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content)) |
|
elif os.path.isfile(image): |
|
image = PIL.Image.open(image) |
|
else: |
|
if image.startswith("data:image/"): |
|
image = image.split(",")[1] |
|
|
|
|
|
try: |
|
b64 = base64.decodebytes(image.encode()) |
|
image = PIL.Image.open(BytesIO(b64)) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" |
|
) |
|
elif isinstance(image, PIL.Image.Image): |
|
image = image |
|
else: |
|
raise TypeError( |
|
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." |
|
) |
|
image = PIL.ImageOps.exif_transpose(image) |
|
image = image.convert("RGB") |
|
return image |
|
|
|
|
|
class Phi4MultimodalProcessor(ProcessorMixin): |
|
r""" |
|
Constructs a Phi4Multimodal processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor. |
|
|
|
[`Phi4MultimodalProcessor`] offers all the functionalities of [`Phi4MultimodalImageProcessorFast`] and [`GPT2Tokenizer`]. See the |
|
[`~Phi4MultimodalProcessor.__call__`] and [`~Phi4MultimodalProcessor.decode`] for more information. |
|
|
|
Args: |
|
image_processor (`Phi4MultimodalImageProcessorFast`): |
|
The image processor to use for images. |
|
audio_processor (`Phi4MultimodalFeatureExtractor`): |
|
The audio processor to use for audio inputs. |
|
tokenizer (`GPT2TokenizerFast`): |
|
The tokenizer to use for text. |
|
fake_image_token_pattern (`str`, *optional*, defaults to `r"<\|image_\d+\|>"`): |
|
The fake image token pattern. |
|
fake_audio_token_pattern (`str`, *optional*, defaults to `r"<\|audio_\d+\|>"`): |
|
The fake audio token pattern. |
|
""" |
|
|
|
attributes = ["image_processor", "audio_processor", "tokenizer"] |
|
tokenizer_class = "GPT2TokenizerFast" |
|
image_processor_class = "AutoImageProcessor" |
|
audio_processor_class = "AutoFeatureExtractor" |
|
valid_kwargs = ["chat_template"] |
|
|
|
def __init__( |
|
self, |
|
image_processor, |
|
audio_processor, |
|
tokenizer, |
|
**kwargs, |
|
): |
|
self.image_token = tokenizer.image_token |
|
self.image_token_id = tokenizer.image_token_id |
|
self.audio_token = tokenizer.audio_token |
|
self.audio_token_id = tokenizer.audio_token_id |
|
super().__init__(image_processor, audio_processor, tokenizer, **kwargs) |
|
|
|
def __call__( |
|
self, |
|
text: Union[TextInput, List[TextInput]], |
|
images: Optional[ImageInput] = None, |
|
audio: Optional[AudioInput] = None, |
|
**kwargs: Unpack[ProcessingKwargs], |
|
) -> BatchFeature: |
|
""" |
|
Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text` |
|
and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode |
|
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to |
|
Phi4MultimodalImageProcessorFast's [`~Phi4MultimodalImageProcessorFast.__call__`] if `images` is not `None`. Please refer to the doctsring |
|
of the above two methods for more information. |
|
|
|
Args: |
|
text (`str`, `List[str]`, `List[List[str]]`): |
|
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings |
|
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set |
|
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences). |
|
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): |
|
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch |
|
tensor. Both channels-first and channels-last formats are supported. |
|
audio (`List[Union[np.ndarray, torch.Tensor]]`): |
|
List of the audios to be prepared. |
|
|
|
Returns: |
|
[`BatchFeature`]: A [`BatchFeature`] with the following fields: |
|
|
|
- **input_ids** -- List of token ids to be fed to a model. |
|
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model. |
|
- **input_image_embeds** -- Pixel values to be fed to a model. |
|
- **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`. |
|
- **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`. |
|
- **input_audio_embeds** -- Audio embeddings to be fed to a model. |
|
- **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`. |
|
""" |
|
|
|
output_kwargs = self._merge_kwargs(Phi4MultimodalProcessorKwargs, self.tokenizer.init_kwargs, **kwargs) |
|
image_kwargs = output_kwargs["images_kwargs"] |
|
audio_kwargs = output_kwargs["audio_kwargs"] |
|
|
|
image_inputs = self.image_processor(images, **image_kwargs) if images is not None else {} |
|
audio_inputs = self.audio_processor(audio, **audio_kwargs) if audio is not None else {} |
|
|
|
|
|
num_img_tokens = image_inputs.pop("num_img_tokens", []) |
|
audio_embed_sizes = audio_inputs.get("audio_embed_sizes", []) |
|
|
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
elif not isinstance(text, list) and not isinstance(text[0], str): |
|
raise ValueError("Invalid input text. Please provide a string, or a list of strings") |
|
|
|
image_token = self.tokenizer.image_token |
|
audio_token = self.tokenizer.audio_token |
|
|
|
|
|
concatenated_prompt = "".join(text) |
|
if concatenated_prompt.count(image_token) != len(num_img_tokens): |
|
raise ValueError( |
|
"You should add as much image tokens `<|image|>` in your prompt as you pass `images` to the processor. ", |
|
f"Input contains {concatenated_prompt.count(image_token)} tokens != {len(num_img_tokens)} images", |
|
) |
|
if concatenated_prompt.count(audio_token) != len(audio_embed_sizes): |
|
raise ValueError( |
|
"You should add as much audio tokens `<|audio|>` in your prompt as you pass `audios` to the processor. " |
|
f"Input contains {concatenated_prompt.count(audio_token)} tokens != {len(audio_embed_sizes)} audios" |
|
) |
|
|
|
|
|
image_count_iter = iter(num_img_tokens) |
|
audio_count_iter = iter(audio_embed_sizes) |
|
processed_text = [ |
|
re.sub(re.escape(image_token), lambda _: image_token * next(image_count_iter), t) for t in text |
|
] |
|
processed_text = [ |
|
re.sub(re.escape(audio_token), lambda _: audio_token * next(audio_count_iter), t) for t in processed_text |
|
] |
|
|
|
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) |
|
text_inputs = self.tokenizer(processed_text, **output_kwargs["text_kwargs"]) |
|
self._check_special_mm_tokens(processed_text, text_inputs, modalities=["image"]) |
|
|
|
|
|
data = { |
|
**text_inputs, |
|
**image_inputs, |
|
**audio_inputs, |
|
} |
|
|
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please |
|
refer to the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
""" |
|
This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to |
|
the docstring of this method for more information. |
|
""" |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
tokenizer_input_names = self.tokenizer.model_input_names |
|
image_processor_input_names = self.image_processor.model_input_names |
|
audio_processor_input_names = self.audio_processor.model_input_names |
|
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names)) |
|
|
|
def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]): |
|
""" |
|
Checks that number of special tokens in text and processed text is same. The count can be different |
|
if tokenized text was truncated, leading to issues in model code. |
|
""" |
|
for modality in modalities: |
|
token_str = getattr(self, f"{modality}_token") |
|
token_id = getattr(self, f"{modality}_token_id") |
|
ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]] |
|
text_count = [sample.count(token_str) for sample in text] |
|
|
|
if ids_count != text_count: |
|
raise ValueError( |
|
f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. " |
|
"Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`." |
|
) |
|
|
|
def apply_chat_template( |
|
self, |
|
conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]], |
|
chat_template: Optional[str] = None, |
|
**kwargs: Unpack[AllKwargsForChatTemplate], |
|
) -> str: |
|
""" |
|
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input |
|
conversations to turn them into a single tokenizable string. |
|
|
|
The input is expected to be in the following format, where each message content is a list consisting of text and |
|
optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form |
|
`pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text. |
|
|
|
conversation = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": "https://www.ilankelman.org/stopsigns/australia.jpg"}, |
|
{"type": "text", "text": "Please describe this image in detail."}, |
|
], |
|
}, |
|
] |
|
|
|
Args: |
|
conversation (`Union[List[Dict, [str, str]], List[List[Dict[str, str]]]]`): |
|
The conversation to format. |
|
chat_template (`Optional[str]`, *optional*): |
|
The Jinja template to use for formatting the conversation. If not provided, the tokenizer's |
|
chat template is used. |
|
""" |
|
|
|
if chat_template is None: |
|
if isinstance(self.chat_template, dict) and "default" in self.chat_template: |
|
chat_template = self.chat_template["default"] |
|
elif isinstance(self.chat_template, dict): |
|
raise ValueError( |
|
'The processor has multiple chat templates but none of them are named "default". You need to specify' |
|
" which one to use by passing the `chat_template` argument. Available templates are: " |
|
f"{', '.join(self.chat_template.keys())}" |
|
) |
|
elif self.chat_template is not None: |
|
chat_template = self.chat_template |
|
else: |
|
raise ValueError( |
|
"Cannot use apply_chat_template because this processor does not have a chat template." |
|
) |
|
else: |
|
if isinstance(self.chat_template, dict) and chat_template in self.chat_template: |
|
|
|
chat_template = self.chat_template[chat_template] |
|
else: |
|
|
|
chat_template = chat_template |
|
|
|
|
|
processed_kwargs = { |
|
"mm_load_kwargs": {}, |
|
"template_kwargs": {}, |
|
} |
|
|
|
for kwarg_type in processed_kwargs: |
|
for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__.keys(): |
|
kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type] |
|
default_value = getattr(kwarg_type_defaults, key, None) |
|
value = kwargs.pop(key, default_value) |
|
if value is not None and not isinstance(value, dict): |
|
processed_kwargs[kwarg_type][key] = value |
|
|
|
if isinstance(conversation, (list, tuple)) and ( |
|
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content") |
|
): |
|
is_batched = True |
|
conversations = conversation |
|
else: |
|
is_batched = False |
|
conversations = [conversation] |
|
|
|
tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False) |
|
return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False) |
|
mm_load_kwargs = processed_kwargs["mm_load_kwargs"] |
|
|
|
if tokenize: |
|
batch_images, batch_videos = [], [] |
|
batch_audios = [] |
|
batch_video_metadata = [] |
|
for conversation in conversations: |
|
images, videos = [], [] |
|
video_metadata = [] |
|
for message in conversation: |
|
visuals = [content for content in message["content"] if content["type"] in ["image", "video"]] |
|
audio_fnames = [ |
|
content[key] |
|
for content in message["content"] |
|
for key in ["audio", "url", "path"] |
|
if key in content and content["type"] == "audio" |
|
] |
|
image_fnames = [ |
|
vision_info[key] |
|
for vision_info in visuals |
|
for key in ["image", "url", "path", "base64"] |
|
if key in vision_info and vision_info["type"] == "image" |
|
] |
|
video_fnames = [ |
|
vision_info[key] |
|
for vision_info in visuals |
|
for key in ["video", "url", "path"] |
|
if key in vision_info and vision_info["type"] == "video" |
|
] |
|
|
|
for fname in image_fnames: |
|
images.append(load_image(fname)) |
|
|
|
|
|
if not mm_load_kwargs["load_audio_from_video"]: |
|
for fname in audio_fnames: |
|
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) |
|
else: |
|
for fname in video_fnames: |
|
batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"])) |
|
|
|
for fname in video_fnames: |
|
if isinstance(fname, (list, tuple)) and isinstance(fname[0], str): |
|
video = [np.array(load_image(image_fname)) for image_fname in fname] |
|
|
|
video = np.stack(video) |
|
metadata = None |
|
logger.warning( |
|
"When loading the video from list of images, we cannot infer metadata such as `fps` or `duration`. " |
|
"If your model uses this metadata during processing, please load the whole video and let the model sample frames instead." |
|
) |
|
else: |
|
|
|
video, metadata = self._load_video_for_model( |
|
fname, |
|
num_frames=mm_load_kwargs.get("num_frames", None), |
|
fps=mm_load_kwargs.get("video_fps", None), |
|
backend=mm_load_kwargs["video_load_backend"], |
|
**kwargs, |
|
) |
|
videos.append(video) |
|
video_metadata.append(metadata) |
|
|
|
|
|
|
|
if images: |
|
batch_images.append(images) |
|
if videos: |
|
batch_videos.append(videos) |
|
batch_video_metadata.append(video_metadata) |
|
|
|
|
|
conversations = self._process_messages_for_chat_template( |
|
conversations, |
|
batch_images=batch_images, |
|
batch_videos=batch_videos, |
|
batch_video_metadata=batch_video_metadata, |
|
**processed_kwargs["mm_load_kwargs"], |
|
) |
|
|
|
prompt = self.tokenizer.apply_chat_template( |
|
conversations, |
|
chat_template=chat_template, |
|
tokenize=False, |
|
return_dict=False, |
|
**processed_kwargs["template_kwargs"], |
|
) |
|
|
|
if not is_batched: |
|
prompt = prompt[0] |
|
|
|
if tokenize: |
|
|
|
|
|
|
|
|
|
|
|
|
|
single_prompt = prompt[0] if is_batched else prompt |
|
if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token): |
|
kwargs["add_special_tokens"] = False |
|
|
|
out = self( |
|
text=prompt, |
|
images=batch_images if batch_images else None, |
|
videos=batch_videos if batch_videos else None, |
|
audio=batch_audios if batch_audios else None, |
|
**kwargs, |
|
) |
|
if return_dict: |
|
return out |
|
else: |
|
return out["input_ids"] |
|
return prompt |
|
|
|
|
|
__all__ = ["Phi4MultimodalProcessor"] |
|
|
|
|
|
Phi4MultimodalProcessor.register_for_auto_class() |