ruslanmv commited on
Commit
35932eb
·
verified ·
1 Parent(s): c1f9bb7

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +251 -285
src/app.py CHANGED
@@ -1,106 +1,122 @@
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
3
- from collections.abc import Iterator
4
- from datetime import datetime
5
- from pathlib import Path
6
- from threading import Thread
7
-
8
- import gradio as gr
9
- import spaces
10
- import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
-
13
- from themes.research_monochrome import theme
14
-
15
- # Vision imports
16
- import random
17
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
18
-
19
- today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
20
-
21
- SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
22
- Today's Date: {today_date}.
23
- You are Granite, developed by IBM. You are a helpful AI assistant"""
24
- TITLE = "IBM Granite 3.1 8b Instruct & Vision Preview"
25
- DESCRIPTION = """
26
- <p>Granite 3.1 8b instruct is an open-source LLM supporting a 128k context window. Start with one of the sample prompts
27
- or upload an image and ask a question. Keep in mind that AI can occasionally make mistakes.
28
- <span class="gr_docs_link">
29
- <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
30
- </span>
31
- </p>
32
- """
33
- MAX_INPUT_TOKEN_LENGTH = 128_000
34
- MAX_NEW_TOKENS = 1024
35
- TEMPERATURE = 0.7
36
- TOP_P = 0.85
37
- TOP_K = 50
38
- REPETITION_PENALTY = 1.05
39
-
40
- if not torch.cuda.is_available():
41
- print("This demo may not work on CPU.")
42
-
43
- # Text Model and Tokenizer
44
- text_model = AutoModelForCausalLM.from_pretrained(
45
- "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
46
- )
47
- text_tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
48
- text_tokenizer.use_default_system_prompt = False
49
-
50
- # Vision Model and Processor
51
- vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview"
52
- vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True)
53
- vision_model = LlavaNextForConditionalGeneration.from_pretrained(vision_model_path, torch_dtype="auto", device_map="auto")
54
-
55
-
56
- @spaces.GPU
57
- def generate(
58
- message: str,
59
- chat_history: list[dict],
60
- temperature: float = TEMPERATURE,
61
- repetition_penalty: float = REPETITION_PENALTY,
62
- top_p: float = TOP_P,
63
- top_k: float = TOP_K,
64
- max_new_tokens: int = MAX_NEW_TOKENS,
65
- ) -> Iterator[str]:
66
- """Generate function for text chat demo."""
67
- # Build messages
68
- conversation = []
69
- conversation.append({"role": "system", "content": SYS_PROMPT})
70
- conversation += chat_history
71
- conversation.append({"role": "user", "content": message})
72
-
73
- # Convert messages to prompt format
74
- input_ids = text_tokenizer.apply_chat_template(
75
- conversation,
76
- return_tensors="pt",
77
- add_generation_prompt=True,
78
- truncation=True,
79
- max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
80
- )
81
-
82
- input_ids = input_ids.to(text_model.device)
83
- streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
84
- generate_kwargs = dict(
85
- {"input_ids": input_ids},
86
- streamer=streamer,
87
- max_new_tokens=max_new_tokens,
88
- do_sample=True,
89
- top_p=top_p,
90
- top_k=top_k,
91
- temperature=temperature,
92
- num_beams=1,
93
- repetition_penalty=repetition_penalty,
94
- )
95
-
96
- t = Thread(target=text_model.generate, kwargs=generate_kwargs)
97
- t.start()
98
-
99
- outputs = []
100
- for text in streamer:
101
- outputs.append(text)
102
- yield "".join(outputs)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def get_text_from_content(content):
105
  texts = []
106
  for item in content:
@@ -111,23 +127,17 @@ def get_text_from_content(content):
111
  return " ".join(texts)
112
 
113
  @spaces.GPU
114
- def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversation):
115
  if conversation is None:
116
  conversation = []
117
-
118
  user_content = []
119
  if image is not None:
120
  user_content.append({"type": "image", "image": image})
121
  if text and text.strip():
122
  user_content.append({"type": "text", "text": text.strip()})
123
  if not user_content:
124
- return conversation_display(conversation), conversation
125
-
126
- conversation.append({
127
- "role": "user",
128
- "content": user_content
129
- })
130
-
131
  inputs = vision_processor.apply_chat_template(
132
  conversation,
133
  add_generation_prompt=True,
@@ -135,9 +145,7 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
135
  return_dict=True,
136
  return_tensors="pt"
137
  ).to("cuda")
138
-
139
  torch.manual_seed(random.randint(0, 10000))
140
-
141
  generation_kwargs = {
142
  "max_new_tokens": max_tokens,
143
  "temperature": temperature,
@@ -145,207 +153,165 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
145
  "top_k": top_k,
146
  "do_sample": True,
147
  }
148
-
149
  output = vision_model.generate(**inputs, **generation_kwargs)
150
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- conversation.append({
153
- "role": "assistant",
154
- "content": [{"type": "text", "text": assistant_response.strip()}]
155
- })
156
-
157
- return conversation_display(conversation), conversation
158
-
159
- def conversation_display(conversation):
160
  chat_history = []
161
- for msg in conversation:
162
- if msg["role"] == "user":
163
- user_text = get_text_from_content(msg["content"])
164
- chat_history.append({"role": "user", "content": user_text})
165
- elif msg["role"] == "assistant":
166
- assistant_text = msg["content"][0]["text"].split("<|assistant|>")[-1].strip()
167
- chat_history.append({"role": "assistant", "content": assistant_text})
 
 
 
 
 
 
 
168
  return chat_history
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def clear_chat():
171
- return [], [], "", None
 
172
 
 
 
 
173
  css_file_path = Path(Path(__file__).parent / "app.css")
174
  head_file_path = Path(Path(__file__).parent / "app_head.html")
175
 
176
- # Advanced settings (displayed in Accordion) - Common settings for both models
177
- temperature_slider = gr.Slider(
178
- minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
179
- )
180
- top_p_slider = gr.Slider(
181
- minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
182
- )
183
- top_k_slider = gr.Slider(
184
- minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
185
- )
186
-
187
- # Advanced settings specific to Text model
188
- repetition_penalty_slider = gr.Slider(
189
- minimum=0,
190
- maximum=2.0,
191
- value=REPETITION_PENALTY,
192
- step=0.05,
193
- label="Repetition Penalty (Text Model)",
194
- elem_classes=["gr_accordion_element"],
195
- )
196
- max_new_tokens_slider = gr.Slider(
197
- minimum=1,
198
- maximum=2000,
199
- value=MAX_NEW_TOKENS,
200
- step=1,
201
- label="Max New Tokens (Text Model)",
202
- elem_classes=["gr_accordion_element"],
203
- )
204
-
205
- # Advanced settings specific to Vision model
206
- max_tokens_slider_vision = gr.Slider(
207
- minimum=10,
208
- maximum=300,
209
- value=128,
210
- step=1,
211
- label="Max Tokens (Vision Model)",
212
- elem_classes=["gr_accordion_element"],
213
- )
214
-
215
- chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)
216
-
217
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
218
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
219
  gr.HTML(DESCRIPTION)
220
-
221
- state = gr.State([]) # State for vision chat history
222
- chat_history_state = gr.State([]) # State for text chat history
223
-
224
  with gr.Row():
225
  with gr.Column(scale=2):
226
  image_input = gr.Image(type="pil", label="Upload Image (optional)")
227
- with gr.Accordion(label="Vision Model Settings", open=False):
228
- max_tokens_input_vision = max_tokens_slider_vision
229
- with gr.Accordion(label="Text Model Settings", open=False):
230
- repetition_penalty_input = repetition_penalty_slider
231
- max_new_tokens_input = max_new_tokens_slider
232
- with chat_interface_accordion: # Common Settings
233
- temperature_input = temperature_slider
234
- top_p_input = top_p_slider
235
- top_k_input = top_k_slider
236
-
237
- with gr.Column(scale=3):
238
- chatbot = gr.Chatbot(label="Chat History", elem_id="chatbot", type='messages')
239
  text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
240
- with gr.Row():
241
- send_button = gr.Button("Chat")
242
- clear_button = gr.Button("Clear Chat")
243
-
244
- def process_chat(image_input, text_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, max_new_tokens_input, max_tokens_input_vision, state, chat_history_state):
245
- if image_input:
246
- # Use Vision model
247
- return chat_inference(image_input, text_input, temperature_input, top_p_input, top_k_input, max_tokens_input_vision, state)
248
- else:
249
- # Use Text model
250
- return generate(text_input, chat_history_state, temperature_input, repetition_penalty_input, top_p_input, top_k_input, max_new_tokens_input), None # Return None for state as text model doesn't use it
251
-
252
- def process_chat_wrapper(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val):
253
- if image_input_val:
254
- chatbot_output, updated_state = process_chat(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val)
255
- return chatbot_output, updated_state, chat_history_state_val # Return vision state and keep text state unchanged
256
- else:
257
- chatbot_output_generator, _ = process_chat(image_input_val, text_input_val, temperature_input_val, top_p_input_val, top_k_input_val, repetition_penalty_input_val, max_new_tokens_input_val, max_tokens_input_vision_val, state_val, chat_history_state_val)
258
- updated_chat_history = []
259
- full_response = ""
260
- for response_chunk in chatbot_output_generator:
261
- full_response = response_chunk
262
- if chat_history_state_val is None:
263
- updated_chat_history = []
264
- else:
265
- updated_chat_history = chat_history_state_val
266
-
267
- updated_chat_history.append({"role": "user", "content": text_input_val})
268
- updated_chat_history.append({"role": "assistant", "content": full_response})
269
-
270
- return updated_chat_history, state_val, updated_chat_history # Return text chat history, keep vision state unchanged, return updated text history for chatbot display
271
-
272
-
273
  send_button.click(
274
- process_chat_wrapper,
275
- inputs=[image_input, text_input, temperature_input, top_p_input, top_k_input, repetition_penalty_input, max_new_tokens_input, max_tokens_input_vision, state, chat_history_state],
276
- outputs=[chatbot, state, chat_history_state] # Keep both states as output
 
 
 
 
 
277
  )
278
-
279
  clear_button.click(
280
  clear_chat,
281
  inputs=None,
282
- outputs=[chatbot, state, text_input, image_input] # clear_chat clears vision state and input. Need to clear text state also.
283
  )
284
-
285
  gr.Examples(
286
  examples=[
287
- ["Explain the concept of quantum computing to someone with no background in physics or computer science."],
288
- ["What is OpenShift?"],
289
- ["What's the importance of low latency inference?"],
290
- ["Help me boost productivity habits."],
291
- [
292
- """Explain the following code in a concise manner:
293
-
294
- ```java
295
- import java.util.ArrayList;
296
- import java.util.List;
297
-
298
- public class Main {
299
-
300
- public static void main(String[] args) {
301
- int[] arr = {1, 5, 3, 4, 2};
302
- int diff = 3;
303
- List<Pair> pairs = findPairs(arr, diff);
304
- for (Pair pair : pairs) {
305
- System.out.println(pair.x + " " + pair.y);
306
- }
307
- }
308
-
309
- public static List<Pair> findPairs(int[] arr, int diff) {
310
- List<Pair> pairs = new ArrayList<>();
311
- for (int i = 0; i < arr.length; i++) {
312
- for (int j = i + 1; j < arr.length; j++) {
313
- if (Math.abs(arr[i] - arr[j]) < diff) {
314
- pairs.add(new Pair(arr[i], arr[j]));
315
- }
316
- }
317
- }
318
-
319
- return pairs;
320
- }
321
- }
322
-
323
- class Pair {
324
- int x;
325
- int y;
326
- public Pair(int x, int y) {
327
- this.x = x;
328
- this.y = y;
329
- }
330
- }
331
- ```"""
332
- ],
333
- [
334
- """Generate a Java code block from the following explanation:
335
-
336
- The code in the Main class finds all pairs in an array whose absolute difference is less than a given value.
337
-
338
- The findPairs method takes two arguments: an array of integers and a difference value. It iterates over the array and compares each element to every other element in the array. If the absolute difference between the two elements is less than the difference value, a new Pair object is created and added to a list.
339
-
340
- The Pair class is a simple data structure that stores two integers.
341
-
342
- The main method creates an array of integers, initializes the difference value, and calls the findPairs method to find all pairs in the array. Finally, the code iterates over the list of pairs and prints each pair to the console.""" # noqa: E501
343
- ],
344
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is this?"] # Vision example
345
  ],
346
- inputs=[text_input, text_input, text_input, text_input, text_input, text_input, image_input, image_input] , # Duplicated text_input to match example count, last two are image_input for vision example
347
- examples_per_page=7
348
  )
349
 
350
  if __name__ == "__main__":
351
- demo.queue().launch()
 
1
  """Template Demo for IBM Granite Hugging Face spaces."""
2
 
3
+ from collections.abc import Iterator
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from threading import Thread
7
+
8
+ import gradio as gr
9
+ import spaces
10
+ import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
13
+ import random
14
+
15
+ from themes.research_monochrome import theme
16
+
17
+ # =============================================================================
18
+ # Constants & Prompts
19
+ # =============================================================================
20
+ today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
21
+ SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
22
+ Today's Date: {today_date}.
23
+ You are Granite, developed by IBM. You are a helpful AI assistant"""
24
+ TITLE = "IBM Granite 3.1 8b Instruct & Vision Preview"
25
+ DESCRIPTION = """
26
+ <p>Granite 3.1 8b instruct is an opensource LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts
27
+ or enter your own. Keep in mind that AI can occasionally make mistakes.
28
+ <span class="gr_docs_link">
29
+ <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
30
+ </span>
31
+ </p>
32
+ """
33
+ MAX_INPUT_TOKEN_LENGTH = 128_000
34
+ MAX_NEW_TOKENS = 1024
35
+ TEMPERATURE = 0.7
36
+ TOP_P = 0.85
37
+ TOP_K = 50
38
+ REPETITION_PENALTY = 1.05
39
+
40
+ # Vision defaults (advanced settings)
41
+ VISION_TEMPERATURE = 0.2
42
+ VISION_TOP_P = 0.95
43
+ VISION_TOP_K = 50
44
+ VISION_MAX_TOKENS = 128
45
+
46
+ if not torch.cuda.is_available():
47
+ print("This demo may not work on CPU.")
48
+
49
+ # =============================================================================
50
+ # Text Model Loading
51
+ # =============================================================================
52
+ text_model = AutoModelForCausalLM.from_pretrained(
53
+ "ibm-granite/granite-3.1-8b-instruct",
54
+ torch_dtype=torch.float16,
55
+ device_map="auto"
56
+ )
57
+ tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
58
+ tokenizer.use_default_system_prompt = False
59
+
60
+ # =============================================================================
61
+ # Vision Model Loading
62
+ # =============================================================================
63
+ vision_model_path = "ibm-granite/granite-vision-3.1-2b-preview"
64
+ vision_processor = LlavaNextProcessor.from_pretrained(vision_model_path, use_fast=True)
65
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
66
+ vision_model_path,
67
+ torch_dtype=torch.float16,
68
+ device_map="auto",
69
+ trust_remote_code=True # Ensure the custom code is used so that weight shapes match.
70
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # =============================================================================
73
+ # Text Generation Function (for text-only chat)
74
+ # =============================================================================
75
+ @spaces.GPU
76
+ def generate(
77
+ message: str,
78
+ chat_history: list[dict],
79
+ temperature: float = TEMPERATURE,
80
+ repetition_penalty: float = REPETITION_PENALTY,
81
+ top_p: float = TOP_P,
82
+ top_k: float = TOP_K,
83
+ max_new_tokens: int = MAX_NEW_TOKENS,
84
+ ) -> Iterator[str]:
85
+ """Generate function for text chat demo."""
86
+ conversation = []
87
+ conversation.append({"role": "system", "content": SYS_PROMPT})
88
+ conversation.extend(chat_history)
89
+ conversation.append({"role": "user", "content": message})
90
+ input_ids = tokenizer.apply_chat_template(
91
+ conversation,
92
+ return_tensors="pt",
93
+ add_generation_prompt=True,
94
+ truncation=True,
95
+ max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
96
+ )
97
+ input_ids = input_ids.to(text_model.device)
98
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
99
+ generate_kwargs = dict(
100
+ {"input_ids": input_ids},
101
+ streamer=streamer,
102
+ max_new_tokens=max_new_tokens,
103
+ do_sample=True,
104
+ top_p=top_p,
105
+ top_k=top_k,
106
+ temperature=temperature,
107
+ num_beams=1,
108
+ repetition_penalty=repetition_penalty,
109
+ )
110
+ t = Thread(target=text_model.generate, kwargs=generate_kwargs)
111
+ t.start()
112
+ outputs = []
113
+ for text in streamer:
114
+ outputs.append(text)
115
+ yield "".join(outputs)
116
+
117
+ # =============================================================================
118
+ # Vision Chat Inference Function (for image+text chat)
119
+ # =============================================================================
120
  def get_text_from_content(content):
121
  texts = []
122
  for item in content:
 
127
  return " ".join(texts)
128
 
129
  @spaces.GPU
130
+ def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS):
131
  if conversation is None:
132
  conversation = []
 
133
  user_content = []
134
  if image is not None:
135
  user_content.append({"type": "image", "image": image})
136
  if text and text.strip():
137
  user_content.append({"type": "text", "text": text.strip()})
138
  if not user_content:
139
+ return display_vision_conversation(conversation), conversation
140
+ conversation.append({"role": "user", "content": user_content})
 
 
 
 
 
141
  inputs = vision_processor.apply_chat_template(
142
  conversation,
143
  add_generation_prompt=True,
 
145
  return_dict=True,
146
  return_tensors="pt"
147
  ).to("cuda")
 
148
  torch.manual_seed(random.randint(0, 10000))
 
149
  generation_kwargs = {
150
  "max_new_tokens": max_tokens,
151
  "temperature": temperature,
 
153
  "top_k": top_k,
154
  "do_sample": True,
155
  }
 
156
  output = vision_model.generate(**inputs, **generation_kwargs)
157
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
158
+ conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response.strip()}]})
159
+ return display_vision_conversation(conversation), conversation
160
+
161
+ # =============================================================================
162
+ # Helper Functions to Format Conversation for Display
163
+ # =============================================================================
164
+ def display_text_conversation(conversation):
165
+ """Convert a text conversation (list of dicts) into a list of (user, assistant) tuples."""
166
+ chat_history = []
167
+ i = 0
168
+ while i < len(conversation):
169
+ if conversation[i]["role"] == "user":
170
+ user_msg = conversation[i]["content"]
171
+ assistant_msg = ""
172
+ if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
173
+ assistant_msg = conversation[i+1]["content"]
174
+ i += 2
175
+ else:
176
+ i += 1
177
+ chat_history.append((user_msg, assistant_msg))
178
+ else:
179
+ i += 1
180
+ return chat_history
181
 
182
+ def display_vision_conversation(conversation):
183
+ """Convert a vision conversation (with mixed content types) into a list of (user, assistant) tuples."""
 
 
 
 
 
 
184
  chat_history = []
185
+ i = 0
186
+ while i < len(conversation):
187
+ if conversation[i]["role"] == "user":
188
+ user_msg = get_text_from_content(conversation[i]["content"])
189
+ assistant_msg = ""
190
+ if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
191
+ # Extract assistant text; remove any special tokens if present.
192
+ assistant_msg = conversation[i+1]["content"][0]["text"].split("<|assistant|>")[-1].strip()
193
+ i += 2
194
+ else:
195
+ i += 1
196
+ chat_history.append((user_msg, assistant_msg))
197
+ else:
198
+ i += 1
199
  return chat_history
200
 
201
+ # =============================================================================
202
+ # Unified Send-Message Function
203
+ # =============================================================================
204
+ def send_message(image, text,
205
+ text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens,
206
+ vision_temperature, vision_top_p, vision_top_k, vision_max_tokens,
207
+ text_state, vision_state):
208
+ """
209
+ If an image is uploaded, use the vision model; otherwise, use the text model.
210
+ Returns updated conversation (as a list of tuples) and state for each branch.
211
+ """
212
+ if image is not None:
213
+ # Vision branch
214
+ conv = vision_state if vision_state is not None else []
215
+ chat_history, updated_conv = chat_inference(
216
+ image, text, conv,
217
+ temperature=vision_temperature,
218
+ top_p=vision_top_p,
219
+ top_k=vision_top_k,
220
+ max_tokens=vision_max_tokens
221
+ )
222
+ vision_state = updated_conv
223
+ # In vision mode, the conversation display is produced from the vision branch.
224
+ return chat_history, text_state, vision_state
225
+ else:
226
+ # Text branch
227
+ conv = text_state if text_state is not None else []
228
+ output_text = ""
229
+ for chunk in generate(
230
+ text, conv,
231
+ temperature=text_temperature,
232
+ repetition_penalty=text_repetition_penalty,
233
+ top_p=text_top_p,
234
+ top_k=text_top_k,
235
+ max_new_tokens=text_max_new_tokens
236
+ ):
237
+ output_text = chunk
238
+ conv.append({"role": "user", "content": text})
239
+ conv.append({"role": "assistant", "content": output_text})
240
+ text_state = conv
241
+ chat_history = display_text_conversation(text_state)
242
+ return chat_history, text_state, vision_state
243
+
244
  def clear_chat():
245
+ # Clear the conversation and input fields.
246
+ return [], [], [], None # (chat_history, text_state, vision_state, cleared text and image inputs)
247
 
248
+ # =============================================================================
249
+ # UI Layout with Gradio
250
+ # =============================================================================
251
  css_file_path = Path(Path(__file__).parent / "app.css")
252
  head_file_path = Path(Path(__file__).parent / "app_head.html")
253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
255
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
256
  gr.HTML(DESCRIPTION)
257
+
258
+ chatbot = gr.Chatbot(label="Chat History", height=500)
259
+
 
260
  with gr.Row():
261
  with gr.Column(scale=2):
262
  image_input = gr.Image(type="pil", label="Upload Image (optional)")
 
 
 
 
 
 
 
 
 
 
 
 
263
  text_input = gr.Textbox(lines=2, placeholder="Enter your message here", label="Message")
264
+ with gr.Column(scale=1):
265
+ with gr.Accordion("Text Advanced Settings", open=False):
266
+ text_temperature_slider = gr.Slider(minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"])
267
+ repetition_penalty_slider = gr.Slider(minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition Penalty", elem_classes=["gr_accordion_element"])
268
+ top_p_slider = gr.Slider(minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"])
269
+ top_k_slider = gr.Slider(minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"])
270
+ max_new_tokens_slider = gr.Slider(minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens", elem_classes=["gr_accordion_element"])
271
+ with gr.Accordion("Vision Advanced Settings", open=False):
272
+ vision_temperature_slider = gr.Slider(minimum=0.0, maximum=2.0, value=VISION_TEMPERATURE, step=0.01, label="Vision Temperature", elem_classes=["gr_accordion_element"])
273
+ vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"])
274
+ vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"])
275
+ vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"])
276
+
277
+ send_button = gr.Button("Send Message")
278
+ clear_button = gr.Button("Clear Chat")
279
+
280
+ # Conversation state variables for each branch.
281
+ text_state = gr.State([])
282
+ vision_state = gr.State([])
283
+
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  send_button.click(
285
+ send_message,
286
+ inputs=[
287
+ image_input, text_input,
288
+ text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider,
289
+ vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider,
290
+ text_state, vision_state
291
+ ],
292
+ outputs=[chatbot, text_state, vision_state]
293
  )
294
+
295
  clear_button.click(
296
  clear_chat,
297
  inputs=None,
298
+ outputs=[chatbot, text_state, vision_state, text_input, image_input]
299
  )
300
+
301
  gr.Examples(
302
  examples=[
303
+ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "What is in this image?"],
304
+ ["Explain quantum computing to a beginner.", None],
305
+ ["What is OpenShift?", None]
306
+ ],
307
+ inputs=[image_input, text_input],
308
+ example_labels=[
309
+ "Vision Example: What is in this image?",
310
+ "Explain quantum computing",
311
+ "What is OpenShift?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  ],
313
+ cache_examples=False,
 
314
  )
315
 
316
  if __name__ == "__main__":
317
+ demo.queue().launch()