""" File: app.py Description: Chat with the vision language model Gemma3. Author: Didier Guillevic Date: 2025-03-16 """ import gradio as gr from transformers import AutoProcessor, Gemma3ForConditionalGeneration from transformers import TextIteratorStreamer from threading import Thread import torch device = 'cuda' model_id = "google/gemma-3-4b-it" processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16 ).to(device).eval() def process(message, history): """Generate the model response in streaming mode given message and history """ print(f"{history=}") # Get the user's text and list of images user_text = message.get("text", "") user_images = message.get("files", []) # List of images # Build the message list including history messages = [] combined_user_input = [] # Combine images and text if found in same turn. for user_turn, bot_turn in history: if isinstance(user_turn, tuple): # Image input image_content = [{"type": "image", "url": image_url} for image_url in user_turn] combined_user_input.extend(image_content) elif isinstance(user_turn, str): # Text input combined_user_input.append({"type":"text", "text": user_turn}) if combined_user_input and bot_turn: messages.append({'role': 'user', 'content': combined_user_input}) messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]}) combined_user_input = [] # reset the combined user input. # Build the user message's content from the provided message user_content = [] if user_text: user_content.append({"type": "text", "text": user_text}) for image in user_images: user_content.append({"type": "image", "url": image}) messages.append({'role': 'user', 'content': user_content}) # Generate model's response inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer( processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=1_024, do_sample=False ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_message = "" for new_text in streamer: partial_message += new_text yield partial_message # # User interface # with gr.Blocks() as demo: chat_interface = gr.ChatInterface( fn=process, title="Multimedia Chat", description="Chat with text or text+image.", multimodal=True, examples=[ "How can we rationalize quantum entanglement?", {'files': ['./sample_ID.jpeg',], 'text': 'Describe this image in a few words.'}, "Peux-tu expliquer le 'quantum spin'?" ] ) demo.launch()