KingNish commited on
Commit
d680493
·
verified ·
1 Parent(s): 57b80dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -67
app.py CHANGED
@@ -1,37 +1,53 @@
1
- import gradio as gr
2
- import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
- from PIL import Image
7
- import subprocess
8
- import numpy as np
9
  import os
10
- from threading import Thread
 
 
11
  import uuid
12
  import io
 
13
 
14
- # Model and Processor Loading (Done once at startup)
15
- MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
16
- model = Qwen2VLForConditionalGeneration.from_pretrained(
17
- MODEL_ID,
18
- trust_remote_code=True,
19
- torch_dtype=torch.float16
20
- ).to("cuda").eval()
21
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
22
 
23
- DESCRIPTION = "[Qwen2-VL-7B Demo](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)"
 
 
 
 
 
 
24
 
 
 
 
25
  image_extensions = Image.registered_extensions()
26
- video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
27
 
28
 
29
  def identify_and_save_blob(blob_path):
30
- """Identifies if the blob is an image or video and saves it accordingly."""
 
 
 
31
  try:
32
  with open(blob_path, 'rb') as file:
33
  blob_content = file.read()
34
-
35
  # Try to identify if it's an image
36
  try:
37
  Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
@@ -39,106 +55,184 @@ def identify_and_save_blob(blob_path):
39
  media_type = "image"
40
  except (IOError, SyntaxError):
41
  # If it's not a valid image, assume it's a video
42
- extension = ".mp4" # Default to MP4 for saving
 
 
 
 
 
 
43
  media_type = "video"
44
-
45
  # Create a unique filename
46
  filename = f"temp_{uuid.uuid4()}_media{extension}"
47
  with open(filename, "wb") as f:
48
  f.write(blob_content)
49
-
50
  return filename, media_type
51
-
52
  except FileNotFoundError:
53
  raise ValueError(f"The file {blob_path} was not found.")
54
  except Exception as e:
55
  raise ValueError(f"An error occurred while processing the file: {e}")
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @spaces.GPU
59
- def qwen_inference(media_input, text_input=None):
60
- if isinstance(media_input, str): # If it's a filepath
61
- media_path = media_input
62
- if media_path.endswith(tuple([i for i, f in image_extensions.items()])):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  media_type = "image"
64
- elif media_path.endswith(video_extensions):
 
65
  media_type = "video"
66
  else:
 
67
  try:
68
- media_path, media_type = identify_and_save_blob(media_input)
69
- print(media_path, media_type)
70
  except Exception as e:
71
- print(e)
72
- raise ValueError(
73
- "Unsupported media type. Please upload an image or video."
74
- )
75
-
76
-
77
- print(media_path)
78
-
79
- messages = [
80
- {
81
- "role": "user",
82
- "content": [
83
- {
84
- "type": media_type,
85
- media_type: media_path,
86
- **({"fps": 8.0} if media_type == "video" else {}),
87
- },
88
- {"type": "text", "text": text_input},
89
- ],
90
- }
91
- ]
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  text = processor.apply_chat_template(
94
  messages, tokenize=False, add_generation_prompt=True
95
  )
96
- image_inputs, video_inputs = process_vision_info(messages)
97
  inputs = processor(
98
  text=[text],
99
  images=image_inputs,
100
  videos=video_inputs,
101
  padding=True,
102
  return_tensors="pt",
103
- ).to("cuda")
104
 
 
105
  streamer = TextIteratorStreamer(
106
  processor, skip_prompt=True, **{"skip_special_tokens": True}
107
  )
108
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
109
 
 
110
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
111
  thread.start()
112
 
113
  buffer = ""
114
  for new_text in streamer:
115
  buffer += new_text
116
- yield buffer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  css = """
119
  #output {
120
- height: 500px;
121
- overflow: auto;
122
- border: 1px solid #ccc;
123
  }
124
  """
125
 
126
  with gr.Blocks(css=css) as demo:
127
  gr.Markdown(DESCRIPTION)
128
-
129
- with gr.Tab(label="Image/Video Input"):
130
  with gr.Row():
131
  with gr.Column():
 
132
  input_media = gr.File(
133
- label="Upload Image or Video", type="filepath"
 
134
  )
135
- text_input = gr.Textbox(label="Question")
 
 
 
136
  submit_btn = gr.Button(value="Submit")
137
  with gr.Column():
138
- output_text = gr.Textbox(label="Output Text")
 
139
 
140
- submit_btn.click(
141
- qwen_inference, [input_media, text_input], [output_text]
142
- )
143
 
144
  demo.launch(debug=True)
 
1
+ # Standard library imports
 
 
 
 
 
 
 
2
  import os
3
+ from datetime import datetime
4
+ import subprocess
5
+ import time
6
  import uuid
7
  import io
8
+ from threading import Thread
9
 
10
+ # Third-party imports
11
+ import numpy as np
12
+ import torch
13
+ from PIL import Image
14
+ import accelerate
15
+ import gradio as gr
16
+ import spaces
17
+ from transformers import (
18
+ Qwen2_5_VLForConditionalGeneration,
19
+ AutoTokenizer,
20
+ AutoProcessor,
21
+ TextIteratorStreamer
22
+ )
23
+
24
+ # Local imports
25
+ from qwen_vl_utils import process_vision_info
26
 
27
+ # Set device agnostic code
28
+ if torch.cuda.is_available():
29
+ device = "cuda"
30
+ elif (torch.backends.mps.is_available()) and (torch.backends.mps.is_built()):
31
+ device = "mps"
32
+ else:
33
+ device = "cpu"
34
 
35
+ print(f"[INFO] Using device: {device}")
36
+
37
+ # Define supported media extensions
38
  image_extensions = Image.registered_extensions()
39
+ video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "gif", "webm", "m4v", "3gp") # Removed .wav as it's audio, not video
40
 
41
 
42
  def identify_and_save_blob(blob_path):
43
+ """
44
+ Identifies if the blob is an image or video and saves it with a unique name.
45
+ Returns the saved file path and its media type ("image" or "video").
46
+ """
47
  try:
48
  with open(blob_path, 'rb') as file:
49
  blob_content = file.read()
50
+
51
  # Try to identify if it's an image
52
  try:
53
  Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
 
55
  media_type = "image"
56
  except (IOError, SyntaxError):
57
  # If it's not a valid image, assume it's a video
58
+ # We can try to get the actual extension from the blob_path,
59
+ # but for unknown types, MP4 is a good default.
60
+ _, ext = os.path.splitext(blob_path)
61
+ if ext.lower() in video_extensions:
62
+ extension = ext.lower()
63
+ else:
64
+ extension = ".mp4" # Default to MP4 for saving
65
  media_type = "video"
66
+
67
  # Create a unique filename
68
  filename = f"temp_{uuid.uuid4()}_media{extension}"
69
  with open(filename, "wb") as f:
70
  f.write(blob_content)
71
+
72
  return filename, media_type
73
+
74
  except FileNotFoundError:
75
  raise ValueError(f"The file {blob_path} was not found.")
76
  except Exception as e:
77
  raise ValueError(f"An error occurred while processing the file: {e}")
78
 
79
 
80
+ # Model and Processor Loading
81
+ # Define models and processors as dictionaries for easy selection
82
+ models = {
83
+ "Qwen/Qwen2.5-VL-7B-Instruct": Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct",
84
+ trust_remote_code=True,
85
+ torch_dtype="auto",
86
+ device_map="auto").eval(),
87
+ "Qwen/Qwen2.5-VL-3B-Instruct": Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct",
88
+ trust_remote_code=True,
89
+ torch_dtype="auto",
90
+ device_map="auto").eval()
91
+ }
92
+
93
+ processors = {
94
+ "Qwen/Qwen2.5-VL-7B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True),
95
+ "Qwen/Qwen2.5-VL-3B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", trust_remote_code=True)
96
+ }
97
+
98
+ DESCRIPTION = "[Qwen2.5-VL Demo](https://huggingface.co/collections/Qwen/qwen25-vl-6795ffac22b334a837c0f9a5)"
99
+
100
  @spaces.GPU
101
+ def run_example(media_input, text_input=None, model_id=None):
102
+ if media_input is None:
103
+ raise gr.Error("No media provided. Please upload an image or video before submitting.")
104
+ if model_id is None:
105
+ raise gr.Error("No model selected. Please select a model.")
106
+
107
+ start_time = time.time()
108
+
109
+ media_path = None
110
+ media_type = None
111
+
112
+ # Determine if it's an image (numpy array from gr.Image) or a file (from gr.File)
113
+ if isinstance(media_input, np.ndarray): # This comes from gr.Image
114
+ img = Image.fromarray(np.uint8(media_input))
115
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
116
+ filename = f"image_{timestamp}.png"
117
+ img.save(filename)
118
+ media_path = os.path.abspath(filename)
119
+ media_type = "image"
120
+ elif isinstance(media_input, str): # This comes from gr.File (filepath)
121
+ path = media_input
122
+ _, ext = os.path.splitext(path)
123
+ ext = ext.lower()
124
+
125
+ if ext in image_extensions:
126
+ media_path = path
127
  media_type = "image"
128
+ elif ext in video_extensions:
129
+ media_path = path
130
  media_type = "video"
131
  else:
132
+ # For blobs or unknown file types, try to identify
133
  try:
134
+ media_path, media_type = identify_and_save_blob(path)
135
+ print(f"Identified blob as: {media_type}, saved to: {media_path}")
136
  except Exception as e:
137
+ print(f"Error identifying blob: {e}")
138
+ raise gr.Error("Unsupported media type. Please upload an image (PNG, JPG, etc.) or a video (MP4, AVI, etc.).")
139
+ else:
140
+ raise gr.Error("Unsupported input type for media. Please upload an image or video.")
141
+
142
+ print(f"[INFO] Processing {media_type} from {media_path}")
143
+
144
+ model = models[model_id]
145
+ processor = processors[model_id]
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Construct messages list based on media type
148
+ content_list = []
149
+ if media_type == "image":
150
+ content_list.append({"type": "image", "image": media_path})
151
+ elif media_type == "video":
152
+ content_list.append({"type": "video", "video": media_path, "fps": 8.0}) # Qwen2.5-VL often uses 8fps
153
+
154
+ if text_input:
155
+ content_list.append({"type": "text", "text": text_input})
156
+ else:
157
+ # Default prompt if no text_input is provided
158
+ content_list.append({"type": "text", "text": "What is in this image/video?"})
159
+
160
+
161
+ messages = [{"role": "user", "content": content_list}]
162
+
163
+ # Preparation for inference
164
  text = processor.apply_chat_template(
165
  messages, tokenize=False, add_generation_prompt=True
166
  )
167
+ image_inputs, video_inputs = process_vision_info(messages) # This utility handles both image and video info
168
  inputs = processor(
169
  text=[text],
170
  images=image_inputs,
171
  videos=video_inputs,
172
  padding=True,
173
  return_tensors="pt",
174
+ ).to(device)
175
 
176
+ # Inference: Generation of the output using streaming
177
  streamer = TextIteratorStreamer(
178
  processor, skip_prompt=True, **{"skip_special_tokens": True}
179
  )
180
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
181
 
182
+ # Start generation in a separate thread to allow streaming
183
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
184
  thread.start()
185
 
186
  buffer = ""
187
  for new_text in streamer:
188
  buffer += new_text
189
+ yield buffer, None # Yield partial text and None for time until full generation
190
+ # Clean up the temporary file after it's processed (optional, depends on use case)
191
+ # if media_path and os.path.exists(media_path) and "temp_" in os.path.basename(media_path):
192
+ # os.remove(media_path)
193
+
194
+
195
+ end_time = time.time()
196
+ total_time = round(end_time - start_time, 2)
197
+
198
+ # Final yield with total time
199
+ yield buffer, f"{total_time} seconds"
200
+
201
+ # Clean up the temporary file after it's fully processed
202
+ if media_path and os.path.exists(media_path) and "temp_" in os.path.basename(media_path):
203
+ os.remove(media_path)
204
+ print(f"[INFO] Cleaned up temporary file: {media_path}")
205
+
206
 
207
  css = """
208
  #output {
209
+ height: 500px;
210
+ overflow: auto;
211
+ border: 1px solid #ccc;
212
  }
213
  """
214
 
215
  with gr.Blocks(css=css) as demo:
216
  gr.Markdown(DESCRIPTION)
217
+ with gr.Tab(label="Qwen2.5-VL Input"):
 
218
  with gr.Row():
219
  with gr.Column():
220
+ # Change input to gr.File to accept both image and video
221
  input_media = gr.File(
222
+ label="Upload Image or Video (JPG, PNG, MP4, AVI, etc.)",
223
+ type="filepath" # Use 'filepath' to get the path to the temp file
224
  )
225
+ model_selector = gr.Dropdown(choices=list(models.keys()),
226
+ label="Model",
227
+ value="Qwen/Qwen2.5-VL-7B-Instruct")
228
+ text_input = gr.Textbox(label="Text Prompt")
229
  submit_btn = gr.Button(value="Submit")
230
  with gr.Column():
231
+ output_text = gr.Textbox(label="Output Text", interactive=False)
232
+ time_taken = gr.Textbox(label="Time taken for processing + inference", interactive=False)
233
 
234
+ submit_btn.click(run_example,
235
+ [input_media, text_input, model_selector],
236
+ [output_text, time_taken]) # Ensure output components match yield order
237
 
238
  demo.launch(debug=True)