File size: 4,188 Bytes
1e83452
 
11c1c09
 
 
 
 
c3e97a0
4014a39
fb37d1d
c07895d
dcf8e95
 
0407a7f
dcf8e95
1e83452
 
 
 
11c1c09
903237b
 
 
 
 
 
db7d1ef
1e83452
013667b
11c1c09
0dbf527
 
 
 
 
0407a7f
 
7c77dd1
1b792af
11c1c09
 
22a8481
87ee71d
 
 
 
 
 
 
11c1c09
1e83452
87ee71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4918f06
7c09928
87ee71d
 
 
 
 
 
 
 
 
 
5757838
1e83452
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
from huggingface_hub import InferenceClient
import spaces
import torch
import os
from huggingface_hub import login
from PIL import Image
from threading import Thread
import platform
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import time
from transformers import pipeline




"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
# client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Python version: {platform.python_version()}")
print(f"Pytorch version: {torch.__version__}")
print(f"Gradio version: {gr. __version__}")
duration=None

login(token = os.getenv('gemma'))

# messages = [
#     {"role": "user", "content": "Who are you?"},
# ]
# pipe = pipeline("image-text-to-text", model="google/gemma-3-4b-it")
# print(pipe(messages))


ckpt = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16).to("cuda")
processor = AutoProcessor.from_pretrained(ckpt)

@spaces.GPU(duration=duration)
def bot_streaming(message, history, max_new_tokens=250):
    
    txt = message["text"]
    ext_buffer = f"{txt}"
    
    messages= [] 
    images = []
    

    for i, msg in enumerate(history): 
        if isinstance(msg[0], tuple):
            messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
            images.append(Image.open(msg[0][0]).convert("RGB"))
        elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
            # messages are already handled
            pass
        elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
            messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
            messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})

    # add current message
    if len(message["files"]) == 1:
        
        if isinstance(message["files"][0], str): # examples
            image = Image.open(message["files"][0]).convert("RGB")
        else: # regular input
            image = Image.open(message["files"][0]["path"]).convert("RGB")
        images.append(image)
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
    else:
        messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})


    texts = processor.apply_chat_template(messages, add_generation_prompt=True)

    if images == []:
        inputs = processor(text=texts, return_tensors="pt").to("cuda")
    else:
        inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)

    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
    generated_text = ""
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    buffer = ""
    
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        time.sleep(0.01)
        yield buffer


demo = gr.ChatInterface(fn=bot_streaming, 
                        title="Multimodal Gemma 3 Model by Google", 
      textbox=gr.MultimodalTextbox(), 
      additional_inputs = [gr.Slider(
              minimum=10,
              maximum=3000,
              value=500,
              step=10,
              label="Maximum number of new tokens to generate",
          )
        ],
      cache_examples=False,
      description="Upload an image, and start chatting about it, or just enter any text into the prompt to start.",
      stop_btn="Stop Generation", 
      fill_height=True,
    multimodal=True)
    
demo.launch(debug=True, share=True)