Spaces:
Running
New features!
Browse filesBig changes to the code!
Add a functionality to save and include some outputs to the zip file:
- Modified the `predict` method in the `Predictor` class to create and save:
- A `.txt` file containing "Output (string)" as well the "Categorized Output (string)".
- A `categorized_tag.json` file containing "Categorized (tags)".
- A copy of the uploaded image(s) in PNG format.
- Updated the `create_file` method to handle both text and JSON files.
- Ensured that all the created files are included in the zip file during the download process.
Also:
- Fixed the images grid position by creating a custom CSS, so it shows all the images correctly.
- Used a minify to clean some parts of the code without changing its functionality.
- Removed unnecessary comments and duplicate code.
@@ -1,56 +1,30 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
import
|
4 |
-
import copy
|
5 |
-
import requests
|
6 |
-
import numpy as np
|
7 |
-
import spaces
|
8 |
-
import gradio as gr
|
9 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
10 |
-
from transformers import AutoModelForCausalLM, AutoProcessor
|
11 |
from transformers.dynamic_module_utils import get_imports
|
12 |
-
from PIL import Image,
|
13 |
-
import matplotlib.pyplot as plt
|
14 |
-
import matplotlib.patches as patches
|
15 |
from unittest.mock import patch
|
16 |
-
|
17 |
-
import
|
18 |
-
import huggingface_hub
|
19 |
-
import onnxruntime as rt
|
20 |
-
import pandas as pd
|
21 |
-
import traceback
|
22 |
-
import tempfile
|
23 |
-
import zipfile
|
24 |
-
import re
|
25 |
-
import ast
|
26 |
-
import time
|
27 |
-
from datetime import datetime, timezone
|
28 |
from collections import defaultdict
|
29 |
from classifyTags import classify_tags
|
30 |
-
# Add scheduler code here
|
31 |
from apscheduler.schedulers.background import BackgroundScheduler
|
32 |
-
|
33 |
-
os.environ[
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
imports.remove("flash_attn")
|
41 |
-
return imports
|
42 |
-
|
43 |
@spaces.GPU
|
44 |
def get_device_type():
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
return "mps"
|
51 |
-
else:
|
52 |
-
return "cpu"
|
53 |
-
|
54 |
model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
|
55 |
|
56 |
import subprocess
|
@@ -61,7 +35,6 @@ if (device == "cuda"):
|
|
61 |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
62 |
model.to(device)
|
63 |
else:
|
64 |
-
#https://huggingface.co/microsoft/Florence-2-base-ft/discussions/4
|
65 |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
66 |
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
67 |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
@@ -75,13 +48,12 @@ Features:
|
|
75 |
- Supports batch processing of multiple images.
|
76 |
- Tags images with multiple categories: general tags, character tags, and ratings.
|
77 |
- Displays categorized tags in a structured format.
|
78 |
-
- Includes a separate tab for image captioning using Florence 2.
|
79 |
-
- Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection),
|
80 |
|
81 |
Example image by [me.](https://huggingface.co/Werli)
|
82 |
"""
|
83 |
-
colormap
|
84 |
-
'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
|
85 |
|
86 |
# Dataset v3 series of models:
|
87 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
@@ -89,277 +61,87 @@ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
|
|
89 |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
90 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
91 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
92 |
-
|
93 |
# Dataset v2 series of models:
|
94 |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
95 |
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
96 |
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
97 |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
98 |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
99 |
-
|
100 |
# IdolSankaku series of models:
|
101 |
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
|
102 |
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
|
103 |
-
|
104 |
# Files to download from the repos
|
105 |
MODEL_FILENAME = "model.onnx"
|
106 |
LABEL_FILENAME = "selected_tags.csv"
|
107 |
-
|
108 |
# LLAMA model
|
109 |
META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
|
110 |
META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
"+_-",
|
118 |
-
"._.",
|
119 |
-
"<o>_<o>",
|
120 |
-
"<|>_<|>",
|
121 |
-
"=_=",
|
122 |
-
">_<",
|
123 |
-
"3_3",
|
124 |
-
"6_9",
|
125 |
-
">_o",
|
126 |
-
"@_@",
|
127 |
-
"^_^",
|
128 |
-
"o_o",
|
129 |
-
"u_u",
|
130 |
-
"x_x",
|
131 |
-
"|_|",
|
132 |
-
"||_||",
|
133 |
-
]
|
134 |
-
def parse_args() -> argparse.Namespace:
|
135 |
-
parser = argparse.ArgumentParser()
|
136 |
-
parser.add_argument("--score-slider-step", type=float, default=0.05)
|
137 |
-
parser.add_argument("--score-general-threshold", type=float, default=0.35)
|
138 |
-
parser.add_argument("--score-character-threshold", type=float, default=0.85)
|
139 |
-
parser.add_argument("--share", action="store_true")
|
140 |
-
return parser.parse_args()
|
141 |
-
def load_labels(dataframe) -> list[str]:
|
142 |
-
name_series = dataframe["name"]
|
143 |
-
name_series = name_series.map(
|
144 |
-
lambda x: x.replace("_", " ") if x not in kaomojis else x
|
145 |
-
)
|
146 |
-
tag_names = name_series.tolist()
|
147 |
-
|
148 |
-
rating_indexes = list(np.where(dataframe["category"] == 9)[0])
|
149 |
-
general_indexes = list(np.where(dataframe["category"] == 0)[0])
|
150 |
-
character_indexes = list(np.where(dataframe["category"] == 4)[0])
|
151 |
-
return tag_names, rating_indexes, general_indexes, character_indexes
|
152 |
-
def mcut_threshold(probs):
|
153 |
-
"""
|
154 |
-
Maximum Cut Thresholding (MCut)
|
155 |
-
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
|
156 |
-
for Multi-label Classification. In 11th International Symposium, IDA 2012
|
157 |
-
(pp. 172-183).
|
158 |
-
"""
|
159 |
-
sorted_probs = probs[probs.argsort()[::-1]]
|
160 |
-
difs = sorted_probs[:-1] - sorted_probs[1:]
|
161 |
-
t = difs.argmax()
|
162 |
-
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
163 |
-
return thresh
|
164 |
class Timer:
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
177 |
-
|
178 |
-
prev_time = self.checkpoints[0][1]
|
179 |
-
for label, curr_time in self.checkpoints[1:]:
|
180 |
-
elapsed = curr_time - prev_time
|
181 |
-
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
182 |
-
prev_time = curr_time
|
183 |
-
|
184 |
-
if is_clear_checkpoints:
|
185 |
-
self.checkpoints.clear()
|
186 |
-
self.checkpoint() # Store checkpoints
|
187 |
-
|
188 |
-
def report_all(self):
|
189 |
-
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
190 |
-
print("\n> Execution Time Report:")
|
191 |
-
|
192 |
-
# Determine the max label width for alignment
|
193 |
-
max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
|
194 |
-
|
195 |
-
prev_time = self.start_time
|
196 |
-
for label, curr_time in self.checkpoints[1:]:
|
197 |
-
elapsed = curr_time - prev_time
|
198 |
-
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
199 |
-
prev_time = curr_time
|
200 |
-
|
201 |
-
total_time = self.checkpoints[-1][1] - self.start_time
|
202 |
-
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
203 |
-
|
204 |
-
self.checkpoints.clear()
|
205 |
-
|
206 |
-
def restart(self):
|
207 |
-
self.start_time = time.perf_counter() # Record the start time
|
208 |
-
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
209 |
|
210 |
class Llama3Reorganize:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
if device is None:
|
229 |
-
import torch
|
230 |
-
self.totalVram = 0
|
231 |
-
if torch.cuda.is_available():
|
232 |
-
try:
|
233 |
-
deviceId = torch.cuda.current_device()
|
234 |
-
self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
|
235 |
-
except Exception as e:
|
236 |
-
print(traceback.format_exc())
|
237 |
-
print("Error detect vram: " + str(e))
|
238 |
-
device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
|
239 |
-
else:
|
240 |
-
device = "cpu"
|
241 |
-
|
242 |
-
self.device = device
|
243 |
-
self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
|
244 |
-
|
245 |
-
if loadModel:
|
246 |
-
self.load_model()
|
247 |
-
|
248 |
-
def download_model(self, repoId):
|
249 |
-
import warnings
|
250 |
-
import requests
|
251 |
-
allowPatterns = [
|
252 |
-
"config.json",
|
253 |
-
"generation_config.json",
|
254 |
-
"model.bin",
|
255 |
-
"pytorch_model.bin",
|
256 |
-
"pytorch_model.bin.index.json",
|
257 |
-
"pytorch_model-*.bin",
|
258 |
-
"sentencepiece.bpe.model",
|
259 |
-
"tokenizer.json",
|
260 |
-
"tokenizer_config.json",
|
261 |
-
"shared_vocabulary.txt",
|
262 |
-
"shared_vocabulary.json",
|
263 |
-
"special_tokens_map.json",
|
264 |
-
"spiece.model",
|
265 |
-
"vocab.json",
|
266 |
-
"model.safetensors",
|
267 |
-
"model-*.safetensors",
|
268 |
-
"model.safetensors.index.json",
|
269 |
-
"quantize_config.json",
|
270 |
-
"tokenizer.model",
|
271 |
-
"vocabulary.json",
|
272 |
-
"preprocessor_config.json",
|
273 |
-
"added_tokens.json"
|
274 |
-
]
|
275 |
-
|
276 |
-
kwargs = {"allow_patterns": allowPatterns,}
|
277 |
-
|
278 |
-
try:
|
279 |
-
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
280 |
-
except (
|
281 |
-
huggingface_hub.utils.HfHubHTTPError,
|
282 |
-
requests.exceptions.ConnectionError,
|
283 |
-
) as exception:
|
284 |
-
warnings.warn(
|
285 |
-
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
286 |
-
repoId,
|
287 |
-
exception,
|
288 |
-
)
|
289 |
-
warnings.warn(
|
290 |
-
"Trying to load the model directly from the local cache, if it exists."
|
291 |
-
)
|
292 |
-
|
293 |
-
kwargs["local_files_only"] = True
|
294 |
-
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
295 |
-
|
296 |
|
297 |
def load_model(self):
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
302 |
-
kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
|
303 |
-
kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
|
304 |
-
self.roleSystem = {"role": "system", "content": self.system_prompt}
|
305 |
-
self.Model = ctranslate2.Generator(**kwargsModel)
|
306 |
-
|
307 |
-
self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
|
308 |
-
self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
|
309 |
-
|
310 |
-
except Exception as e:
|
311 |
-
self.release_vram()
|
312 |
-
raise e
|
313 |
-
|
314 |
|
315 |
def release_vram(self):
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
output = None
|
340 |
-
result = None
|
341 |
-
try:
|
342 |
-
input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
|
343 |
-
source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
|
344 |
-
output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
|
345 |
-
target = output[0]
|
346 |
-
result = self.Tokenizer.decode(target.sequences_ids[0])
|
347 |
-
|
348 |
-
if len(result) > 2:
|
349 |
-
if result[0] == "\"" and result[len(result) - 1] == "\"":
|
350 |
-
result = result[1:-1]
|
351 |
-
elif result[0] == "'" and result[len(result) - 1] == "'":
|
352 |
-
result = result[1:-1]
|
353 |
-
elif result[0] == "「" and result[len(result) - 1] == "」":
|
354 |
-
result = result[1:-1]
|
355 |
-
elif result[0] == "『" and result[len(result) - 1] == "』":
|
356 |
-
result = result[1:-1]
|
357 |
-
except Exception as e:
|
358 |
-
print(traceback.format_exc())
|
359 |
-
print("Error reorganize text: " + str(e))
|
360 |
-
|
361 |
-
return result
|
362 |
-
|
363 |
|
364 |
class Predictor:
|
365 |
def __init__(self):
|
@@ -412,7 +194,7 @@ class Predictor:
|
|
412 |
|
413 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
414 |
padded_image.paste(image, (pad_left, pad_top))
|
415 |
-
|
416 |
# Resize
|
417 |
if max_dim != target_size:
|
418 |
padded_image = padded_image.resize(
|
@@ -421,18 +203,21 @@ class Predictor:
|
|
421 |
)
|
422 |
# Convert to numpy array
|
423 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
424 |
-
|
425 |
# Convert PIL-native RGB to BGR
|
426 |
image_array = image_array[:, :, ::-1]
|
427 |
-
|
428 |
return np.expand_dims(image_array, axis=0)
|
429 |
|
430 |
-
def create_file(self,
|
431 |
-
# Write the
|
432 |
-
|
433 |
-
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
-
return
|
436 |
|
437 |
def predict(
|
438 |
self,
|
@@ -465,14 +250,13 @@ class Predictor:
|
|
465 |
progress(current_progress, desc="Initialize wd model finished")
|
466 |
timer.checkpoint(f"Initialize wd model")
|
467 |
|
468 |
-
# Result
|
469 |
txt_infos = []
|
470 |
output_dir = tempfile.mkdtemp()
|
471 |
if not os.path.exists(output_dir):
|
472 |
os.makedirs(output_dir)
|
473 |
|
474 |
sorted_general_strings = ""
|
475 |
-
#
|
476 |
categorized_output_strings = []
|
477 |
rating = None
|
478 |
character_res = None
|
@@ -567,6 +351,22 @@ class Predictor:
|
|
567 |
# Collect all categorized output strings into a single string
|
568 |
final_categorized_output = ', '.join(categorized_output_strings)
|
569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
current_progress += progressRatio/progressTotal;
|
571 |
progress(current_progress, desc=f"image{idx:02d}, predict finished")
|
572 |
timer.checkpoint(f"image{idx:02d}, predict finished")
|
@@ -602,15 +402,16 @@ class Predictor:
|
|
602 |
print(traceback.format_exc())
|
603 |
print("Error predict: " + str(e))
|
604 |
# Result
|
|
|
605 |
download = []
|
606 |
if txt_infos is not None and len(txt_infos) > 0:
|
607 |
-
downloadZipPath = os.path.join(output_dir, "
|
608 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
609 |
for info in txt_infos:
|
610 |
# Get file name from lookup
|
611 |
taggers_zip.write(info["path"], arcname=info["name"])
|
612 |
download.append(downloadZipPath)
|
613 |
-
|
614 |
if llama3_reorganize_model_repo:
|
615 |
llama3_reorganize.release_vram()
|
616 |
del llama3_reorganize
|
@@ -620,11 +421,9 @@ class Predictor:
|
|
620 |
print("Predict is complete.")
|
621 |
|
622 |
return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
623 |
-
|
624 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
625 |
if not selected_state:
|
626 |
return selected_state
|
627 |
-
|
628 |
tag_result = {
|
629 |
"strings": "",
|
630 |
"strings2": "",
|
@@ -636,272 +435,80 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
|
|
636 |
}
|
637 |
if selected_state.value["image"]["path"] in tag_results:
|
638 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
639 |
-
|
640 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
if gallery is None:
|
655 |
-
gallery = []
|
656 |
-
if not images:
|
657 |
-
return gallery
|
658 |
-
|
659 |
-
# Combine the new images with the existing gallery images
|
660 |
-
gallery.extend(images)
|
661 |
-
|
662 |
-
return gallery
|
663 |
-
|
664 |
-
def remove_image_from_gallery(gallery: list, selected_image: str):
|
665 |
-
if not gallery or not selected_image:
|
666 |
-
return gallery
|
667 |
-
|
668 |
-
selected_image = ast.literal_eval(selected_image) # Use ast.literal_eval to parse text into a tuple.
|
669 |
-
# Remove the selected image from the gallery
|
670 |
-
if selected_image in gallery:
|
671 |
-
gallery.remove(selected_image)
|
672 |
-
return gallery
|
673 |
-
|
674 |
# END
|
675 |
|
676 |
-
def fig_to_pil(fig):
|
677 |
-
buf = io.BytesIO()
|
678 |
-
fig.savefig(buf, format='png')
|
679 |
-
buf.seek(0)
|
680 |
-
return Image.open(buf)
|
681 |
-
|
682 |
@spaces.GPU
|
683 |
-
def run_example(task_prompt,
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
od_results = {
|
739 |
-
'bboxes': bboxes,
|
740 |
-
'labels': labels
|
741 |
-
}
|
742 |
-
return od_results
|
743 |
-
|
744 |
-
def draw_ocr_bboxes(image, prediction):
|
745 |
-
scale = 1
|
746 |
-
draw = ImageDraw.Draw(image)
|
747 |
-
bboxes, labels = prediction['quad_boxes'], prediction['labels']
|
748 |
-
for box, label in zip(bboxes, labels):
|
749 |
-
color = random.choice(colormap)
|
750 |
-
new_box = (np.array(box) * scale).tolist()
|
751 |
-
draw.polygon(new_box, width=3, outline=color)
|
752 |
-
draw.text((new_box[0]+8, new_box[1]+2),
|
753 |
-
"{}".format(label),
|
754 |
-
align="right",
|
755 |
-
fill=color)
|
756 |
-
return image
|
757 |
-
|
758 |
-
def convert_to_od_format(data):
|
759 |
-
bboxes = data.get('bboxes', [])
|
760 |
-
labels = data.get('bboxes_labels', [])
|
761 |
-
od_results = {
|
762 |
-
'bboxes': bboxes,
|
763 |
-
'labels': labels
|
764 |
-
}
|
765 |
-
return od_results
|
766 |
-
|
767 |
-
def draw_ocr_bboxes(image, prediction):
|
768 |
-
scale = 1
|
769 |
-
draw = ImageDraw.Draw(image)
|
770 |
-
bboxes, labels = prediction['quad_boxes'], prediction['labels']
|
771 |
-
for box, label in zip(bboxes, labels):
|
772 |
-
color = random.choice(colormap)
|
773 |
-
new_box = (np.array(box) * scale).tolist()
|
774 |
-
draw.polygon(new_box, width=3, outline=color)
|
775 |
-
draw.text((new_box[0]+8, new_box[1]+2),
|
776 |
-
"{}".format(label),
|
777 |
-
align="right",
|
778 |
-
fill=color)
|
779 |
-
return image
|
780 |
-
def process_image(image, task_prompt, text_input=None):
|
781 |
-
# Test
|
782 |
-
if isinstance(image, str): # If image is a file path
|
783 |
-
image = Image.open(image) # Load image from file path
|
784 |
-
else: # If image is a NumPy array
|
785 |
-
image = Image.fromarray(image) # Convert NumPy array to PIL Image
|
786 |
-
if task_prompt == 'Caption':
|
787 |
-
task_prompt = '<CAPTION>'
|
788 |
-
results = run_example(task_prompt, image)
|
789 |
-
return results[task_prompt], None
|
790 |
-
elif task_prompt == 'Detailed Caption':
|
791 |
-
task_prompt = '<DETAILED_CAPTION>'
|
792 |
-
results = run_example(task_prompt, image)
|
793 |
-
return results[task_prompt], None
|
794 |
-
elif task_prompt == 'More Detailed Caption':
|
795 |
-
task_prompt = '<MORE_DETAILED_CAPTION>'
|
796 |
-
results = run_example(task_prompt, image)
|
797 |
-
return results, None
|
798 |
-
elif task_prompt == 'Caption + Grounding':
|
799 |
-
task_prompt = '<CAPTION>'
|
800 |
-
results = run_example(task_prompt, image)
|
801 |
-
text_input = results[task_prompt]
|
802 |
-
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
|
803 |
-
results = run_example(task_prompt, image, text_input)
|
804 |
-
results['<CAPTION>'] = text_input
|
805 |
-
fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
|
806 |
-
return results, fig_to_pil(fig)
|
807 |
-
elif task_prompt == 'Detailed Caption + Grounding':
|
808 |
-
task_prompt = '<DETAILED_CAPTION>'
|
809 |
-
results = run_example(task_prompt, image)
|
810 |
-
text_input = results[task_prompt]
|
811 |
-
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
|
812 |
-
results = run_example(task_prompt, image, text_input)
|
813 |
-
results['<DETAILED_CAPTION>'] = text_input
|
814 |
-
fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
|
815 |
-
return results, fig_to_pil(fig)
|
816 |
-
elif task_prompt == 'More Detailed Caption + Grounding':
|
817 |
-
task_prompt = '<MORE_DETAILED_CAPTION>'
|
818 |
-
results = run_example(task_prompt, image)
|
819 |
-
text_input = results[task_prompt]
|
820 |
-
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
|
821 |
-
results = run_example(task_prompt, image, text_input)
|
822 |
-
results['<MORE_DETAILED_CAPTION>'] = text_input
|
823 |
-
fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
|
824 |
-
return results, fig_to_pil(fig)
|
825 |
-
elif task_prompt == 'Object Detection':
|
826 |
-
task_prompt = '<OD>'
|
827 |
-
results = run_example(task_prompt, image)
|
828 |
-
fig = plot_bbox(image, results['<OD>'])
|
829 |
-
return results, fig_to_pil(fig)
|
830 |
-
elif task_prompt == 'Dense Region Caption':
|
831 |
-
task_prompt = '<DENSE_REGION_CAPTION>'
|
832 |
-
results = run_example(task_prompt, image)
|
833 |
-
fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
|
834 |
-
return results, fig_to_pil(fig)
|
835 |
-
elif task_prompt == 'Region Proposal':
|
836 |
-
task_prompt = '<REGION_PROPOSAL>'
|
837 |
-
results = run_example(task_prompt, image)
|
838 |
-
fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
|
839 |
-
return results, fig_to_pil(fig)
|
840 |
-
elif task_prompt == 'Caption to Phrase Grounding':
|
841 |
-
task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
|
842 |
-
results = run_example(task_prompt, image, text_input)
|
843 |
-
fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
|
844 |
-
return results, fig_to_pil(fig)
|
845 |
-
elif task_prompt == 'Referring Expression Segmentation':
|
846 |
-
task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
|
847 |
-
results = run_example(task_prompt, image, text_input)
|
848 |
-
output_image = copy.deepcopy(image)
|
849 |
-
output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
|
850 |
-
return results, output_image
|
851 |
-
elif task_prompt == 'Region to Segmentation':
|
852 |
-
task_prompt = '<REGION_TO_SEGMENTATION>'
|
853 |
-
results = run_example(task_prompt, image, text_input)
|
854 |
-
output_image = copy.deepcopy(image)
|
855 |
-
output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
|
856 |
-
return results, output_image
|
857 |
-
elif task_prompt == 'Open Vocabulary Detection':
|
858 |
-
task_prompt = '<OPEN_VOCABULARY_DETECTION>'
|
859 |
-
results = run_example(task_prompt, image, text_input)
|
860 |
-
bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
|
861 |
-
fig = plot_bbox(image, bbox_results)
|
862 |
-
return results, fig_to_pil(fig)
|
863 |
-
elif task_prompt == 'Region to Category':
|
864 |
-
task_prompt = '<REGION_TO_CATEGORY>'
|
865 |
-
results = run_example(task_prompt, image, text_input)
|
866 |
-
return results, None
|
867 |
-
elif task_prompt == 'Region to Description':
|
868 |
-
task_prompt = '<REGION_TO_DESCRIPTION>'
|
869 |
-
results = run_example(task_prompt, image, text_input)
|
870 |
-
return results, None
|
871 |
-
elif task_prompt == 'OCR':
|
872 |
-
task_prompt = '<OCR>'
|
873 |
-
results = run_example(task_prompt, image)
|
874 |
-
return results, None
|
875 |
-
elif task_prompt == 'OCR with Region':
|
876 |
-
task_prompt = '<OCR_WITH_REGION>'
|
877 |
-
results = run_example(task_prompt, image)
|
878 |
-
output_image = copy.deepcopy(image)
|
879 |
-
output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
|
880 |
-
return results, output_image
|
881 |
-
else:
|
882 |
-
return "", None # Return empty string and None for unknown task prompts
|
883 |
|
884 |
-
# Custom CSS to set the height of the gr.Dropdown menu
|
885 |
-
css = """
|
886 |
-
div.progress-level div.progress-level-inner {
|
887 |
-
text-align: left !important;
|
888 |
-
width: 55.5% !important;
|
889 |
-
#output {
|
890 |
-
height: 500px;
|
891 |
-
overflow: auto;
|
892 |
-
border: 1px solid #ccc;
|
893 |
-
}
|
894 |
-
"""
|
895 |
-
single_task_list =[
|
896 |
-
'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
|
897 |
-
'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
|
898 |
-
'Referring Expression Segmentation', 'Region to Segmentation',
|
899 |
-
'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
|
900 |
-
'OCR', 'OCR with Region'
|
901 |
-
]
|
902 |
-
cascaded_task_list =[
|
903 |
-
'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
|
904 |
-
]
|
905 |
def update_task_dropdown(choice):
|
906 |
if choice == 'Cascaded task':
|
907 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
@@ -909,7 +516,6 @@ def update_task_dropdown(choice):
|
|
909 |
return gr.Dropdown(choices=single_task_list, value='Caption')
|
910 |
|
911 |
args = parse_args()
|
912 |
-
|
913 |
predictor = Predictor()
|
914 |
|
915 |
dropdown_list = [
|
@@ -933,19 +539,23 @@ llama_list = [
|
|
933 |
META_LLAMA_3_8B_REPO,
|
934 |
]
|
935 |
|
936 |
-
# This is workaround will make the space restart every 2 days. (for test).
|
937 |
def _restart_space():
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
scheduler = BackgroundScheduler()
|
943 |
# Add a job to restart the space every 2 days (172800 seconds)
|
944 |
restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
|
945 |
-
# Start the scheduler
|
946 |
scheduler.start()
|
947 |
-
next_run_time_utc
|
948 |
-
NEXT_RESTART
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
949 |
|
950 |
with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
|
951 |
gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
@@ -1030,7 +640,7 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
1030 |
size="lg",
|
1031 |
)
|
1032 |
with gr.Column(variant="panel"):
|
1033 |
-
download_file = gr.File(label="Output (Download)") # 0
|
1034 |
character_res = gr.Label(label="Output (characters)") # 1
|
1035 |
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True) # 2
|
1036 |
final_categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True) # 3
|
@@ -1112,5 +722,4 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
1112 |
label='Try examples'
|
1113 |
)
|
1114 |
submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
|
1115 |
-
|
1116 |
demo.queue(max_size=2).launch()
|
|
|
1 |
import os
|
2 |
+
import io,copy,requests,numpy as np,spaces,gradio as gr
|
3 |
+
from transformers import AutoProcessor,AutoModelForCausalLM,AutoModelForCausalLM,AutoProcessor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from transformers.dynamic_module_utils import get_imports
|
5 |
+
from PIL import Image,ImageDraw,ImageFont
|
6 |
+
import matplotlib.pyplot as plt,matplotlib.patches as patches
|
|
|
7 |
from unittest.mock import patch
|
8 |
+
import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
|
9 |
+
from datetime import datetime,timezone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from collections import defaultdict
|
11 |
from classifyTags import classify_tags
|
|
|
12 |
from apscheduler.schedulers.background import BackgroundScheduler
|
13 |
+
import json
|
14 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
|
15 |
+
|
16 |
+
def fixed_get_imports(filename:str|os.PathLike)->list[str]:
|
17 |
+
if not str(filename).endswith('/modeling_florence2.py'):return get_imports(filename)
|
18 |
+
imports=get_imports(filename)
|
19 |
+
if'flash_attn'in imports:imports.remove('flash_attn')
|
20 |
+
return imports
|
|
|
|
|
|
|
21 |
@spaces.GPU
|
22 |
def get_device_type():
|
23 |
+
import torch
|
24 |
+
if torch.cuda.is_available():return'cuda'
|
25 |
+
elif torch.backends.mps.is_available()and torch.backends.mps.is_built():return'mps'
|
26 |
+
else:return'cpu'
|
27 |
+
|
|
|
|
|
|
|
|
|
28 |
model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
|
29 |
|
30 |
import subprocess
|
|
|
35 |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
36 |
model.to(device)
|
37 |
else:
|
|
|
38 |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
|
39 |
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
40 |
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
|
|
|
48 |
- Supports batch processing of multiple images.
|
49 |
- Tags images with multiple categories: general tags, character tags, and ratings.
|
50 |
- Displays categorized tags in a structured format.
|
51 |
+
- Includes a separate tab for image captioning using Florence 2. Supports CUDA, MPS or CPU if one of them is available.
|
52 |
+
- Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection), it can display output text and images for tasks that generate visual outputs.
|
53 |
|
54 |
Example image by [me.](https://huggingface.co/Werli)
|
55 |
"""
|
56 |
+
colormap=['blue','orange','green','purple','brown','pink','gray','olive','cyan','red','lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
|
|
|
57 |
|
58 |
# Dataset v3 series of models:
|
59 |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
|
|
|
61 |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
|
62 |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
|
63 |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
|
|
|
64 |
# Dataset v2 series of models:
|
65 |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
|
66 |
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
|
67 |
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
68 |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
|
69 |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
|
|
|
70 |
# IdolSankaku series of models:
|
71 |
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
|
72 |
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
|
|
|
73 |
# Files to download from the repos
|
74 |
MODEL_FILENAME = "model.onnx"
|
75 |
LABEL_FILENAME = "selected_tags.csv"
|
|
|
76 |
# LLAMA model
|
77 |
META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
|
78 |
META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
|
79 |
|
80 |
+
kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
|
81 |
+
def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
|
82 |
+
def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
|
83 |
+
def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
|
84 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
class Timer:
|
86 |
+
def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
|
87 |
+
def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
|
88 |
+
def report(self,is_clear_checkpoints=True):
|
89 |
+
max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
|
90 |
+
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
|
91 |
+
if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
|
92 |
+
def report_all(self):
|
93 |
+
print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
|
94 |
+
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
|
95 |
+
total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
|
96 |
+
def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
class Llama3Reorganize:
|
99 |
+
def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
|
100 |
+
self.modelPath=self.download_model(repoId)
|
101 |
+
if device is None:
|
102 |
+
import torch;self.totalVram=0
|
103 |
+
if torch.cuda.is_available():
|
104 |
+
try:deviceId=torch.cuda.current_device();self.totalVram=torch.cuda.get_device_properties(deviceId).total_memory/1073741824
|
105 |
+
except Exception as e:print(traceback.format_exc());print('Error detect vram: '+str(e))
|
106 |
+
device='cuda'if self.totalVram>(8 if'8B'in repoId else 4)else'cpu'
|
107 |
+
else:device='cpu'
|
108 |
+
self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
|
109 |
+
if loadModel:self.load_model()
|
110 |
+
|
111 |
+
def download_model(self,repoId):
|
112 |
+
import warnings,requests;allowPatterns=['config.json','generation_config.json','model.bin','pytorch_model.bin','pytorch_model.bin.index.json','pytorch_model-*.bin','sentencepiece.bpe.model','tokenizer.json','tokenizer_config.json','shared_vocabulary.txt','shared_vocabulary.json','special_tokens_map.json','spiece.model','vocab.json','model.safetensors','model-*.safetensors','model.safetensors.index.json','quantize_config.json','tokenizer.model','vocabulary.json','preprocessor_config.json','added_tokens.json'];kwargs={'allow_patterns':allowPatterns}
|
113 |
+
try:return huggingface_hub.snapshot_download(repoId,**kwargs)
|
114 |
+
except(huggingface_hub.utils.HfHubHTTPError,requests.exceptions.ConnectionError)as exception:warnings.warn('An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s',repoId,exception);warnings.warn('Trying to load the model directly from the local cache, if it exists.');kwargs['local_files_only']=True;return huggingface_hub.snapshot_download(repoId,**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
def load_model(self):
|
117 |
+
import ctranslate2,transformers
|
118 |
+
try:print('\n\nLoading model: %s\n\n'%self.modelPath);kwargsTokenizer={'pretrained_model_name_or_path':self.modelPath};kwargsModel={'device':self.device,'model_path':self.modelPath,'compute_type':'auto'};self.roleSystem={'role':'system','content':self.system_prompt};self.Model=ctranslate2.Generator(**kwargsModel);self.Tokenizer=transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer);self.terminators=[self.Tokenizer.eos_token_id,self.Tokenizer.convert_tokens_to_ids('<|eot_id|>')]
|
119 |
+
except Exception as e:self.release_vram();raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
def release_vram(self):
|
122 |
+
try:
|
123 |
+
import torch
|
124 |
+
if torch.cuda.is_available():
|
125 |
+
if getattr(self,'Model',None)is not None and getattr(self.Model,'unload_model',None)is not None:self.Model.unload_model()
|
126 |
+
if getattr(self,'Tokenizer',None)is not None:del self.Tokenizer
|
127 |
+
if getattr(self,'Model',None)is not None:del self.Model
|
128 |
+
import gc;gc.collect()
|
129 |
+
try:torch.cuda.empty_cache()
|
130 |
+
except Exception as e:print(traceback.format_exc());print('\tcuda empty cache, error: '+str(e))
|
131 |
+
print('release vram end.')
|
132 |
+
except Exception as e:print(traceback.format_exc());print('Error release vram: '+str(e))
|
133 |
+
|
134 |
+
def reorganize(self,text:str,max_length:int=400):
|
135 |
+
output=None;result=None
|
136 |
+
try:
|
137 |
+
input_ids=self.Tokenizer.apply_chat_template([self.roleSystem,{'role':'user','content':text+"\n\nHere's the reorganized English article:"}],tokenize=False,add_generation_prompt=True);source=self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids));output=self.Model.generate_batch([source],max_length=max_length,max_batch_size=2,no_repeat_ngram_size=3,beam_size=2,sampling_temperature=.7,sampling_topp=.9,include_prompt_in_result=False,end_token=self.terminators);target=output[0];result=self.Tokenizer.decode(target.sequences_ids[0])
|
138 |
+
if len(result)>2:
|
139 |
+
if result[0]=='"'and result[len(result)-1]=='"':result=result[1:-1]
|
140 |
+
elif result[0]=="'"and result[len(result)-1]=="'":result=result[1:-1]
|
141 |
+
elif result[0]=='「'and result[len(result)-1]=='」':result=result[1:-1]
|
142 |
+
elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
|
143 |
+
except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
|
144 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
class Predictor:
|
147 |
def __init__(self):
|
|
|
194 |
|
195 |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
|
196 |
padded_image.paste(image, (pad_left, pad_top))
|
197 |
+
|
198 |
# Resize
|
199 |
if max_dim != target_size:
|
200 |
padded_image = padded_image.resize(
|
|
|
203 |
)
|
204 |
# Convert to numpy array
|
205 |
image_array = np.asarray(padded_image, dtype=np.float32)
|
|
|
206 |
# Convert PIL-native RGB to BGR
|
207 |
image_array = image_array[:, :, ::-1]
|
|
|
208 |
return np.expand_dims(image_array, axis=0)
|
209 |
|
210 |
+
def create_file(self, content: str, directory: str, fileName: str) -> str:
|
211 |
+
# Write the content to a file
|
212 |
+
file_path = os.path.join(directory, fileName)
|
213 |
+
if fileName.endswith('.json'):
|
214 |
+
with open(file_path, 'w', encoding="utf-8") as file:
|
215 |
+
file.write(content)
|
216 |
+
else:
|
217 |
+
with open(file_path, 'w+', encoding="utf-8") as file:
|
218 |
+
file.write(content)
|
219 |
|
220 |
+
return file_path
|
221 |
|
222 |
def predict(
|
223 |
self,
|
|
|
250 |
progress(current_progress, desc="Initialize wd model finished")
|
251 |
timer.checkpoint(f"Initialize wd model")
|
252 |
|
|
|
253 |
txt_infos = []
|
254 |
output_dir = tempfile.mkdtemp()
|
255 |
if not os.path.exists(output_dir):
|
256 |
os.makedirs(output_dir)
|
257 |
|
258 |
sorted_general_strings = ""
|
259 |
+
# Create categorized output string
|
260 |
categorized_output_strings = []
|
261 |
rating = None
|
262 |
character_res = None
|
|
|
351 |
# Collect all categorized output strings into a single string
|
352 |
final_categorized_output = ', '.join(categorized_output_strings)
|
353 |
|
354 |
+
# Create a .txt file for "Output (string)" and "Categorized Output (string)"
|
355 |
+
txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
|
356 |
+
txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
|
357 |
+
txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
|
358 |
+
|
359 |
+
# Create a .json file for "Categorized (tags)"
|
360 |
+
json_content = json.dumps(classified_tags, indent=4)
|
361 |
+
json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
|
362 |
+
txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
|
363 |
+
|
364 |
+
# Save a copy of the uploaded image in PNG format
|
365 |
+
image_path = value[0]
|
366 |
+
image = Image.open(image_path)
|
367 |
+
image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
|
368 |
+
txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
|
369 |
+
|
370 |
current_progress += progressRatio/progressTotal;
|
371 |
progress(current_progress, desc=f"image{idx:02d}, predict finished")
|
372 |
timer.checkpoint(f"image{idx:02d}, predict finished")
|
|
|
402 |
print(traceback.format_exc())
|
403 |
print("Error predict: " + str(e))
|
404 |
# Result
|
405 |
+
# Zip creation logic:
|
406 |
download = []
|
407 |
if txt_infos is not None and len(txt_infos) > 0:
|
408 |
+
downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
|
409 |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
|
410 |
for info in txt_infos:
|
411 |
# Get file name from lookup
|
412 |
taggers_zip.write(info["path"], arcname=info["name"])
|
413 |
download.append(downloadZipPath)
|
414 |
+
# End zip creation logic
|
415 |
if llama3_reorganize_model_repo:
|
416 |
llama3_reorganize.release_vram()
|
417 |
del llama3_reorganize
|
|
|
421 |
print("Predict is complete.")
|
422 |
|
423 |
return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
|
|
424 |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
425 |
if not selected_state:
|
426 |
return selected_state
|
|
|
427 |
tag_result = {
|
428 |
"strings": "",
|
429 |
"strings2": "",
|
|
|
435 |
}
|
436 |
if selected_state.value["image"]["path"] in tag_results:
|
437 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
|
|
438 |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
439 |
+
def append_gallery(gallery:list,image:str):
|
440 |
+
if gallery is None:gallery=[]
|
441 |
+
if not image:return gallery,None
|
442 |
+
gallery.append(image);return gallery,None
|
443 |
+
def extend_gallery(gallery:list,images):
|
444 |
+
if gallery is None:gallery=[]
|
445 |
+
if not images:return gallery
|
446 |
+
gallery.extend(images);return gallery
|
447 |
+
def remove_image_from_gallery(gallery:list,selected_image:str):
|
448 |
+
if not gallery or not selected_image:return gallery
|
449 |
+
selected_image=ast.literal_eval(selected_image)
|
450 |
+
if selected_image in gallery:gallery.remove(selected_image)
|
451 |
+
return gallery
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
# END
|
453 |
|
454 |
+
def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
|
|
|
|
|
|
|
|
|
|
|
455 |
@spaces.GPU
|
456 |
+
def run_example(task_prompt,image,text_input=None):
|
457 |
+
if text_input is None:prompt=task_prompt
|
458 |
+
else:prompt=task_prompt+text_input
|
459 |
+
inputs=processor(text=prompt,images=image,return_tensors='pt').to(device);generated_ids=model.generate(input_ids=inputs['input_ids'],pixel_values=inputs['pixel_values'],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3);generated_text=processor.batch_decode(generated_ids,skip_special_tokens=False)[0];parsed_answer=processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width,image.height));return parsed_answer
|
460 |
+
def plot_bbox(image,data):
|
461 |
+
fig,ax=plt.subplots();ax.imshow(image)
|
462 |
+
for(bbox,label)in zip(data['bboxes'],data['labels']):x1,y1,x2,y2=bbox;rect=patches.Rectangle((x1,y1),x2-x1,y2-y1,linewidth=1,edgecolor='r',facecolor='none');ax.add_patch(rect);plt.text(x1,y1,label,color='white',fontsize=8,bbox=dict(facecolor='red',alpha=.5))
|
463 |
+
ax.axis('off');return fig
|
464 |
+
def draw_polygons(image,prediction,fill_mask=False):
|
465 |
+
draw=ImageDraw.Draw(image);scale=1
|
466 |
+
for(polygons,label)in zip(prediction['polygons'],prediction['labels']):
|
467 |
+
color=random.choice(colormap);fill_color=random.choice(colormap)if fill_mask else None
|
468 |
+
for _polygon in polygons:
|
469 |
+
_polygon=np.array(_polygon).reshape(-1,2)
|
470 |
+
if len(_polygon)<3:print('Invalid polygon:',_polygon);continue
|
471 |
+
_polygon=(_polygon*scale).reshape(-1).tolist()
|
472 |
+
if fill_mask:draw.polygon(_polygon,outline=color,fill=fill_color)
|
473 |
+
else:draw.polygon(_polygon,outline=color)
|
474 |
+
draw.text((_polygon[0]+8,_polygon[1]+2),label,fill=color)
|
475 |
+
return image
|
476 |
+
def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
|
477 |
+
def draw_ocr_bboxes(image,prediction):
|
478 |
+
scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
|
479 |
+
for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
|
480 |
+
return image
|
481 |
+
def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
|
482 |
+
def draw_ocr_bboxes(image,prediction):
|
483 |
+
scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
|
484 |
+
for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
|
485 |
+
return image
|
486 |
+
|
487 |
+
def process_image(image,task_prompt,text_input=None):
|
488 |
+
if isinstance(image,str):image=Image.open(image)
|
489 |
+
else:image=Image.fromarray(image)
|
490 |
+
if task_prompt=='Caption':task_prompt='<CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
|
491 |
+
elif task_prompt=='Detailed Caption':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
|
492 |
+
elif task_prompt=='More Detailed Caption':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);return results,None
|
493 |
+
elif task_prompt=='Caption + Grounding':task_prompt='<CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
|
494 |
+
elif task_prompt=='Detailed Caption + Grounding':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
|
495 |
+
elif task_prompt=='More Detailed Caption + Grounding':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<MORE_DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
|
496 |
+
elif task_prompt=='Object Detection':task_prompt='<OD>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<OD>']);return results,fig_to_pil(fig)
|
497 |
+
elif task_prompt=='Dense Region Caption':task_prompt='<DENSE_REGION_CAPTION>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<DENSE_REGION_CAPTION>']);return results,fig_to_pil(fig)
|
498 |
+
elif task_prompt=='Region Proposal':task_prompt='<REGION_PROPOSAL>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<REGION_PROPOSAL>']);return results,fig_to_pil(fig)
|
499 |
+
elif task_prompt=='Caption to Phrase Grounding':task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
|
500 |
+
elif task_prompt=='Referring Expression Segmentation':task_prompt='<REFERRING_EXPRESSION_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REFERRING_EXPRESSION_SEGMENTATION>'],fill_mask=True);return results,output_image
|
501 |
+
elif task_prompt=='Region to Segmentation':task_prompt='<REGION_TO_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REGION_TO_SEGMENTATION>'],fill_mask=True);return results,output_image
|
502 |
+
elif task_prompt=='Open Vocabulary Detection':task_prompt='<OPEN_VOCABULARY_DETECTION>';results=run_example(task_prompt,image,text_input);bbox_results=convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>']);fig=plot_bbox(image,bbox_results);return results,fig_to_pil(fig)
|
503 |
+
elif task_prompt=='Region to Category':task_prompt='<REGION_TO_CATEGORY>';results=run_example(task_prompt,image,text_input);return results,None
|
504 |
+
elif task_prompt=='Region to Description':task_prompt='<REGION_TO_DESCRIPTION>';results=run_example(task_prompt,image,text_input);return results,None
|
505 |
+
elif task_prompt=='OCR':task_prompt='<OCR>';results=run_example(task_prompt,image);return results,None
|
506 |
+
elif task_prompt=='OCR with Region':task_prompt='<OCR_WITH_REGION>';results=run_example(task_prompt,image);output_image=copy.deepcopy(image);output_image=draw_ocr_bboxes(output_image,results['<OCR_WITH_REGION>']);return results,output_image
|
507 |
+
else:return'',None # Return empty string and None for unknown task prompts
|
508 |
+
|
509 |
+
single_task_list=['Caption','Detailed Caption','More Detailed Caption','Object Detection','Dense Region Caption','Region Proposal','Caption to Phrase Grounding','Referring Expression Segmentation','Region to Segmentation','Open Vocabulary Detection','Region to Category','Region to Description','OCR','OCR with Region']
|
510 |
+
cascaded_task_list=['Caption + Grounding','Detailed Caption + Grounding','More Detailed Caption + Grounding']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
def update_task_dropdown(choice):
|
513 |
if choice == 'Cascaded task':
|
514 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
|
|
516 |
return gr.Dropdown(choices=single_task_list, value='Caption')
|
517 |
|
518 |
args = parse_args()
|
|
|
519 |
predictor = Predictor()
|
520 |
|
521 |
dropdown_list = [
|
|
|
539 |
META_LLAMA_3_8B_REPO,
|
540 |
]
|
541 |
|
|
|
542 |
def _restart_space():
|
543 |
+
HF_TOKEN=os.getenv('HF_TOKEN')
|
544 |
+
if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
|
545 |
+
huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
|
546 |
+
scheduler=BackgroundScheduler()
|
|
|
547 |
# Add a job to restart the space every 2 days (172800 seconds)
|
548 |
restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
|
|
|
549 |
scheduler.start()
|
550 |
+
next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
|
551 |
+
NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
|
552 |
+
|
553 |
+
css = """
|
554 |
+
div.progress-level div.progress-level-inner {text-align: left !important; width: 55.5% !important;}
|
555 |
+
#output {height: 500px; overflow: auto; border: 1px solid #ccc;}
|
556 |
+
label.float.svelte-i3tvor {position: relative !important;}
|
557 |
+
.reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
|
558 |
+
"""
|
559 |
|
560 |
with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
|
561 |
gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
|
|
|
640 |
size="lg",
|
641 |
)
|
642 |
with gr.Column(variant="panel"):
|
643 |
+
download_file = gr.File(label="Output (Download All)") # 0
|
644 |
character_res = gr.Label(label="Output (characters)") # 1
|
645 |
sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True) # 2
|
646 |
final_categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True) # 3
|
|
|
722 |
label='Try examples'
|
723 |
)
|
724 |
submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
|
|
|
725 |
demo.queue(max_size=2).launch()
|