Didier commited on
Commit
37f622b
·
verified ·
1 Parent(s): 012d6a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -87
app.py CHANGED
@@ -5,7 +5,6 @@ Author: Didier Guillevic
5
  Date: 2025-03-16
6
  """
7
 
8
- import spaces
9
  from huggingface_hub import login, whoami
10
  import os
11
  token = os.getenv('HF_TOKEN')
@@ -17,91 +16,13 @@ from transformers import TextIteratorStreamer
17
  from threading import Thread
18
  import torch
19
 
20
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
- model_id = "google/gemma-3-4b-it"
22
- processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left")
23
- model = Gemma3ForConditionalGeneration.from_pretrained(
24
- model_id,
25
- torch_dtype=torch.bfloat16
26
- ).to(device).eval()
27
 
28
- @torch.inference_mode()
29
- @spaces.GPU
30
- def process(message, history):
31
- """Generate the model response in streaming mode given message and history
32
- """
33
- print(f"{history=}")
34
- # Get the user's text and list of images
35
- user_text = message.get("text", "")
36
- user_images = message.get("files", []) # List of images
37
 
38
- # Build the message list including history
39
- messages = []
40
- combined_user_input = [] # Combine images and text if found in same turn.
41
- for user_turn, bot_turn in history:
42
- if isinstance(user_turn, tuple): # Image input
43
- image_content = [{"type": "image", "url": image_url} for image_url in user_turn]
44
- combined_user_input.extend(image_content)
45
- elif isinstance(user_turn, str): # Text input
46
- combined_user_input.append({"type":"text", "text": user_turn})
47
- if combined_user_input and bot_turn:
48
- messages.append({'role': 'user', 'content': combined_user_input})
49
- messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]})
50
- combined_user_input = [] # reset the combined user input.
51
-
52
- # Build the user message's content from the provided message
53
- user_content = []
54
- if user_text:
55
- user_content.append({"type": "text", "text": user_text})
56
- for image in user_images:
57
- user_content.append({"type": "image", "url": image})
58
-
59
- messages.append({'role': 'user', 'content': user_content})
60
-
61
- # Generate model's response
62
- inputs = processor.apply_chat_template(
63
- messages, add_generation_prompt=True, tokenize=True,
64
- return_dict=True, return_tensors="pt"
65
- ).to(model.device, dtype=torch.bfloat16)
66
-
67
- streamer = TextIteratorStreamer(
68
- processor, skip_prompt=True, skip_special_tokens=True)
69
- generation_kwargs = dict(
70
- inputs,
71
- streamer=streamer,
72
- max_new_tokens=1_024,
73
- do_sample=False
74
- )
75
-
76
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
77
- thread.start()
78
-
79
- partial_message = ""
80
- for new_text in streamer:
81
- partial_message += new_text
82
- yield partial_message
83
-
84
-
85
- #
86
- # User interface
87
- #
88
- with gr.Blocks() as demo:
89
- chat_interface = gr.ChatInterface(
90
- fn=process,
91
- title="Multimedia Chat",
92
- description="Chat with text or text+image.",
93
- multimodal=True,
94
- examples=[
95
- "How can we rationalize quantum entanglement?",
96
- "Peux-tu expliquer le terme 'quantum spin'?",
97
- {'files': ['./sample_ID.jpeg',], 'text': 'Describe this image in a few words.'},
98
- {
99
- 'files': ['./sample_ID.jpeg',],
100
- 'text': (
101
- 'Could you extract the information present in the image '
102
- 'and present it as a bulleted list?')
103
- },
104
- ]
105
- )
106
-
107
- demo.launch()
 
5
  Date: 2025-03-16
6
  """
7
 
 
8
  from huggingface_hub import login, whoami
9
  import os
10
  token = os.getenv('HF_TOKEN')
 
16
  from threading import Thread
17
  import torch
18
 
19
+ from module_chat import demo as chat_block
20
+ from module_translation import demo as translation_block
 
 
 
 
 
21
 
22
+ demo = gr.TabbedInterface(
23
+ interface_list=[chat_block, translation_block],
24
+ tab_names=["Chat", "Translation"],
25
+ title="Chat with a vision language model"
26
+ )
 
 
 
 
27
 
28
+ demo.launch()