import base64 import copy import io import math import os import uuid from typing import Dict, List, Optional, Union from urllib.parse import urlparse import av import cv2 import numpy as np import requests import torch from decord import VideoReader, cpu from PIL import Image, UnidentifiedImageError from transformers.image_processing_utils import ( BaseImageProcessor, BatchFeature, get_size_dict, ) from transformers.image_transforms import ( convert_to_rgb, get_resize_output_image_size, resize, to_channel_dimension_format, ) from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ChannelDimension, ImageInput, PILImageResampling, get_image_size, infer_channel_dimension_format, is_scaled_image, make_list_of_images, to_numpy_array, valid_images, ) from transformers.utils import TensorType, logging logger = logging.get_logger(__name__) def determine_possible_resolutions(anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False): """ Finds and returns possible resolution combinations with a total number of grids less than or equal to max_num_grids. For example, if max_num_grids is 4, the possible grid combinations are: [1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1], and the resolutions are calculated accordingly. Example: >>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336) >>> print(possible_resolutions) [[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]] Args: anyres (bool): Whether to allow any resolution combinations up to the maximum grid count. max_num_grids (int): The maximum number of grids allowed (height x width must be ≤ this value). grid_size (int): The size of each grid in pixels (e.g., 336). use_1x1_grid (bool, optional): Whether to include the 1x1 grid as a valid resolution. Defaults to False. Returns: List[List[int]]: A list of possible [height, width] resolution pairs. """ possible_resolutions = [] if anyres: assert max_num_grids > 0 for i in range(1, max_num_grids + 1): for j in range(1, max_num_grids + 1): if i == 1 and j == 1 and not use_1x1_grid: continue if i * j <= max_num_grids: possible_resolutions.append([i, j]) possible_resolutions = [[ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions] return possible_resolutions def divide_to_grids(image: np.array, grid_size: int, input_data_format=None) -> List[np.array]: """ Divides a local image into grids of size (grid_size x grid_size). Args: image (np.array): Input image as a NumPy array. grid_size (int): The size (in pixels) of each square grid. input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last"). Returns: List[np.array]: A list of image patches, each of size (grid_size x grid_size). """ grids = [] height, width = get_image_size(image, channel_dim=input_data_format) for i in range(0, height, grid_size): for j in range(0, width, grid_size): if input_data_format == ChannelDimension.LAST: grid = image[i : i + grid_size, j : j + grid_size] else: grid = image[:, i : i + grid_size, j : j + grid_size] grids.append(grid) return grids def pad( image: np.array, target_size: tuple, background_color=(127, 127, 127), input_data_format=None, ) -> np.array: """ Pads the input image on the sides (top/bottom and left/right) to match the target height and width. Args: image (np.array): Input image as a NumPy array. target_size (tuple): Target size as (target_height, target_width). background_color (tuple, optional): RGB color value used for padding. Defaults to (127, 127, 127). input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last"). Returns: np.array: The padded image with the specified target size. """ target_height, target_width = target_size height, width = get_image_size(image, channel_dim=input_data_format) # result = np.ones((target_height, target_width, image.shape[2]), dtype=image.dtype) * background_color result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype) for i in range(image.shape[2]): result[..., i].fill(background_color[i]) paste_x = (target_width - width) // 2 paste_y = (target_height - height) // 2 result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image return result def expand2square( image: np.array, bboxes_dict=None, background_color=(127, 127, 127), input_data_format=None, ) -> np.array: """ Expands the input image to a square shape by placing it at the center of a new square canvas, with padding added to the shorter side (either top/bottom or left/right). The image is always centered on the new canvas, and padding is applied symmetrically. Args: image (np.array): Input image as a NumPy array. bboxes_dict (dict, optional): A dictionary of bounding boxes, where each value is an NDArray of shape (N, 4, 2) with box coordinates in the format [[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]]. Supports multiple categories (e.g., "ocr", "html") simultaneously. background_color (tuple, optional): RGB color to fill the padding area. Defaults to (127, 127, 127). input_data_format (optional): Optional format specifier for image data (e.g., "channels_first" or "channels_last"). Returns: np.array: A square-shaped image with the original image centered and padded as needed. Example: >>> _img = np.ones((80, 100), dtype=np.uint8) * 100 >>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]], ... [[30, 30], [40, 30], [40, 40], [30, 40]]])} >>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255)) >>> _img.shape (100, 100) >>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]], ... [[40, 30], [50, 30], [50, 40], [40, 40]]]) >>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None True """ height, width = get_image_size(image, channel_dim=input_data_format) if width == height: return image, bboxes_dict elif width > height: # result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color result = np.empty((width, width, image.shape[2]), dtype=image.dtype) for i in range(image.shape[2]): result[..., i].fill(background_color[i]) result[(width - height) // 2 : (width - height) // 2 + height, :] = image if bboxes_dict is not None: for key in bboxes_dict: bboxes_dict[key][:, :, 1] += (width - height) // 2 return result, bboxes_dict else: # result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color result = np.empty((height, height, image.shape[2]), dtype=image.dtype) for i in range(image.shape[2]): result[..., i].fill(background_color[i]) result[:, (height - width) // 2 : (height - width) // 2 + width] = image if bboxes_dict is not None: for key in bboxes_dict: bboxes_dict[key][:, :, 0] += (height - width) // 2 return result, bboxes_dict def resize_longside( image: np.array, size: int, resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ): """ Resizes the image so that its longer side matches the specified size, maintaining the original aspect ratio. Args: image (np.array): Input image as a NumPy array. size (int): Target size for the longer side of the image. resample (PILImageResampling, optional): Resampling method to use during resizing. Defaults to BICUBIC. data_format (str or ChannelDimension, optional): Output data format (e.g., "channels_first" or "channels_last"). input_data_format (str or ChannelDimension, optional): Input data format of the image. Returns: np.array: The resized image with its aspect ratio preserved. """ height, width = get_image_size(image, channel_dim=input_data_format) if width == height: target_height, target_width = size, size elif width > height: target_width = size target_height = math.ceil(height / width * size) else: target_width = math.ceil(width / height * size) target_height = size return resize( image, size=(target_height, target_width), resample=resample, data_format=data_format, input_data_format=input_data_format, ) def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: """ Selects the best-fit resolution from a list of possible resolutions based on the original image size. This function evaluates each resolution by computing its effective and wasted area compared to the original size. The optimal resolution is the one that maximizes the effective area while minimizing unused (wasted) space. Args: original_size (tuple): The original image size in the format (height, width). possible_resolutions (list): A list of candidate resolutions in the format [(height1, width1), (height2, width2), ...]. Returns: tuple: The best-fit resolution in the format (height, width). This function includes code adapted from the file image_processing_llava_next.py in the LLaVA-Next project(https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py), which is licensed under apache-2.0. """ original_height, original_width = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for height, width in possible_resolutions: scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or ( effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution ): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (height, width) return best_fit def _get_local_grids_output_size(image: np.array, target_resolution: tuple, input_data_format=None): """ Computes the number of local grids (patches) along the height and width when resizing an image to the target resolution. Args: image (np.array): Input image as a NumPy array. target_resolution (tuple): Target resolution in the format (target_height, target_width). input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last"). Returns: tuple: A tuple (grid_h, grid_w) representing the number of grids along the height and width. """ original_height, original_width = get_image_size(image, channel_dim=input_data_format) target_height, target_width = target_resolution scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) return new_height, new_width def determine_anyres_num_vision_patches( num_grids, image_size, grid_size, patch_size, possible_resolutions, anyres=False, unpad=True, num_queries_vis_abstractor=0, num_queries_vis_abstractor_slow=0, is_video=False, first_last_frames_slow=False, # sample-wise option is_first_or_last_frames=False, # grid-wise option ): """ Computes the number of visual tokens (patches) based on image resolution, grid configuration, and patch size. This function supports both fixed-size and any-resolution settings, as well as video-specific configurations such as handling slow frames and frame position flags. Args: num_grids (int): Number of grids per image (e.g., 1 for 1x1, 4 for 2x2, etc.). image_size (tuple): The original image size as (height, width). grid_size (int): Size of each grid in pixels (e.g., 336). patch_size (int): Size of each vision patch (e.g., 14 for ViT models). possible_resolutions (list): List of possible resolution tuples [(h1, w1), (h2, w2), ...]. anyres (bool, optional): Whether to use any-resolution mode. Defaults to False. unpad (bool, optional): Whether to unpad the image before computing patches. Defaults to True. num_queries_vis_abstractor (int, optional): Number of query tokens for vision abstractor (fast path). num_queries_vis_abstractor_slow (int, optional): Number of query tokens for vision abstractor (slow path). is_video (bool, optional): Whether the input is a video. Defaults to False. first_last_frames_slow (bool, optional): Whether to treat first/last video frames as "slow". Defaults to False. is_first_or_last_frames (bool, optional): Whether current grid corresponds to first/last frame. Defaults to False. Returns: int: Total number of visual tokens (patches) after processing. """ if not anyres: return num_queries_vis_abstractor if num_queries_vis_abstractor > 0 else (grid_size // patch_size) ** 2 if num_queries_vis_abstractor > 0: num_patch_per_grid = int(num_queries_vis_abstractor**0.5) else: num_patch_per_grid = grid_size // patch_size num_global_per_grid = num_patch_per_grid # In anyres mode, a global image is included, so there are always at least 2 grids. # However, for video inputs, there is no global image, so it's possible to have only 1 grid. # Therefore, the assertion below is commented out: # assert num_grids > 1 # Compute the number of vision patches. height, width = select_best_resolution(image_size, possible_resolutions) num_patch_height = (height // grid_size) * num_patch_per_grid num_patch_width = (width // grid_size) * num_patch_per_grid # local images if unpad: original_height, original_width = image_size original_aspect_ratio = original_width / original_height current_aspect_ratio = num_patch_width / num_patch_height if original_aspect_ratio > current_aspect_ratio: scale_factor = num_patch_width / original_width new_height = int(original_height * scale_factor) padding = (num_patch_height - new_height) // 2 num_patch_height = num_patch_height - padding * 2 else: scale_factor = num_patch_height / original_height new_width = int(original_width * scale_factor) padding = (num_patch_width - new_width) // 2 num_patch_width = num_patch_width - padding * 2 num_patches = num_patch_width * num_patch_height + num_patch_height else: num_patches = num_patch_width * num_patch_height # In the "slow" strategy, when applying to first and last frames only, it is applied exclusively to those two frames. if num_queries_vis_abstractor_slow > 0: if first_last_frames_slow: if is_first_or_last_frames: num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor else: num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor # The slowfast feature is only applicable when unpad is set to False. assert unpad is False # Global image is not included for video inputs. if not is_video: num_patches += num_global_per_grid**2 return num_patches class HCXVisionProcessor(BaseImageProcessor): r""" Constructs a VLM image processor. This processor is based on [`CLIPImageProcessor`] and incorporates additional techniques for handling high-resolution images, such as flexible resolution support (`anyres`), unpadding, square padding, and multi-grid patching strategies. Args: do_resize (bool): Whether to resize the image. size (Dict[str, int], optional): Target size for resizing, typically with keys `"height"` and `"width"`. anyres (bool): Whether to enable the any-resolution (`anyres`) feature, which allows flexible resolution handling via grid division. unpad (bool): When `anyres` is enabled, whether to remove visual tokens corresponding to pure padding regions. max_num_grids (int): Maximum number of grids allowed per image. max_image_cnt (int): Maximum number of images that can be processed at once (used for batching). num_queries_vis_abstractor (int): Number of visual query tokens per grid when using a visual resampler (e.g., Perceiver). num_queries_vis_abstractor_video_fast (int): Number of visual queries for fast-path video frames. num_queries_vis_abstractor_video_slow (int): Number of visual queries for slow-path video frames (e.g., first/last). possible_resolutions (List): List of allowed resolution pairs when `anyres` is enabled. Example: [[336, 336], [336, 672], [672, 336]]. patch_size (int): Patch size for the Vision Transformer (ViT). pad_to_square (bool): Whether to pad images to a square shape. If `False`, a center crop is applied to fit ViT input. resample (PILImageResampling): Resampling method to use for resizing. Default is `BICUBIC`. do_center_crop (bool): Whether to apply center cropping. crop_size (Dict[str, int], optional): Size for center cropping. do_rescale (bool): Whether to rescale pixel values. rescale_factor (float or int): Factor to use for rescaling pixel values (typically `1/255`). do_normalize (bool): Whether to normalize pixel values using `image_mean` and `image_std`. image_mean (float or List[float], optional): Mean values for normalization. Can be a single float or list of floats per channel. image_std (float or List[float], optional): Standard deviation values for normalization. Can be a single float or list of floats per channel. do_convert_rgb (bool): Whether to convert the input image to RGB. first_last_frames_slow (bool): Whether to treat the first and last frames of a video as “slow path” (processed differently). Attributes: model_input_names (List[str]): Names of the expected model inputs. Defaults to `["pixel_values"]`. """ model_input_names = ["pixel_values"] def __init__( self, do_resize: bool = True, size: Dict[str, int] = None, anyres: bool = False, unpad: bool = False, max_num_grids: int = 9, max_image_cnt: int = 12, num_queries_vis_abstractor: int = 0, num_queries_vis_abstractor_video_fast: int = 0, num_queries_vis_abstractor_video_slow: int = 0, possible_resolutions: List = [], patch_size: int = 14, pad_to_square: bool = True, resample: PILImageResampling = PILImageResampling.BICUBIC, do_center_crop: bool = True, crop_size: Dict[str, int] = None, do_rescale: bool = True, rescale_factor: Union[int, float] = 1 / 255, do_normalize: bool = True, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = True, first_last_frames_slow: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) size = size if size is not None else {"shortest_edge": 512} size = get_size_dict(size, default_to_square=False) crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512} crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") self.do_resize = do_resize self.size = size self.anyres = anyres self.unpad = unpad self.max_num_grids = max_num_grids self.max_image_cnt = max_image_cnt self.num_queries_vis_abstractor = num_queries_vis_abstractor self.num_queries_vis_abstractor_video_fast = num_queries_vis_abstractor_video_fast self.num_queries_vis_abstractor_video_slow = num_queries_vis_abstractor_video_slow self.possible_resolutions = [_resolution for _resolution in possible_resolutions] self.patch_size = patch_size self.pad_to_square = pad_to_square self.resample = resample self.do_center_crop = do_center_crop self.crop_size = crop_size self.do_rescale = do_rescale self.rescale_factor = rescale_factor self.do_normalize = do_normalize self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD self.do_convert_rgb = do_convert_rgb self.first_last_frames_slow = first_last_frames_slow assert self.crop_size["height"] == self.crop_size["width"] def resize( self, image: np.ndarray, size: Dict[str, int], resample: PILImageResampling = PILImageResampling.BICUBIC, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, **kwargs, ) -> np.ndarray: """ Resizes the input image to the specified target size. Args: image (np.ndarray): The input image to resize. size (Dict[str, int]): A dictionary specifying the target size with keys `"height"` and `"width"`. resample (PILImageResampling, optional): The resampling filter to use. Defaults to `BICUBIC`. data_format (str or ChannelDimension, optional): The desired output data format (e.g., "channels_last"). input_data_format (str or ChannelDimension, optional): The input data format of the image. **kwargs: Additional keyword arguments, if any. Returns: np.ndarray: The resized image as a NumPy array. """ default_to_square = True if "shortest_edge" in size: size = size["shortest_edge"] default_to_square = False elif "height" in size and "width" in size: size = (size["height"], size["width"]) else: raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") output_size = get_resize_output_image_size( image, size=size, default_to_square=default_to_square, input_data_format=input_data_format, ) return resize( image, size=output_size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs, ) def _preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, resample: PILImageResampling = None, do_center_crop: bool = None, crop_size: int = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> Image.Image: """ Applies a sequence of preprocessing operations to the input image(s), including resizing, cropping, rescaling, normalization, and format conversion. This method is typically used internally to prepare images for model input. Args: images (ImageInput): A single image or a batch of images to preprocess. do_resize (bool, optional): Whether to resize the image(s). size (Dict[str, int], optional): Target size for resizing, with keys `"height"` and `"width"`. resample (PILImageResampling, optional): Resampling method to use for resizing. do_center_crop (bool, optional): Whether to apply center cropping. crop_size (int, optional): Size of the center crop (applied to both height and width). do_rescale (bool, optional): Whether to rescale the image pixel values. rescale_factor (float, optional): Factor to use when rescaling pixel values (e.g., 1/255). do_normalize (bool, optional): Whether to normalize the image using `image_mean` and `image_std`. image_mean (float or List[float], optional): Mean value(s) used for normalization. image_std (float or List[float], optional): Standard deviation value(s) used for normalization. data_format (ChannelDimension, optional): The desired output data format (e.g., `ChannelDimension.FIRST`). input_data_format (str or ChannelDimension, optional): The format of the input image(s). Returns: Image.Image: The preprocessed image or batch of images, ready for model input. """ images = make_list_of_images(images) if do_resize: images = [ self.resize( image=image, size=size, resample=resample, input_data_format=input_data_format, ) for image in images ] if do_center_crop: images = [ self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images ] if do_rescale: images = [ self.rescale( image=image, scale=rescale_factor, input_data_format=input_data_format, ) for image in images ] if do_normalize: images = [ self.normalize( image=image, mean=image_mean, std=image_std, input_data_format=input_data_format, ) for image in images ] images = [ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images ] return images def _resize_for_local_grids( self, image: np.array, target_resolution: tuple, resample, input_data_format: ChannelDimension, ) -> np.array: """ Resizes the image to the given target resolution for use in local grid processing. This function ensures that the image is properly resized to match the (height, width) specified in `target_resolution`, using the provided resampling method. It supports channel-first and channel-last formats based on `input_data_format`. Args: image (np.array): Input image as a NumPy array. target_resolution (tuple): Target resolution as (height, width) for resizing. resample: Resampling method to use (e.g., `PILImageResampling.BICUBIC`). input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`). Returns: np.array: The resized image in NumPy array format. """ new_height, new_width = _get_local_grids_output_size(image, target_resolution, input_data_format) # Resize the image resized_image = resize( image, (new_height, new_width), resample=resample, input_data_format=input_data_format, ) return resized_image def _pad_for_patching( self, image: np.array, target_resolution: tuple, input_data_format: ChannelDimension, ) -> np.array: """ Pads the image to match the target resolution, ensuring compatibility with patch-based models. This is typically used to make sure the image dimensions are divisible by the patch size or to meet specific model input requirements. Padding is applied symmetrically where needed. Args: image (np.array): Input image as a NumPy array. target_resolution (tuple): The desired resolution after padding, in the format (height, width). input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`). Returns: np.array: The padded image as a NumPy array. """ target_height, target_width = target_resolution background_color = tuple(int(x * 255) for x in self.image_mean) padded_image = pad( image, target_size=(target_height, target_width), background_color=background_color, input_data_format=input_data_format, ) return padded_image def get_image_grids( self, image: np.array, possible_resolutions, grid_size: int, resample: PILImageResampling, data_format: ChannelDimension, input_data_format: ChannelDimension, ) -> List[np.array]: """ Splits the input image into multiple local grids based on possible resolutions and grid size. The function selects the best resolution from the provided list, resizes the image accordingly, and divides it into non-overlapping grid patches of size (grid_size x grid_size). It is commonly used for any-resolution (anyres) visual processing. Args: image (np.array): Input image as a NumPy array. possible_resolutions (List[Tuple[int, int]]): List of allowed resolutions to choose from. grid_size (int): The size of each grid patch (e.g., 336 pixels). resample (PILImageResampling): Resampling method used during resizing. data_format (ChannelDimension): Output data format (e.g., `ChannelDimension.FIRST`). input_data_format (ChannelDimension): Input data format of the image. Returns: List[np.array]: A list of grid image patches as NumPy arrays. """ if not isinstance(possible_resolutions, list): raise ValueError("possible_resolutions must be a list of possible resolutions.") image_size = get_image_size(image, channel_dim=input_data_format) best_resolution = select_best_resolution(image_size, possible_resolutions) resized_image = self._resize_for_local_grids( image, best_resolution, resample=resample, input_data_format=input_data_format, ) padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format) local_grids = divide_to_grids(padded_image, grid_size=grid_size, input_data_format=input_data_format) # make sure that all patches are in the input data format local_grids = [ to_channel_dimension_format(grid, channel_dim=data_format, input_channel_dim=input_data_format) for grid in local_grids ] return local_grids def preprocess( self, images: ImageInput, do_resize: bool = None, size: Dict[str, int] = None, anyres: bool = None, unpad: bool = None, is_video_list: List[bool] = None, possible_resolutions: List = None, patch_size: int = None, pad_to_square: bool = None, resample: PILImageResampling = None, do_center_crop: bool = None, crop_size: int = None, do_rescale: bool = None, rescale_factor: float = None, do_normalize: bool = None, image_mean: Optional[Union[float, List[float]]] = None, image_std: Optional[Union[float, List[float]]] = None, do_convert_rgb: bool = None, return_tensors: Optional[Union[str, TensorType]] = None, data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, input_data_format: Optional[Union[str, ChannelDimension]] = None, is_first_or_last_frames: List[bool] = False, ): """ Preprocesses images using HCXVisionProcessor. This method prepares images for visual language models by applying resizing, padding, cropping, normalization, and tokenization into visual patches. In video mode, each frame is converted to a 1D sequence of patches. The `unpad` option is disabled when processing videos. Args: images (ImageInput): A single image or a batch of images (PIL, NumPy, or tensor format). do_resize (bool, optional): Whether to resize the image(s). size (Dict[str, int], optional): Resize target with keys `"height"` and `"width"`. anyres (bool, optional): Whether to use any-resolution processing with grid splitting. unpad (bool, optional): Whether to remove visual tokens that belong to padding areas (only in non-video mode). is_video_list (List[bool], optional): A list indicating which inputs are video frames. possible_resolutions (List, optional): List of resolution pairs allowed in `anyres` mode. patch_size (int, optional): Patch size for the Vision Transformer (ViT). pad_to_square (bool, optional): Whether to pad the image to a square. resample (PILImageResampling, optional): Resampling method to use for resizing. do_center_crop (bool, optional): Whether to apply center cropping. crop_size (int, optional): Target crop size for center cropping. do_rescale (bool, optional): Whether to rescale image pixel values. rescale_factor (float, optional): Factor for pixel rescaling, e.g., `1/255`. do_normalize (bool, optional): Whether to normalize using mean and std. image_mean (float or List[float], optional): Mean value(s) for normalization. image_std (float or List[float], optional): Standard deviation(s) for normalization. do_convert_rgb (bool, optional): Whether to convert the image to RGB. return_tensors (str or TensorType, optional): Desired output tensor type (e.g., "pt" for PyTorch). data_format (ChannelDimension, optional): Output data format (e.g., `ChannelDimension.FIRST`). input_data_format (str or ChannelDimension, optional): Format of the input image. is_first_or_last_frames (List[bool], optional): Flags indicating whether each image is a first/last video frame. Returns: Tuple: pixel_values (List[torch.Tensor]): A list of 4D image tensors ready for model input. image_sizes (List[List[int]]): A list of list containing the original width and height [width, height] of each image, e.g., `[[width, height], ...]`. vision_query_lengths (List[int]): A list of integers representing the number of visual tokens each image contributes to the LLM input. """ do_resize = do_resize if do_resize is not None else self.do_resize size = size if size is not None else self.size size = get_size_dict(size, param_name="size", default_to_square=False) anyres = anyres if anyres is not None else self.anyres unpad = unpad if unpad is not None else self.unpad possible_resolutions = possible_resolutions if possible_resolutions is not None else self.possible_resolutions patch_size = patch_size if patch_size is not None else self.patch_size pad_to_square = pad_to_square if pad_to_square is not None else self.pad_to_square resample = resample if resample is not None else self.resample do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop crop_size = crop_size if crop_size is not None else self.crop_size crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) do_rescale = do_rescale if do_rescale is not None else self.do_rescale rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor do_normalize = do_normalize if do_normalize is not None else self.do_normalize image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb images = make_list_of_images(images) if not valid_images(images): raise ValueError( "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " "torch.Tensor, tf.Tensor or jax.ndarray." ) if do_convert_rgb: images = [convert_to_rgb(image) for image in images] # All transformations expect numpy arrays. images = [to_numpy_array(image) for image in images] if is_scaled_image(images[0]) and do_rescale: logger.warning_once( "It looks like you are trying to rescale already rescaled images. If the input" " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." ) if input_data_format is None: # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) new_images = [] image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images] vision_query_lengths = [] assert crop_size["height"] == crop_size["width"] # Padding operations for the global image can become a bottleneck when the original image width or height is large. # To mitigate this, the image is first resized such that the longest side is scaled proportionally based on size["shortest_edge"], # and then padding is applied to reach the target dimensions. if anyres: anyres_global_images = copy.deepcopy(images) if pad_to_square: background_color = tuple(int(x * 255) for x in self.image_mean) anyres_global_images = [ resize_longside( copy.deepcopy(image), size["shortest_edge"], resample, input_data_format, ) for image in anyres_global_images ] anyres_global_images = [ expand2square( image, background_color=background_color, input_data_format=input_data_format, )[0] for image in anyres_global_images ] else: anyres_global_images = [ self.resize( image=image, size={ "height": size["shortest_edge"], "width": size["shortest_edge"], }, resample=resample, input_data_format=input_data_format, ) for image in anyres_global_images ] else: anyres_global_images = [None for _ in range(len(images))] if pad_to_square: background_color = tuple(int(x * 255) for x in self.image_mean) images = [ resize_longside(image, size["shortest_edge"], resample, input_data_format) for image in images ] images = [ expand2square( image, background_color=background_color, input_data_format=input_data_format, )[0] for image in images ] num_queries_vis_abstractors = [] num_queries_vis_abstractors_slow = [] first_last_frames_slows = [] for image, is_video, anyres_global_image, image_size in zip( images, is_video_list, anyres_global_images, image_sizes ): if is_video: num_queries_vis_abstractor = self.num_queries_vis_abstractor_video_fast num_queries_vis_abstractor_slow = self.num_queries_vis_abstractor_video_slow else: num_queries_vis_abstractor = self.num_queries_vis_abstractor num_queries_vis_abstractor_slow = 0 num_queries_vis_abstractors.append(num_queries_vis_abstractor) num_queries_vis_abstractors_slow.append(num_queries_vis_abstractor_slow) first_last_frames_slows.append(self.first_last_frames_slow) if anyres: # convert image into a list of grids # we intentially use the same data format as the input data format image_grids = self.get_image_grids( image, possible_resolutions, grid_size=crop_size["height"], resample=resample, data_format=input_data_format, input_data_format=input_data_format, ) # Global image (thumbnail) is not used for video inputs. if not is_video: image_grids = [anyres_global_image] + image_grids else: image_grids = [image] pixel_values = self._preprocess( image_grids, do_resize=do_resize, size=size, resample=resample, do_center_crop=do_center_crop, crop_size=crop_size, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, data_format=data_format, input_data_format=input_data_format, ) pixel_values = np.array(pixel_values) new_images.append(pixel_values) num_grids = pixel_values.shape[0] vision_query_length = determine_anyres_num_vision_patches( num_grids=num_grids, image_size=image_size, grid_size=crop_size["height"], patch_size=patch_size, possible_resolutions=possible_resolutions, anyres=anyres, unpad=False if is_video else unpad, num_queries_vis_abstractor=num_queries_vis_abstractor, num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow, is_video=is_video, first_last_frames_slow=self.first_last_frames_slow, is_first_or_last_frames=self.first_last_frames_slow, ) vision_query_lengths.append(vision_query_length) data = { "pixel_values": [[torch.tensor(new_image) for new_image in new_images]], "image_sizes": [[[image_size[1], image_size[0]] for image_size in image_sizes]], "vision_query_lengths": [vision_query_lengths], "is_videos": [is_video_list], "num_queries_vis_abstractors": [num_queries_vis_abstractors], "num_queries_vis_abstractors_slow": [num_queries_vis_abstractors_slow], "first_last_frames_slows": [first_last_frames_slows], } return BatchFeature(data=data) def load_images_videos(self, vlm_chat): """ Loads and prepares images or video frames from a VLM chat input. This function parses the input `vlm_chat` object, extracts image or video sources, and loads them into memory as PIL or NumPy images, ready for preprocessing. Args: vlm_chat: A VLM chat input structure containing multimodal elements (e.g., images, videos, URLs, or file paths). The format is typically a list of messages with associated media fields. Returns: List[Union[PIL.Image.Image, List[PIL.Image.Image]]]: A list of loaded images. For video entries, a list of frames is returned instead of a single image. """ vlm_chat = copy.deepcopy(vlm_chat) new_vlm_chat = [] all_images = [] # images + images_from_videos is_video_list = [] for line in vlm_chat: if "content" in line: content = line["content"] if "image" in content: if "filename" not in content: content["filename"] = f"{uuid.uuid4().hex}.jpg" image_pil = load_image(content["image"]) all_images.append(image_pil) is_video_list.append(False) new_vlm_chat.append(line) elif "video" in content: video_bytesio = load_video_to_bytesio(content["video"]) pil_img_frames, video_time_stamp = process_video( video_bytesio, self.max_num_grids, self.max_image_cnt, self.crop_size["width"] ) all_images.extend(pil_img_frames) is_video_list.extend([True] * len(pil_img_frames)) if "filename" not in content: content["filename"] = f"{uuid.uuid4().hex}.mp4" for i, image_time_stamp in enumerate(video_time_stamp): new_line = copy.deepcopy(line) basename, ext = os.path.splitext(content["filename"]) new_line["content"]["filename"] = f"{basename}-{i}{ext}" new_line["content"]["video_time_stamp"] = image_time_stamp if i == len(video_time_stamp) - 1: new_line["content"]["is_final_grid"] = True for last_frame_target_key in ["lens_keywords", "lens_local_keywords", "speech_to_text"]: if last_frame_target_key in content: new_line["content"][last_frame_target_key] = content[last_frame_target_key] new_vlm_chat.append(new_line) else: new_vlm_chat.append(line) return new_vlm_chat, all_images, is_video_list def process_video(video_bytesio, max_num_grids, max_image_cnt, vit_input_size): """ Processes a video file and extracts frames suitable for vision transformer (ViT) input. The function reads video data from a BytesIO object, extracts a limited number of frames based on `max_num_grids` and `max_image_cnt`, and resizes them to the appropriate ViT input size. Args: video_bytesio (io.BytesIO): A BytesIO object containing the raw video file data. max_num_grids (int): The maximum number of grids allowed (e.g., for tiling or patching). max_image_cnt (int): The maximum number of frames to extract from the video. vit_input_size (int): The desired input size (height and width) for the ViT model. Returns: List[np.ndarray]: A list of processed video frames as NumPy arrays, each resized to (vit_input_size, vit_input_size). """ frames, time_interval = video_decoder( video_bytesio, max_num_grids=max_num_grids, max_image_cnt=max_image_cnt, default_interval=0.4 ) pil_img_frames, video_time_stamp = combine_frames_into_images( frames, time_interval, max_grid_shape=(max_num_grids, 1), vit_input_size=vit_input_size ) return pil_img_frames, video_time_stamp def load_image(image_src): """ Loads an image from various sources (file path, URL, base64 string, or raw bytes) and returns it as a PIL Image object. Args: image_src (str or bytes): The image source. It can be: - A local file path - A URL - A base64-encoded string - Raw image bytes Returns: PIL.Image.Image: The loaded image as a PIL Image object. Raises: ValueError: If the image cannot be loaded or the format is unsupported. TypeError: If the input is not of type str or bytes. """ try: # 1. If input is bytes type if isinstance(image_src, bytes): return Image.open(io.BytesIO(image_src)) # 2. If input is str type (path, URL, base64) if isinstance(image_src, str): # 2a. Check if it's a Base64 data URI format ('data:image/...') if image_src.startswith("data:image"): try: # Remove the 'data:image/...;base64,' part and decode header, encoded = image_src.split(",", 1) image_bytes = base64.b64decode(encoded) return Image.open(io.BytesIO(image_bytes)) except (ValueError, base64.binascii.Error) as e: raise ValueError(f"Invalid base64 data URI format: {e}") from e # 2b. Check if it's a URL format ('http://' or 'https://') elif image_src.startswith("http://") or image_src.startswith("https://"): try: response = requests.get(image_src, stream=True, timeout=10) response.raise_for_status() # Raise an exception for HTTP errors image_bytes = response.content return Image.open(io.BytesIO(image_bytes)) except requests.exceptions.RequestException as e: raise ValueError(f"Error loading image from URL '{image_src}': {e}") from e # 2c. Assume it's a local file path else: return Image.open(image_src) else: raise TypeError(f"Unsupported image_src type: {type(image_src)}") # Common exception handling except FileNotFoundError: raise ValueError(f"Image loading error: File not found '{image_src}'") except UnidentifiedImageError: raise ValueError("Image loading error: Cannot identify image file format.") except IOError as e: raise ValueError(f"Image loading error (I/O): {e}") from e except Exception as e: raise ValueError(f"Unexpected error during image loading: {e}") from e def load_video_to_bytesio(video_src): """ Loads video data from various sources (file path, URL, base64 string, or raw bytes) and returns an `io.BytesIO` object containing the raw video content. Args: video_src (str or bytes): The video source. Supported formats include: - Local file path - URL - Base64-encoded data URI string - Raw video bytes Returns: io.BytesIO: A `BytesIO` object containing the loaded video data. Raises: ValueError: If the video cannot be loaded due to issues such as an invalid path, URL failure, malformed base64 string, or unsupported format. TypeError: If the input is not a `str` or `bytes` object. """ video_bytes = None try: # 1. If input is bytes type if isinstance(video_src, bytes): video_bytes = video_src # 2. If input is str type (path, URL, base64) elif isinstance(video_src, str): # 2a. Check if it's a Base64 data URI format ('data:video/...') if video_src.startswith("data:video"): try: # Remove the 'data:video/...;base64,' part and decode header, encoded = video_src.split(",", 1) video_bytes = base64.b64decode(encoded) except (ValueError, base64.binascii.Error) as e: raise ValueError(f"Invalid base64 data URI format: {e}") from e # 2b. Check if it looks like a URL elif urlparse(video_src).scheme in ("http", "https"): try: response = requests.get( video_src, stream=True, timeout=30 ) # Increased timeout for potentially large videos response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx) # Read all content from the stream into bytes video_bytes = response.content except requests.exceptions.MissingSchema: # If urlparse thinks it's a scheme but requests disagrees (e.g., "http:/example.com") # Treat it as a potential file path below. pass except requests.exceptions.RequestException as e: raise ValueError(f"Error loading video from URL '{video_src}': {e}") from e # 2c. Assume it's a local file path if not base64 or confirmed URL if video_bytes is None: # Only attempt file read if not already loaded as base64 or URL failed gracefully # Check if it could potentially be a file path # Note: This check is basic. A string like "http:/path/file" might incorrectly be treated as a path here # if the requests call failed due to MissingSchema. More robust path validation could be added. if ( os.path.exists(video_src) or "/" in video_src or "\\" in video_src ): # Basic check if it resembles a path try: with open(video_src, "rb") as f: video_bytes = f.read() except FileNotFoundError: raise ValueError(f"Video loading error: File not found at path '{video_src}'") except IsADirectoryError: raise ValueError(f"Video loading error: Path '{video_src}' is a directory, not a file.") except IOError as e: raise ValueError(f"Video loading error (I/O) for path '{video_src}': {e}") from e else: # If it's not base64, not a valid downloadable URL, and doesn't look like a path/doesn't exist raise ValueError(f"Unsupported string input format or resource not found: '{video_src}'") # 3. If the type is unsupported else: raise TypeError(f"Unsupported video_src type: {type(video_src)}") # Final check if video_bytes was successfully obtained if video_bytes is None: raise ValueError(f"Could not load video data from the provided source: {video_src}") # Return the bytes wrapped in BytesIO return io.BytesIO(video_bytes) # Catch specific exceptions first for better error reporting except FileNotFoundError as e: # Should be caught above, but as a safeguard raise ValueError(f"Video loading error: File not found '{video_src}'") from e except requests.exceptions.RequestException as e: # Already handled, but for clarity raise ValueError(f"Video loading error (Network): {e}") from e except (ValueError, TypeError) as e: # Re-raise ValueErrors/TypeErrors raised intentionally within the try block raise e except Exception as e: # Catch any other unexpected errors during processing raise ValueError(f"Unexpected error during video loading from source '{video_src}': {e}") from e def video_decoder(video_bytesio, max_num_grids, max_image_cnt, default_interval=0.4): """ Decodes video data from a BytesIO object and returns a list of extracted frames. Args: video_bytesio (io.BytesIO): A BytesIO object containing the raw video data. max_num_grids (int): Maximum number of grids allowed per image. Used to determine how many frames to extract. max_image_cnt (int): Maximum number of frames to extract from the video. default_interval (float, optional): Default time interval (in seconds) between frames. Used when frame rate info is unavailable. TODO: make configurable. Returns: Tuple: frames (List[PIL.Image.Image]): A list of extracted frames as PIL Images. time_interval (float): Time interval (in seconds) between selected frames. """ error_messages = [] frames = [] # 1. Try decoding the video using Decord. try: vr = VideoReader(video_bytesio, ctx=cpu(0), num_threads=8) fps = vr.get_avg_fps() play_time = len(vr) / fps total_frames = len(vr) frame_indices, time_interval = extract_frame_indices( play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval ) # Sample every 0.4 seconds; if the video is too long, apply uniform sampling instead. if frame_indices is None: frame_indices = range(len(vr)) # Convert all frames. batch_frames = vr.get_batch(frame_indices).asnumpy() frames = [Image.fromarray(frame).convert("RGB") for frame in batch_frames] return frames, time_interval except Exception as e: print("error with decord") error_messages.append(f"Decord 실패: {e}") # 2. Fallback: Try decoding the video using PyAV. try: container = av.open(video_bytesio) fps = container.streams.video[0].average_rate play_time = len(container) / fps total_frames = len(container) frame_indices, time_interval = extract_frame_indices( play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval ) # Sample frames every 0.4 seconds. If the video is long, use uniform sampling to limit the number of frames. # Even if frame_indices were assigned using Decord, reprocess them to be compatible with PyAV. target_indices = None if frame_indices is None else set(frame_indices) frames = [] for i, frame in enumerate(container.decode(video=0)): if target_indices is not None and i not in target_indices: continue # Skip frames that are not in the required indices. pil_frame = Image.fromarray(frame.to_ndarray(format="rgb24")).convert("RGB") frames.append(pil_frame) if frames: return frames, time_interval else: raise Exception("Decoding with PyAV succeeded, but no frames were extracted.") except Exception as e: error_messages.append(f"PyAV failed: {e}") # 3. Fallback: Try decoding the video using OpenCV. try: byte_data = np.frombuffer(video_bytesio.getvalue(), dtype=np.uint8) video = cv2.imdecode(byte_data, cv2.IMREAD_UNCHANGED) cap = cv2.VideoCapture(video) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) play_time = total_frames / fps frame_indices, time_interval = extract_frame_indices( play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval ) # Sample frames every 0.4 seconds; if the video is too long, apply uniform sampling to limit the total number of frames. if frame_indices is None: frame_indices = range(total_frames) # Convert all frames. index_set = set(frame_indices) # Convert to a set for faster lookup. current_index = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if current_index in index_set: frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert("RGB")) current_index += 1 if current_index > max(index_set): # Stop processing once all required indices have been handled. break cap.release() if frames: return frames, time_interval except Exception as e: error_messages.append(f"OpenCV failed: {e}") if error_messages: raise Exception(f"All decoding attempts have failed.: {error_messages}") def convert_format_for_multi_image(img, json, convert_key_list=["words", "text", "objects", "entities"]): """ Converts the format of image and annotation data from a single-image dataset to a multi-image dataset format. Single-image datasets typically return a single image and its associated annotation as individual objects. This function wraps them in a dictionary format used by multi-image datasets. Args: img: The input image (e.g., a PIL Image or NumPy array). json: The annotation data associated with the image. convert_key_list (List[str], optional): A list of keys to extract and convert from the original JSON. Defaults to ["words", "text", "objects", "entities"]. Returns: Tuple[Dict, Dict]: - A dictionary mapping image IDs to images (e.g., {"image_0": img}). - A dictionary mapping image IDs to corresponding annotation JSONs (with filtered keys). """ is_multi_image_dataset = isinstance(img, dict) if not is_multi_image_dataset: img = {"00": img} for convert_key in convert_key_list: if convert_key in json: json[convert_key] = {"00": json[convert_key]} for json_key in json: if "region" in json_key: json[json_key] = {"00": json[json_key]} return is_multi_image_dataset, img, json def convert_tags_for_video(img, json): """ Converts tags to tags based on the number of video frames. In video datasets, annotations often use a generic tag. This function replaces that tag with frame-specific tags such as , , ..., based on the number of frames in `img`. Args: img: A list of video frames (e.g., list of PIL Images or NumPy arrays). json: The annotation data containing tags to be replaced. Returns: Dict: The updated annotation JSON with frame-specific tags. """ image_tag = "".join([f"" for idx in range(len(img))]) # image_tag = "" # Use this format to construct and insert image-specific tags. for json_key in json: if "qa_pairs" in json_key: new_qa_pairs = [] for qa_pair in json[json_key]: question = qa_pair[0] # Replace tags with corresponding tags. question = question.replace("", image_tag) new_qa_pairs.append([question, qa_pair[1]]) json[json_key] = new_qa_pairs return img, json def split_list(input_list, split_value): """ Splits a list into sublists using a specified delimiter value. Each time `split_value` is encountered in `input_list`, a new sublist is started. The delimiter itself is not included in the output. Args: input_list (List[Any]): The input list to split. split_value (Any): The value used as the delimiter for splitting. Returns: List[List[Any]]: A list of sublists, split by the specified delimiter. Example: >>> split_list(["a", "b", "|", "c", "d", "|", "e"], "|") [['a', 'b'], ['c', 'd'], ['e']] """ temp_list = [] result = [] for value in input_list: if value == split_value: result.append(temp_list) temp_list = [] else: temp_list.append(value) result.append(temp_list) return result def combine_frames_into_images(frames, time_interval, max_grid_shape=(3, 3), vit_input_size=378): """ Combines a sequence of video frames into grid-based images and generates corresponding time range labels. Frames are grouped and arranged into a grid (e.g., 3x3) such that each combined image contains up to `max_grid_shape[0] * max_grid_shape[1]` frames. Each combined image is resized to the given ViT input size. Args: frames (List[PIL.Image.Image]): A list of frames extracted from a video. time_interval (float): Time interval (in seconds) between consecutive frames. max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3). vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378. Returns: Tuple: image_list (List[PIL.Image.Image]): A list of grid-combined images. image_time_stamps (List[str]): A list of time span labels for each combined image, e.g., ["0.00s~1.50s", "1.50s~3.00s", ...]. """ # grid_size = int(np.sqrt(max_num_grids)) # assert grid_size**2 == max_num_grids, "max_num_grids must be a perfect square." max_num_grids = max_grid_shape[0] * max_grid_shape[1] assert ( max_grid_shape[1] == 1 ), f"For video processing, decided to concatenate frames horizontally into a wide image." # List to store the resulting combined images. image_list = [] # Calculate the number of canvases needed. num_frames = len(frames) num_canvases = num_frames // max_num_grids leftover_frames = num_frames % max_num_grids time_stamp = 0 # second image_time_stamps = [] for canvas_idx in range(num_canvases): # Initialize the current canvas. combined_image = Image.new( "RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0) ) # Determine the frames to fill in the current canvas. start_idx = canvas_idx * max_num_grids end_idx = min(start_idx + max_num_grids, num_frames) for idx in range(start_idx, end_idx): img = frames[idx] # Resize each frame to a square shape. img_resized = img.resize((vit_input_size, vit_input_size)) # Calculate the (row, column) position to place the frame within the grid layout. local_idx = idx - start_idx x_offset = (local_idx % max_grid_shape[0]) * vit_input_size y_offset = (local_idx // max_grid_shape[0]) * vit_input_size # Calculate the position to place the frame in the grid. combined_image.paste(img_resized, (x_offset, y_offset)) # Append the current canvas to the result list. image_list.append(combined_image) frame_cnt = end_idx - start_idx image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s") time_stamp += frame_cnt * time_interval if leftover_frames > 0: # canvas_idx might be undefined; default to 0 if not previously assigned to avoid "referenced before assignment" error. canvas_idx = num_canvases # Add the remaining frames to the final canvas. combined_image = Image.new("RGB", (vit_input_size * leftover_frames, vit_input_size * 1), color=(0, 0, 0)) for idx in range(leftover_frames): img = frames[num_canvases * max_num_grids + idx] # Resize the frame to a square (equal width and height). img_resized = img.resize((vit_input_size, vit_input_size)) # Calculate the (row, column) position to place the frame within the grid layout. x_offset = (idx % leftover_frames) * vit_input_size y_offset = (idx // leftover_frames) * vit_input_size # Calculate the position to place the frame within the grid layout. combined_image.paste(img_resized, (x_offset, y_offset)) # Add the current canvas to the list of combined images. image_list.append(combined_image) frame_cnt = leftover_frames image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s") time_stamp += frame_cnt * time_interval return image_list, image_time_stamps def extract_frame_indices(play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=0.4): """ Extracts specific frame indices from a video based on duration, frame count, and sampling strategy. The function determines which frames to extract given the video duration (`play_time`), total frame count, and frame rate. It samples frames at regular intervals (default: 0.4s), but if the number of frames exceeds the limit defined by `max_num_grids * max_image_cnt`, it performs uniform sampling to stay within that limit. Args: play_time (float): Total play time of the video in seconds. total_frames (int): Total number of frames in the video. fps (float): Frames per second of the video. max_num_grids (int): Maximum number of grids to display. max_image_cnt (int): Maximum number of images per grid. default_interval (float, optional): Interval in seconds between frame samples. Defaults to 0.4. Returns: Tuple: frame_indices (List[int]): A list of selected frame indices. time_interval (float): Time interval between selected frames (in seconds). """ # Calculate how many frames to extract with the default interval default_frame_count = int(play_time / default_interval) # Maximum frames allowed based on max_num_grids and max_image_cnt max_frames_allowed = max_num_grids * max_image_cnt # Determine whether we can use the default interval or need uniform sampling if default_frame_count <= max_frames_allowed: # Default interval is sufficient, extract frames every 0.4 seconds frame_interval = int(total_frames / default_frame_count) else: # Use uniform sampling to fit within max_frames_allowed frame_interval = int(total_frames / max_frames_allowed) # Extract frame indices at the calculated interval selected_indices = list(range(0, total_frames, frame_interval)) time_interval = frame_interval / fps # Ensure the number of selected indices does not exceed max_frames_allowed return selected_indices[:max_frames_allowed], time_interval