Spaces:
dxdcx
/
Running on CPU Upgrade

dxdcx commited on
Commit
3ecf5ec
·
verified ·
1 Parent(s): 5911439

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -1
app.py CHANGED
@@ -22,7 +22,146 @@ tf_client = InferenceClient(endpoint_url, token=hf_token)
22
 
23
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
 
25
- # ... [helper functions count_files_in_new_message, count_files_in_history, validate_media_constraints, downsample_video, process_video, process_interleaved_images, process_new_user_message, process_history stay unchanged] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def run(message: dict, history: list[dict]) -> Iterator[str]:
28
  if not validate_media_constraints(message, history):
 
22
 
23
  MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
24
 
25
+
26
+ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
+ image_count = 0
28
+ video_count = 0
29
+ for path in paths:
30
+ if path.endswith(".mp4"):
31
+ video_count += 1
32
+ else:
33
+ image_count += 1
34
+ return image_count, video_count
35
+
36
+
37
+ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
+ image_count = 0
39
+ video_count = 0
40
+ for item in history:
41
+ if item["role"] != "user" or isinstance(item["content"], str):
42
+ continue
43
+ if item["content"][0].endswith(".mp4"):
44
+ video_count += 1
45
+ else:
46
+ image_count += 1
47
+ return image_count, video_count
48
+
49
+
50
+ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
+ new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
+ history_image_count, history_video_count = count_files_in_history(history)
53
+ image_count = history_image_count + new_image_count
54
+ video_count = history_video_count + new_video_count
55
+ if video_count > 1:
56
+ gr.Warning("Only one video is supported.")
57
+ return False
58
+ if video_count == 1:
59
+ if image_count > 0:
60
+ gr.Warning("Mixing images and videos is not allowed.")
61
+ return False
62
+ if "<image>" in message["text"]:
63
+ gr.Warning("Using <image> tags with video files is not supported.")
64
+ return False
65
+ if video_count == 0 and image_count > MAX_NUM_IMAGES:
66
+ gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
67
+ return False
68
+ if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
69
+ gr.Warning("The number of <image> tags in the text does not match the number of images.")
70
+ return False
71
+ return True
72
+
73
+
74
+ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
75
+ vidcap = cv2.VideoCapture(video_path)
76
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
77
+ total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+
79
+ frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
80
+ frames: list[tuple[Image.Image, float]] = []
81
+
82
+ for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
83
+ if len(frames) >= MAX_NUM_IMAGES:
84
+ break
85
+
86
+ vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
87
+ success, image = vidcap.read()
88
+ if success:
89
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
90
+ pil_image = Image.fromarray(image)
91
+ timestamp = round(i / fps, 2)
92
+ frames.append((pil_image, timestamp))
93
+
94
+ vidcap.release()
95
+ return frames
96
+
97
+
98
+ def process_video(video_path: str) -> list[dict]:
99
+ content = []
100
+ frames = downsample_video(video_path)
101
+ for frame in frames:
102
+ pil_image, timestamp = frame
103
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
104
+ pil_image.save(temp_file.name)
105
+ content.append({"type": "text", "text": f"Frame {timestamp}:"})
106
+ content.append({"type": "image", "url": temp_file.name})
107
+ logger.debug(f"{content=}")
108
+ return content
109
+
110
+
111
+ def process_interleaved_images(message: dict) -> list[dict]:
112
+ logger.debug(f"{message['files']=}")
113
+ parts = re.split(r"(<image>)", message["text"])
114
+ logger.debug(f"{parts=}")
115
+
116
+ content = []
117
+ image_index = 0
118
+ for part in parts:
119
+ logger.debug(f"{part=}")
120
+ if part == "<image>":
121
+ content.append({"type": "image", "url": message["files"][image_index]})
122
+ logger.debug(f"file: {message['files'][image_index]}")
123
+ image_index += 1
124
+ elif part.strip():
125
+ content.append({"type": "text", "text": part.strip()})
126
+ elif isinstance(part, str) and part != "<image>":
127
+ content.append({"type": "text", "text": part})
128
+ logger.debug(f"{content=}")
129
+ return content
130
+
131
+
132
+ def process_new_user_message(message: dict) -> list[dict]:
133
+ if not message["files"]:
134
+ return [{"type": "text", "text": message["text"]}]
135
+
136
+ if message["files"][0].endswith(".mp4"):
137
+ return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
138
+
139
+ if "<image>" in message["text"]:
140
+ return process_interleaved_images(message)
141
+
142
+ return [
143
+ {"type": "text", "text": message["text"]},
144
+ *[{"type": "image", "url": path} for path in message["files"]],
145
+ ]
146
+
147
+
148
+ def process_history(history: list[dict]) -> list[dict]:
149
+ messages = []
150
+ current_user_content: list[dict] = []
151
+ for item in history:
152
+ if item["role"] == "assistant":
153
+ if current_user_content:
154
+ messages.append({"role": "user", "content": current_user_content})
155
+ current_user_content = []
156
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
157
+ else:
158
+ content = item["content"]
159
+ if isinstance(content, str):
160
+ current_user_content.append({"type": "text", "text": content})
161
+ else:
162
+ current_user_content.append({"type": "image", "url": content[0]})
163
+ return messages
164
+
165
 
166
  def run(message: dict, history: list[dict]) -> Iterator[str]:
167
  if not validate_media_constraints(message, history):