Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,146 @@ tf_client = InferenceClient(endpoint_url, token=hf_token)
|
|
22 |
|
23 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def run(message: dict, history: list[dict]) -> Iterator[str]:
|
28 |
if not validate_media_constraints(message, history):
|
|
|
22 |
|
23 |
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
24 |
|
25 |
+
|
26 |
+
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
27 |
+
image_count = 0
|
28 |
+
video_count = 0
|
29 |
+
for path in paths:
|
30 |
+
if path.endswith(".mp4"):
|
31 |
+
video_count += 1
|
32 |
+
else:
|
33 |
+
image_count += 1
|
34 |
+
return image_count, video_count
|
35 |
+
|
36 |
+
|
37 |
+
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
38 |
+
image_count = 0
|
39 |
+
video_count = 0
|
40 |
+
for item in history:
|
41 |
+
if item["role"] != "user" or isinstance(item["content"], str):
|
42 |
+
continue
|
43 |
+
if item["content"][0].endswith(".mp4"):
|
44 |
+
video_count += 1
|
45 |
+
else:
|
46 |
+
image_count += 1
|
47 |
+
return image_count, video_count
|
48 |
+
|
49 |
+
|
50 |
+
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
51 |
+
new_image_count, new_video_count = count_files_in_new_message(message["files"])
|
52 |
+
history_image_count, history_video_count = count_files_in_history(history)
|
53 |
+
image_count = history_image_count + new_image_count
|
54 |
+
video_count = history_video_count + new_video_count
|
55 |
+
if video_count > 1:
|
56 |
+
gr.Warning("Only one video is supported.")
|
57 |
+
return False
|
58 |
+
if video_count == 1:
|
59 |
+
if image_count > 0:
|
60 |
+
gr.Warning("Mixing images and videos is not allowed.")
|
61 |
+
return False
|
62 |
+
if "<image>" in message["text"]:
|
63 |
+
gr.Warning("Using <image> tags with video files is not supported.")
|
64 |
+
return False
|
65 |
+
if video_count == 0 and image_count > MAX_NUM_IMAGES:
|
66 |
+
gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
|
67 |
+
return False
|
68 |
+
if "<image>" in message["text"] and message["text"].count("<image>") != new_image_count:
|
69 |
+
gr.Warning("The number of <image> tags in the text does not match the number of images.")
|
70 |
+
return False
|
71 |
+
return True
|
72 |
+
|
73 |
+
|
74 |
+
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
75 |
+
vidcap = cv2.VideoCapture(video_path)
|
76 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
77 |
+
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
78 |
+
|
79 |
+
frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
|
80 |
+
frames: list[tuple[Image.Image, float]] = []
|
81 |
+
|
82 |
+
for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
|
83 |
+
if len(frames) >= MAX_NUM_IMAGES:
|
84 |
+
break
|
85 |
+
|
86 |
+
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
87 |
+
success, image = vidcap.read()
|
88 |
+
if success:
|
89 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
90 |
+
pil_image = Image.fromarray(image)
|
91 |
+
timestamp = round(i / fps, 2)
|
92 |
+
frames.append((pil_image, timestamp))
|
93 |
+
|
94 |
+
vidcap.release()
|
95 |
+
return frames
|
96 |
+
|
97 |
+
|
98 |
+
def process_video(video_path: str) -> list[dict]:
|
99 |
+
content = []
|
100 |
+
frames = downsample_video(video_path)
|
101 |
+
for frame in frames:
|
102 |
+
pil_image, timestamp = frame
|
103 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
104 |
+
pil_image.save(temp_file.name)
|
105 |
+
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
106 |
+
content.append({"type": "image", "url": temp_file.name})
|
107 |
+
logger.debug(f"{content=}")
|
108 |
+
return content
|
109 |
+
|
110 |
+
|
111 |
+
def process_interleaved_images(message: dict) -> list[dict]:
|
112 |
+
logger.debug(f"{message['files']=}")
|
113 |
+
parts = re.split(r"(<image>)", message["text"])
|
114 |
+
logger.debug(f"{parts=}")
|
115 |
+
|
116 |
+
content = []
|
117 |
+
image_index = 0
|
118 |
+
for part in parts:
|
119 |
+
logger.debug(f"{part=}")
|
120 |
+
if part == "<image>":
|
121 |
+
content.append({"type": "image", "url": message["files"][image_index]})
|
122 |
+
logger.debug(f"file: {message['files'][image_index]}")
|
123 |
+
image_index += 1
|
124 |
+
elif part.strip():
|
125 |
+
content.append({"type": "text", "text": part.strip()})
|
126 |
+
elif isinstance(part, str) and part != "<image>":
|
127 |
+
content.append({"type": "text", "text": part})
|
128 |
+
logger.debug(f"{content=}")
|
129 |
+
return content
|
130 |
+
|
131 |
+
|
132 |
+
def process_new_user_message(message: dict) -> list[dict]:
|
133 |
+
if not message["files"]:
|
134 |
+
return [{"type": "text", "text": message["text"]}]
|
135 |
+
|
136 |
+
if message["files"][0].endswith(".mp4"):
|
137 |
+
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
138 |
+
|
139 |
+
if "<image>" in message["text"]:
|
140 |
+
return process_interleaved_images(message)
|
141 |
+
|
142 |
+
return [
|
143 |
+
{"type": "text", "text": message["text"]},
|
144 |
+
*[{"type": "image", "url": path} for path in message["files"]],
|
145 |
+
]
|
146 |
+
|
147 |
+
|
148 |
+
def process_history(history: list[dict]) -> list[dict]:
|
149 |
+
messages = []
|
150 |
+
current_user_content: list[dict] = []
|
151 |
+
for item in history:
|
152 |
+
if item["role"] == "assistant":
|
153 |
+
if current_user_content:
|
154 |
+
messages.append({"role": "user", "content": current_user_content})
|
155 |
+
current_user_content = []
|
156 |
+
messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
|
157 |
+
else:
|
158 |
+
content = item["content"]
|
159 |
+
if isinstance(content, str):
|
160 |
+
current_user_content.append({"type": "text", "text": content})
|
161 |
+
else:
|
162 |
+
current_user_content.append({"type": "image", "url": content[0]})
|
163 |
+
return messages
|
164 |
+
|
165 |
|
166 |
def run(message: dict, history: list[dict]) -> Iterator[str]:
|
167 |
if not validate_media_constraints(message, history):
|