ruslanmv commited on
Commit
ebb25b2
·
verified ·
1 Parent(s): 15534ac

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +83 -52
src/app.py CHANGED
@@ -1,51 +1,29 @@
1
  """Developed by Ruslan Magana Vsevolodovna"""
2
-
3
  from collections.abc import Iterator
4
  from datetime import datetime
5
  from pathlib import Path
6
  from threading import Thread
7
-
8
  import gradio as gr
9
  import spaces
10
  import torch
11
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
13
  import random
14
-
15
  from themes.research_monochrome import theme
16
 
17
  # =============================================================================
18
  # Constants & Prompts
19
  # =============================================================================
20
- today_date = datetime.today().strftime("%B %-d, %Y")
21
- SYS_PROMPT = """
22
- Today's Date: {today_date}.
23
- You are a helpful AI assistant.
24
- Respond in the following format:
25
- <reasoning>
26
- ...
27
- </reasoning>
28
- <answer>
29
- ...
30
- </answer>
31
- """
32
-
33
  TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview"
34
- DESCRIPTION = """
35
- <p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts
36
- or enter your own. Keep in mind that AI can occasionally make mistakes.
37
- <span class="gr_docs_link">
38
- <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
39
- </span>
40
- </p>
41
- """
42
  MAX_INPUT_TOKEN_LENGTH = 128_000
43
  MAX_NEW_TOKENS = 1024
44
  TEMPERATURE = 0.7
45
  TOP_P = 0.85
46
  TOP_K = 50
47
  REPETITION_PENALTY = 1.05
48
-
49
  # Vision defaults (advanced settings)
50
  VISION_TEMPERATURE = 0.2
51
  VISION_TOP_P = 0.95
@@ -54,18 +32,13 @@ VISION_MAX_TOKENS = 128
54
 
55
  if not torch.cuda.is_available():
56
  print("This demo may not work on CPU.")
57
-
58
  # =============================================================================
59
  # Text Model Loading
60
  # =============================================================================
61
-
62
  #Standard Model
63
  #granite_text_model="ibm-granite/granite-3.1-8b-instruct"
64
-
65
  #With Reasoning
66
  granite_text_model="ruslanmv/granite-3.1-8b-Reasoning"
67
-
68
-
69
  text_model = AutoModelForCausalLM.from_pretrained(
70
  granite_text_model,
71
  torch_dtype=torch.float16,
@@ -73,7 +46,6 @@ text_model = AutoModelForCausalLM.from_pretrained(
73
  )
74
  tokenizer = AutoTokenizer.from_pretrained(granite_text_model)
75
  tokenizer.use_default_system_prompt = False
76
-
77
  # =============================================================================
78
  # Vision Model Loading
79
  # =============================================================================
@@ -83,9 +55,8 @@ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
83
  vision_model_path,
84
  torch_dtype=torch.float16,
85
  device_map="auto",
86
- trust_remote_code=True # Ensure the custom code is used so that weight shapes match.
87
  )
88
-
89
  # =============================================================================
90
  # Text Generation Function (for text-only chat)
91
  # =============================================================================
@@ -99,7 +70,7 @@ def generate(
99
  top_k: float = TOP_K,
100
  max_new_tokens: int = MAX_NEW_TOKENS,
101
  ) -> Iterator[str]:
102
- """Generate function for text chat demo."""
103
  conversation = []
104
  conversation.append({"role": "system", "content": SYS_PROMPT})
105
  conversation.extend(chat_history)
@@ -126,10 +97,45 @@ def generate(
126
  )
127
  t = Thread(target=text_model.generate, kwargs=generate_kwargs)
128
  t.start()
 
129
  outputs = []
 
 
 
 
 
130
  for text in streamer:
131
  outputs.append(text)
132
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # =============================================================================
135
  # Vision Chat Inference Function (for image+text chat)
@@ -172,7 +178,30 @@ def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, to
172
  }
173
  output = vision_model.generate(**inputs, **generation_kwargs)
174
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
175
- conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response.strip()}]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  return display_vision_conversation(conversation), conversation
177
 
178
  # =============================================================================
@@ -206,7 +235,12 @@ def display_vision_conversation(conversation):
206
  assistant_msg = ""
207
  if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
208
  # Extract assistant text; remove any special tokens if present.
209
- assistant_msg = conversation[i+1]["content"][0]["text"].split("<|assistant|>")[-1].strip()
 
 
 
 
 
210
  i += 2
211
  else:
212
  i += 1
@@ -214,7 +248,6 @@ def display_vision_conversation(conversation):
214
  else:
215
  i += 1
216
  return chat_history
217
-
218
  # =============================================================================
219
  # Unified Send-Message Function
220
  # =============================================================================
@@ -251,29 +284,28 @@ def send_message(image, text,
251
  top_k=text_top_k,
252
  max_new_tokens=text_max_new_tokens
253
  ):
254
- output_text = chunk
 
255
  conv.append({"role": "user", "content": text})
256
- conv.append({"role": "assistant", "content": output_text})
257
  text_state = conv
258
- chat_history = display_text_conversation(text_state)
259
  return chat_history, text_state, vision_state
260
 
261
  def clear_chat():
262
  # Clear the conversation and input fields.
263
  return [], [], [], None # (chat_history, text_state, vision_state, cleared text and image inputs)
264
-
265
  # =============================================================================
266
  # UI Layout with Gradio
267
  # =============================================================================
268
  css_file_path = Path(Path(__file__).parent / "app.css")
269
  head_file_path = Path(Path(__file__).parent / "app_head.html")
270
-
271
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
272
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
273
  gr.HTML(DESCRIPTION)
274
-
275
  chatbot = gr.Chatbot(label="Chat History", height=500)
276
-
277
  with gr.Row():
278
  with gr.Column(scale=2):
279
  image_input = gr.Image(type="pil", label="Upload Image (optional)")
@@ -290,14 +322,13 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
290
  vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"])
291
  vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"])
292
  vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"])
293
-
294
- send_button = gr.Button("Send Message")
295
  clear_button = gr.Button("Clear Chat")
296
-
297
  # Conversation state variables for each branch.
298
  text_state = gr.State([])
299
  vision_state = gr.State([])
300
-
301
  send_button.click(
302
  send_message,
303
  inputs=[
@@ -308,13 +339,13 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
308
  ],
309
  outputs=[chatbot, text_state, vision_state]
310
  )
311
-
312
  clear_button.click(
313
  clear_chat,
314
  inputs=None,
315
  outputs=[chatbot, text_state, vision_state, text_input, image_input]
316
  )
317
-
318
  gr.Examples(
319
  examples=[
320
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"],
@@ -339,4 +370,4 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
339
  )
340
 
341
  if __name__ == "__main__":
342
- demo.queue().launch()
 
1
  """Developed by Ruslan Magana Vsevolodovna"""
 
2
  from collections.abc import Iterator
3
  from datetime import datetime
4
  from pathlib import Path
5
  from threading import Thread
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
11
  import random
 
12
  from themes.research_monochrome import theme
13
 
14
  # =============================================================================
15
  # Constants & Prompts
16
  # =============================================================================
17
+ today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
18
+ SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.Today's Date: {today_date}.You are Granite, developed by IBM. You are a helpful AI assistant. Respond in the following format:<reasoning>Step-by-step reasoning to arrive at the answer.</reasoning><answer>The final answer to the user's query.</answer> If reasoning is not applicable, you can directly provide the <answer>."""
 
 
 
 
 
 
 
 
 
 
 
19
  TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview"
20
+ DESCRIPTION = """<p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample promptsor enter your own. Keep in mind that AI can occasionally make mistakes.<span class="gr_docs_link"><a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a></span></p>"""
 
 
 
 
 
 
 
21
  MAX_INPUT_TOKEN_LENGTH = 128_000
22
  MAX_NEW_TOKENS = 1024
23
  TEMPERATURE = 0.7
24
  TOP_P = 0.85
25
  TOP_K = 50
26
  REPETITION_PENALTY = 1.05
 
27
  # Vision defaults (advanced settings)
28
  VISION_TEMPERATURE = 0.2
29
  VISION_TOP_P = 0.95
 
32
 
33
  if not torch.cuda.is_available():
34
  print("This demo may not work on CPU.")
 
35
  # =============================================================================
36
  # Text Model Loading
37
  # =============================================================================
 
38
  #Standard Model
39
  #granite_text_model="ibm-granite/granite-3.1-8b-instruct"
 
40
  #With Reasoning
41
  granite_text_model="ruslanmv/granite-3.1-8b-Reasoning"
 
 
42
  text_model = AutoModelForCausalLM.from_pretrained(
43
  granite_text_model,
44
  torch_dtype=torch.float16,
 
46
  )
47
  tokenizer = AutoTokenizer.from_pretrained(granite_text_model)
48
  tokenizer.use_default_system_prompt = False
 
49
  # =============================================================================
50
  # Vision Model Loading
51
  # =============================================================================
 
55
  vision_model_path,
56
  torch_dtype=torch.float16,
57
  device_map="auto",
58
+ trust_remote_code=True # Ensure the custom code is used so that weight shapes match.)
59
  )
 
60
  # =============================================================================
61
  # Text Generation Function (for text-only chat)
62
  # =============================================================================
 
70
  top_k: float = TOP_K,
71
  max_new_tokens: int = MAX_NEW_TOKENS,
72
  ) -> Iterator[str]:
73
+ """Generate function for text chat demo with chain of thought display."""
74
  conversation = []
75
  conversation.append({"role": "system", "content": SYS_PROMPT})
76
  conversation.extend(chat_history)
 
97
  )
98
  t = Thread(target=text_model.generate, kwargs=generate_kwargs)
99
  t.start()
100
+
101
  outputs = []
102
+ reasoning_started = False
103
+ answer_started = False
104
+ collected_reasoning = ""
105
+ collected_answer = ""
106
+
107
  for text in streamer:
108
  outputs.append(text)
109
+ current_output = "".join(outputs)
110
+
111
+ if "<reasoning>" in current_output and not reasoning_started:
112
+ reasoning_started = True
113
+ reasoning_start_index = current_output.find("<reasoning>") + len("<reasoning>")
114
+ collected_reasoning = current_output[reasoning_start_index:]
115
+ yield "[Reasoning]: " # Indicate start of reasoning in chatbot
116
+ outputs = [collected_reasoning] # Reset outputs to only include reasoning part
117
+
118
+ elif reasoning_started and "<answer>" in current_output and not answer_started:
119
+ answer_started = True
120
+ reasoning_end_index = current_output.find("<answer>")
121
+ collected_reasoning = current_output[len("<reasoning>"):reasoning_end_index] # Correctly extract reasoning part
122
+
123
+ answer_start_index = current_output.find("<answer>") + len("<answer>")
124
+ collected_answer = current_output[answer_start_index:]
125
+ yield "\n[Answer]: " # Indicate start of answer in chatbot
126
+ outputs = [collected_answer] # Reset outputs to only include answer part
127
+ yield collected_answer # Yield initial part of answer
128
+
129
+ elif reasoning_started and not answer_started:
130
+ collected_reasoning = text # Accumulate reasoning tokens
131
+ yield text # Stream reasoning tokens
132
+
133
+ elif answer_started:
134
+ collected_answer += text # Accumulate answer tokens
135
+ yield text # Stream answer tokens
136
+ else:
137
+ yield text # In case no tags are found, stream as before
138
+
139
 
140
  # =============================================================================
141
  # Vision Chat Inference Function (for image+text chat)
 
178
  }
179
  output = vision_model.generate(**inputs, **generation_kwargs)
180
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
181
+
182
+ reasoning = ""
183
+ answer = ""
184
+ if "<reasoning>" in assistant_response and "<answer>" in assistant_response:
185
+ reasoning_start = assistant_response.find("<reasoning>") + len("<reasoning>")
186
+ reasoning_end = assistant_response.find("</reasoning>")
187
+ reasoning = assistant_response[reasoning_start:reasoning_end].strip()
188
+
189
+ answer_start = assistant_response.find("<answer>") + len("<answer>")
190
+ answer_end = assistant_response.find("</answer>")
191
+
192
+ if answer_end != -1: # Handle cases where answer end tag is present
193
+ answer = assistant_response[answer_start:answer_end].strip()
194
+ else: # Fallback if answer end tag is missing (less robust)
195
+ answer = assistant_response[answer_start:].strip()
196
+
197
+
198
+ formatted_response_content = []
199
+ if reasoning:
200
+ formatted_response_content.append({"type": "text", "text": f"[Reasoning]: {reasoning}"})
201
+ formatted_response_content.append({"type": "text", "text": f"[Answer]: {answer}"})
202
+
203
+
204
+ conversation.append({"role": "assistant", "content": formatted_response_content})
205
  return display_vision_conversation(conversation), conversation
206
 
207
  # =============================================================================
 
235
  assistant_msg = ""
236
  if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
237
  # Extract assistant text; remove any special tokens if present.
238
+ assistant_content = conversation[i+1]["content"]
239
+ assistant_text_parts = []
240
+ for item in assistant_content:
241
+ if item["type"] == "text":
242
+ assistant_text_parts.append(item["text"])
243
+ assistant_msg = "\n".join(assistant_text_parts).strip()
244
  i += 2
245
  else:
246
  i += 1
 
248
  else:
249
  i += 1
250
  return chat_history
 
251
  # =============================================================================
252
  # Unified Send-Message Function
253
  # =============================================================================
 
284
  top_k=text_top_k,
285
  max_new_tokens=text_max_new_tokens
286
  ):
287
+ output_text += chunk # Accumulate for display function to process correctly.
288
+
289
  conv.append({"role": "user", "content": text})
290
+ conv.append({"role": "assistant", "content": output_text}) # Store full output with tags
291
  text_state = conv
292
+ chat_history = display_text_conversation(text_state) # Display function handles tag parsing now.
293
  return chat_history, text_state, vision_state
294
 
295
  def clear_chat():
296
  # Clear the conversation and input fields.
297
  return [], [], [], None # (chat_history, text_state, vision_state, cleared text and image inputs)
 
298
  # =============================================================================
299
  # UI Layout with Gradio
300
  # =============================================================================
301
  css_file_path = Path(Path(__file__).parent / "app.css")
302
  head_file_path = Path(Path(__file__).parent / "app_head.html")
 
303
  with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
304
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
305
  gr.HTML(DESCRIPTION)
306
+
307
  chatbot = gr.Chatbot(label="Chat History", height=500)
308
+
309
  with gr.Row():
310
  with gr.Column(scale=2):
311
  image_input = gr.Image(type="pil", label="Upload Image (optional)")
 
322
  vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"])
323
  vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"])
324
  vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"])
325
+ send_button = gr.Button("Send Message")
 
326
  clear_button = gr.Button("Clear Chat")
327
+
328
  # Conversation state variables for each branch.
329
  text_state = gr.State([])
330
  vision_state = gr.State([])
331
+
332
  send_button.click(
333
  send_message,
334
  inputs=[
 
339
  ],
340
  outputs=[chatbot, text_state, vision_state]
341
  )
342
+
343
  clear_button.click(
344
  clear_chat,
345
  inputs=None,
346
  outputs=[chatbot, text_state, vision_state, text_input, image_input]
347
  )
348
+
349
  gr.Examples(
350
  examples=[
351
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"],
 
370
  )
371
 
372
  if __name__ == "__main__":
373
+ demo.queue().launch()