File size: 3,895 Bytes
c078e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
File: vlm.py
Description: Vision language model utility functions.
Author: Didier Guillevic
Date: 2025-03-16
"""

from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import TextIteratorStreamer
from threading import Thread
import torch

#
# Load the model: google/gemma-3-4b-it
#
device = 'mps'
model_id = "google/gemma-3-4b-it"
processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left")
model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16
).to(device).eval()

#
# Build messages
#
def build_messages(message: dict, history: list[tuple]):
    """Build messages given message & history from a **multimodal** chat interface.

    Args:
        message: dictionary with keys: 'text', 'files'
        history: list of tuples with (message, response)
    
    Returns:
        list of messages (to be sent to the model)
    """
    # Get the user's text and list of images
    user_text = message.get("text", "")
    user_images = message.get("files", [])  # List of images

    # Build the message list including history
    messages = []
    combined_user_input = [] #Combine images and text if found in same turn.
    for user_turn, bot_turn in history:
        if isinstance(user_turn, tuple):  # Image input
            image_content = [{"type": "image", "url": image_url} for image_url in user_turn]
            combined_user_input.extend(image_content)
        elif isinstance(user_turn, str): #Text input
            combined_user_input.append({"type":"text", "text": user_turn})
        if combined_user_input and bot_turn:
            messages.append({'role': 'user', 'content': combined_user_input})
            messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]})
            combined_user_input = [] #reset the combined user input.
    
    # Build the user message's content from the provided message
    user_content = []
    if user_text:
        user_content.append({"type": "text", "text": user_text})
    for image in user_images:
        user_content.append({"type": "image", "url": image})
    
    messages.append({'role': 'user', 'content': user_content})

    return messages


#
# Streaming response
#
@torch.inference_mode()
def stream_response(messages: list[dict]):
    """Stream the model's response to the chat interface.
    
    Args:
        messages: list of messages to send to the model
    """
    # Generate model's response
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)

    streamer = TextIteratorStreamer(
        processor, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=1_024,
        do_sample=False
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    partial_message = ""
    for new_text in streamer:
        partial_message += new_text
        yield partial_message


#
# Response (non-streaming)
#
@torch.inference_mode()
def get_response(messages: list[dict]):
    """Get the model's response.
    
    Args:
        messages: list of messages to send to the model
    """
    # Generate model's response
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True,
        return_dict=True, return_tensors="pt"
    ).to(model.device, dtype=torch.bfloat16)

    input_len = inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]

    decoded = processor.decode(generation, skip_special_tokens=True)
    
    return decoded