unidisc / demo /inference_utils.py
aswerdlow's picture
Initial commit
131da64
from __future__ import annotations
import base64
import copy
import io
import random
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import math
from PIL import Image
from image_utils import Im
from decoupled_utils import gprint
if TYPE_CHECKING:
from demo.server import ChatRequest
def tensor_center_crop(tensor_image, crop_size):
_, _, h, w = tensor_image.shape
while h >= 2 * crop_size[0] and w >= 2 * crop_size[1]:
tensor_image = F.interpolate(tensor_image, size=(h // 2, w // 2), mode='area')
_, _, h, w = tensor_image.shape
scale = max(crop_size[0] / h, crop_size[1] / w)
new_h, new_w = round(h * scale), round(w * scale)
tensor_image = F.interpolate(tensor_image, size=(new_h, new_w), mode='bilinear')
crop_top = random.randint(0, new_h - crop_size[0])
crop_left = random.randint(0, new_w - crop_size[1])
crop_bottom = crop_top + crop_size[0]
crop_right = crop_left + crop_size[1]
return tensor_image[:, :, crop_top:crop_bottom, crop_left:crop_right]
def parse_messages(messages: List[dict]) -> Tuple[List[Image.Image], List[List[dict]]]:
"""
Given a list of message dicts with format:
[
{"type": "text", "text": msg},
{"type": "image_url", "image_url": <PIL Image>}
]
Returns:
- all_images: a list containing the PIL images, in the order of their appearance
- all_content: a nested list (single conversation) with dicts indicating message type
"""
all_images: List[Image.Image] = []
conversation: List[dict] = []
for msg in messages:
if msg["type"] == "text":
conversation.append(msg)
elif msg["type"] == "image_url":
idx = len(all_images)
all_images.append(msg["image_url"])
_msg = copy.deepcopy(msg)
_msg["image_url"] = {"url": idx}
conversation.append(_msg)
else:
raise ValueError(f"Unsupported message type: {msg['type']}. Expected 'text' or 'image_url'.")
all_content = [conversation]
return all_images, all_content
def messages_to_batch(config, tokenizer, model, input_data, resolution):
import copy
from model import get_image_batch
from unidisc.tokenizers.tokenize_interleaved import _has_image, preprocess
# Build conversations and extract images.
all_images = []
conversations = []
for item in input_data:
role = item["role"]
assert role in ["user", "assistant"]
role = "human" if role == "user" else "gpt"
if item["type"] == "image_url":
token = "<image>"
all_images.append(item["image_url"])
elif item["type"] == "text":
token = item["text"]
else:
continue
if conversations and conversations[-1]["from"] == role:
conversations[-1]["value"] += " " + token
else:
conversations.append({"from": role, "value": token})
output_list = []
entry = {"id": "1", "conversations": conversations}
if all_images:
entry["image"] = {}
output_list.append(entry)
all_content = output_list
vae = model.get_vae()
device = model.device
if not all_images:
image_ids = None
else:
_img = torch.cat([
tensor_center_crop(
torch.from_numpy(np.array(img))[None, :].permute(0, 3, 1, 2) / 255,
(resolution, resolution)
) for img in all_images
])
try:
batch_size = 32
image_ids_list = []
for i in range(0, len(_img), batch_size):
batch = _img[i:i+batch_size]
batch_ids = get_image_batch(config, vae, {"img": batch}, device)
image_ids_list.append(batch_ids)
image_ids = torch.cat(image_ids_list)
except Exception as e:
gprint(f"{_img.shape}, {e}")
import traceback
traceback.print_exc()
all_input_ids = []
all_attention_masks = []
all_modality = []
assert len(all_content) == 1
for sources in all_content:
has_image = _has_image(sources)
sources = copy.deepcopy([sources["conversations"]])
_image_ids = image_ids if has_image else None
try:
print(f"Sources: {sources}")
data_dict = preprocess(sources, tokenizer, has_image=has_image, image_ids=_image_ids)
except Exception as e:
import traceback
traceback.print_exc()
gprint(f"Error in preprocess: {e}")
return None, None, None
input_ids = data_dict["input_ids"][0]
attention_mask = data_dict["attention_mask"][0]
modality = data_dict["modality"][0]
if (input_ids[-2:] == tokenizer.eos_token_id).all():
input_ids = input_ids[:-1]
attention_mask = attention_mask[:-1]
modality = modality[:-1]
assert config.model.length >= input_ids.shape[0], f"Input ids length {input_ids.shape[0]} is greater than model length {config.model.length}"
attention_mask = attention_mask.bool()
print(f"Attention mask: {attention_mask.shape}, input ids: {input_ids.shape}, modality: {modality.shape}")
if modality[-1] == 1:
is_image = modality == 1
change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1
if change_points.numel() > 0:
start_pos = change_points[-1].item()
modality[start_pos:] = 0
attention_mask[start_pos:] = False
input_ids[start_pos:] = tokenizer.pad_token_id
all_input_ids.append(input_ids)
all_attention_masks.append(attention_mask)
all_modality.append(modality)
all_input_ids = torch.stack(all_input_ids)
all_attention_masks = torch.stack(all_attention_masks)
all_modality = torch.stack(all_modality)
all_sample_ids = torch.zeros_like(all_modality, dtype=torch.long)
all_sample_ids[~all_attention_masks] = -1
batch = {
"input_ids": all_input_ids,
"attention_mask": all_attention_masks,
"modality": all_modality.long(),
"sample_ids": all_sample_ids.long(),
}
for k in batch:
batch[k] = batch[k].to(device)
batch["input_ids"] = torch.where(
(batch["modality"] == 1) & (batch["input_ids"] != -1),
batch["input_ids"] + config.data.img_token_shift,
batch["input_ids"]
)
return batch
def pil_to_base64(image: Image.Image) -> str:
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def convert_to_model_input(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
model_input = []
for msg in messages:
for part in msg.content:
if part.type == "text" and part.text:
model_input.append({
"type": "text",
"text": part.text,
"role": msg.role
})
elif part.type == "image_url" and part.image_url:
model_input.append({
"type": "image_url",
"image_url": part.image_url,
"role": msg.role
})
return model_input
def convert_request_pil_to_base64(request: ChatRequest) -> ChatRequest:
for msg in request.messages:
for part in msg.content:
if part.type == "image_url" and isinstance(part.image_url, Image.Image):
buffered = io.BytesIO()
part.image_url.convert("RGB").save(buffered, format="JPEG")
base64_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
part.image_url = {"url": f"data:image/jpeg;base64,{base64_str}"}
return request
def convert_request_base64_to_pil(request: ChatRequest) -> ChatRequest:
for message in request.messages:
for part in message.content:
if part.type == "image_url" and "url" in part.image_url:
image_data = part.image_url["url"]
# Remove any data URL header, e.g. "data:image/jpeg;base64,"
if image_data.startswith("data:"):
try:
header, image_data = image_data.split(",", 1)
except ValueError as e:
raise ValueError(
f"Invalid image URL format: {image_data}"
) from e
try:
decoded_bytes = base64.b64decode(image_data)
part.image_url = Image.open(io.BytesIO(decoded_bytes))
except Exception as e:
raise ValueError(
f"Error decoding or loading image. Ensure the base64 string is valid. Details: {e}"
) from e
return request
def trim_merge_messages(request: ChatRequest) -> ChatRequest:
# Remove empty text parts from each message
for msg in request.messages:
msg.content = [
part for part in msg.content
if not (part.type == "text" and part.text.strip() == "")
]
# Remove messages with no content
request.messages = [
msg for msg in request.messages
if msg.content
]
# Merge consecutive messages with the same role
merged_messages = []
for msg in request.messages:
if merged_messages and merged_messages[-1].role == msg.role:
merged_messages[-1].content.extend(msg.content)
else:
merged_messages.append(msg)
request.messages = merged_messages
return request
def save_grid_image(input_arr: torch.Tensor, output_name, row_len=None):
# Convert to boolean then to int (0/1)
x0_bool = input_arr.bool().long()
n = x0_bool.numel()
if row_len is None:
row_len = math.ceil(math.sqrt(n))
rows = math.ceil(n / row_len)
total = rows * row_len
# Pad with -1 to mark padded positions
padded = torch.full((total,), -1, dtype=torch.long)
padded[:n] = x0_bool
grid = padded.reshape(rows, row_len)
# Create an RGB image: false=black, true=white, padded=red
image = torch.zeros((rows, row_len, 3), dtype=torch.uint8)
mask_true = (grid == 1)
mask_padding = (grid == -1)
image[mask_true] = torch.tensor([255, 255, 255], dtype=torch.uint8)
image[mask_padding] = torch.tensor([255, 0, 0], dtype=torch.uint8)
img = Image.fromarray(image.numpy(), mode='RGB')
from datetime import datetime
output = Im(img).save(datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + output_name)
print(f"Saved visualization to {output}")
return row_len