nyuuzyou commited on
Commit
a1d286e
·
verified ·
1 Parent(s): 8640a7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -102
app.py CHANGED
@@ -6,123 +6,118 @@ import time
6
  import torch
7
  import spaces
8
  import subprocess
9
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
-
11
  from io import BytesIO
12
 
13
- processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct")
14
- model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM2-2.2B-Instruct",
15
- _attn_implementation="flash_attention_2",
16
- torch_dtype=torch.bfloat16).to("cuda:0")
 
 
 
 
17
 
18
 
19
  @spaces.GPU
20
  def model_inference(
21
  input_dict, history, max_tokens
22
- ):
23
- text = input_dict["text"]
24
- images = []
25
  user_content = []
26
  media_queue = []
27
- if history == []:
28
- text = input_dict["text"].strip()
29
-
30
- for file in input_dict.get("files", []):
31
- if file.endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
32
- media_queue.append({"type": "image", "path": file})
33
- elif file.endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
34
- media_queue.append({"type": "video", "path": file})
35
-
36
- if "<image>" in text or "<video>" in text:
37
- parts = re.split(r'(<image>|<video>)', text)
38
- for part in parts:
39
- if part == "<image>" and media_queue:
40
- user_content.append(media_queue.pop(0))
41
- elif part == "<video>" and media_queue:
42
- user_content.append(media_queue.pop(0))
43
- elif part.strip():
44
- user_content.append({"type": "text", "text": part.strip()})
45
- else:
46
- user_content.append({"type": "text", "text": text})
47
-
48
- for media in media_queue:
49
- user_content.append(media)
50
-
51
- resulting_messages = [{"role": "user", "content": user_content}]
52
-
53
- elif len(history) > 0:
54
- resulting_messages = []
55
- user_content = []
56
- media_queue = []
57
- for hist in history:
58
- if hist["role"] == "user" and isinstance(hist["content"], tuple):
59
- file_name = hist["content"][0]
60
- if file_name.endswith((".png", ".jpg", ".jpeg")):
61
- media_queue.append({"type": "image", "path": file_name})
62
- elif file_name.endswith(".mp4"):
63
- media_queue.append({"type": "video", "path": file_name})
64
-
65
-
66
- for hist in history:
67
- if hist["role"] == "user" and isinstance(hist["content"], str):
68
- text = hist["content"]
69
- parts = re.split(r'(<image>|<video>)', text)
70
-
71
- for part in parts:
72
- if part == "<image>" and media_queue:
73
- user_content.append(media_queue.pop(0))
74
- elif part == "<video>" and media_queue:
75
- user_content.append(media_queue.pop(0))
76
- elif part.strip():
77
- user_content.append({"type": "text", "text": part.strip()})
78
-
79
- elif hist["role"] == "assistant":
80
- resulting_messages.append({
81
- "role": "user",
82
- "content": user_content
83
- })
84
- resulting_messages.append({
85
- "role": "assistant",
86
- "content": [{"type": "text", "text": hist["content"]}]
87
- })
88
- user_content = []
89
-
90
-
91
- if text == "" and not images:
92
- gr.Error("Please input a query and optionally image(s).")
93
-
94
- if text == "" and images:
95
- gr.Error("Please input a text query along the images(s).")
96
  print("resulting_messages", resulting_messages)
97
- inputs = processor.apply_chat_template(
98
- resulting_messages,
99
- add_generation_prompt=True,
100
- tokenize=True,
101
- return_dict=True,
102
- return_tensors="pt",
103
- )
 
 
 
 
 
 
 
 
104
 
105
  inputs = inputs.to(model.device)
106
-
107
 
108
- # Generate
109
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
110
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
111
  generated_text = ""
112
 
113
  thread = Thread(target=model.generate, kwargs=generation_args)
114
  thread.start()
115
 
116
- yield "..."
117
  buffer = ""
118
-
119
-
120
  for new_text in streamer:
121
-
122
- buffer += new_text
123
- generated_text_without_prompt = buffer#[len(ext_buffer):]
124
- time.sleep(0.01)
125
- yield buffer
126
 
127
 
128
  examples=[
@@ -133,16 +128,15 @@ examples=[
133
  [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
134
  [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
135
  ]
136
- demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
137
- description="Play with [SmolVLM2-2.2B-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-2.2B-Instruct) in this demo. To get started, upload an image and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
138
  examples=examples,
139
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
140
  cache_examples=False,
141
  additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
142
  type="messages"
143
  )
144
-
145
-
146
 
147
- demo.launch(debug=True)
148
-
 
 
6
  import torch
7
  import spaces
8
  import subprocess
 
 
9
  from io import BytesIO
10
 
11
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
+
13
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM2-256M-Video-Instruct")
14
+ model = AutoModelForImageTextToText.from_pretrained(
15
+ "HuggingFaceTB/SmolVLM2-256M-Video-Instruct",
16
+ _attn_implementation="flash_attention_2",
17
+ torch_dtype=torch.bfloat16
18
+ ).to("cuda:0")
19
 
20
 
21
  @spaces.GPU
22
  def model_inference(
23
  input_dict, history, max_tokens
24
+ ):
25
+
26
+ text = input_dict["text"].strip()
27
  user_content = []
28
  media_queue = []
29
+
30
+ for file_path in input_dict.get("files", []):
31
+ if file_path.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")):
32
+ media_queue.append({"type": "image", "path": file_path})
33
+ elif file_path.lower().endswith((".mp4", ".mov", ".avi", ".mkv", ".flv")):
34
+ media_queue.append({"type": "video", "path": file_path})
35
+
36
+
37
+ if not text and not media_queue:
38
+ gr.Warning("Please input a query and optionally image(s)/video(s).")
39
+ return
40
+
41
+ if not text and media_queue:
42
+ gr.Warning("Please input a text query along with the image(s)/video(s).")
43
+ return
44
+
45
+
46
+ if "<image>" in text or "<video>" in text:
47
+ parts = re.split(r'(<image>|<video>)', text)
48
+ temp_media_queue = list(media_queue)
49
+ for part in parts:
50
+ if part == "<image>" and temp_media_queue:
51
+ media_item = temp_media_queue.pop(0)
52
+ if media_item["type"] == "image":
53
+ user_content.append(media_item)
54
+ else:
55
+ gr.Warning(f"Placeholder <image> found, but next media is a video: {media_item['path']}. Skipping placeholder.")
56
+ user_content.append({"type": "text", "text": part})
57
+ temp_media_queue.insert(0, media_item)
58
+ elif part == "<video>" and temp_media_queue:
59
+ media_item = temp_media_queue.pop(0)
60
+ if media_item["type"] == "video":
61
+ user_content.append(media_item)
62
+ else:
63
+ gr.Warning(f"Placeholder <video> found, but next media is an image: {media_item['path']}. Skipping placeholder.")
64
+ user_content.append({"type": "text", "text": part})
65
+ temp_media_queue.insert(0, media_item)
66
+ elif part.strip():
67
+ user_content.append({"type": "text", "text": part.strip()})
68
+ elif part in ["<image>", "<video>"] and not temp_media_queue:
69
+ gr.Warning(f"Placeholder {part} found, but no more media items available.")
70
+ user_content.append({"type": "text", "text": part})
71
+
72
+ user_content.extend(temp_media_queue)
73
+
74
+ else:
75
+ if text:
76
+ user_content.append({"type": "text", "text": text})
77
+ user_content.extend(media_queue)
78
+
79
+
80
+ resulting_messages = [{"role": "user", "content": user_content}]
81
+
82
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  print("resulting_messages", resulting_messages)
84
+
85
+ try:
86
+ inputs = processor.apply_chat_template(
87
+ resulting_messages,
88
+ add_generation_prompt=True,
89
+ tokenize=True,
90
+ return_dict=True,
91
+ return_tensors="pt",
92
+ )
93
+ except Exception as e:
94
+ gr.Error(f"Error during input processing: {e}")
95
+ print(f"Processor Error: {e}")
96
+ print("Problematic message structure:", resulting_messages)
97
+ return
98
+
99
 
100
  inputs = inputs.to(model.device)
 
101
 
102
+
103
+ if "pixel_values" in inputs:
104
+ inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
105
+
106
+
107
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
108
  generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_tokens)
109
  generated_text = ""
110
 
111
  thread = Thread(target=model.generate, kwargs=generation_args)
112
  thread.start()
113
 
 
114
  buffer = ""
115
+
 
116
  for new_text in streamer:
117
+ buffer += new_text
118
+ yield buffer
119
+
120
+ thread.join()
 
121
 
122
 
123
  examples=[
 
128
  [{"text": "What is the date in this document?", "files": ["example_images/document.jpg"]}],
129
  [{"text": "What is happening in the video?", "files": ["example_images/short.mp4"]}],
130
  ]
131
+ demo = gr.ChatInterface(fn=model_inference, title="SmolVLM2: The Smollest Video Model Ever 📺",
132
+ description="Play with [SmolVLM2-256M-Video-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM2-256M-Video-Instruct) in this demo. To get started, upload an image/video and text or try one of the examples. This demo doesn't use history for the chat, so every chat you start is a new conversation.",
133
  examples=examples,
134
  textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image", ".mp4"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
135
  cache_examples=False,
136
  additional_inputs=[gr.Slider(minimum=100, maximum=500, step=50, value=200, label="Max Tokens")],
137
  type="messages"
138
  )
 
 
139
 
140
+
141
+
142
+ demo.launch(debug=True)