Werli commited on
Commit
6350bdc
·
verified ·
1 Parent(s): 35b60c7

New features!

Browse files

Big changes to the code!

Add a functionality to save and include some outputs to the zip file:
- Modified the `predict` method in the `Predictor` class to create and save:
- A `.txt` file containing "Output (string)" as well the "Categorized Output (string)".
- A `categorized_tag.json` file containing "Categorized (tags)".
- A copy of the uploaded image(s) in PNG format.
- Updated the `create_file` method to handle both text and JSON files.
- Ensured that all the created files are included in the zip file during the download process.

Also:
- Fixed the images grid position by creating a custom CSS, so it shows all the images correctly.
- Used a minify to clean some parts of the code without changing its functionality.
- Removed unnecessary comments and duplicate code.

Files changed (1) hide show
  1. app.py +194 -585
app.py CHANGED
@@ -1,56 +1,30 @@
1
  import os
2
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
3
- import io
4
- import copy
5
- import requests
6
- import numpy as np
7
- import spaces
8
- import gradio as gr
9
- from transformers import AutoProcessor, AutoModelForCausalLM
10
- from transformers import AutoModelForCausalLM, AutoProcessor
11
  from transformers.dynamic_module_utils import get_imports
12
- from PIL import Image, ImageDraw, ImageFont
13
- import matplotlib.pyplot as plt
14
- import matplotlib.patches as patches
15
  from unittest.mock import patch
16
-
17
- import argparse
18
- import huggingface_hub
19
- import onnxruntime as rt
20
- import pandas as pd
21
- import traceback
22
- import tempfile
23
- import zipfile
24
- import re
25
- import ast
26
- import time
27
- from datetime import datetime, timezone
28
  from collections import defaultdict
29
  from classifyTags import classify_tags
30
- # Add scheduler code here
31
  from apscheduler.schedulers.background import BackgroundScheduler
32
-
33
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
- def fixed_get_imports(filename: str | os.PathLike) -> list[str]:
35
- """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
36
- if not str(filename).endswith("/modeling_florence2.py"):
37
- return get_imports(filename)
38
- imports = get_imports(filename)
39
- if "flash_attn" in imports:
40
- imports.remove("flash_attn")
41
- return imports
42
-
43
  @spaces.GPU
44
  def get_device_type():
45
- import torch
46
- if torch.cuda.is_available():
47
- return "cuda"
48
- else:
49
- if (torch.backends.mps.is_available() and torch.backends.mps.is_built()):
50
- return "mps"
51
- else:
52
- return "cpu"
53
-
54
  model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
55
 
56
  import subprocess
@@ -61,7 +35,6 @@ if (device == "cuda"):
61
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
62
  model.to(device)
63
  else:
64
- #https://huggingface.co/microsoft/Florence-2-base-ft/discussions/4
65
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
66
  model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
67
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
@@ -75,13 +48,12 @@ Features:
75
  - Supports batch processing of multiple images.
76
  - Tags images with multiple categories: general tags, character tags, and ratings.
77
  - Displays categorized tags in a structured format.
78
- - Includes a separate tab for image captioning using Florence 2. This supports CUDA, MPS or CPU if one of them is available.
79
- - Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection), as well it can display output text and images for tasks that generate visual outputs.
80
 
81
  Example image by [me.](https://huggingface.co/Werli)
82
  """
83
- colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red',
84
- 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
85
 
86
  # Dataset v3 series of models:
87
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
@@ -89,277 +61,87 @@ CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
89
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
90
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
91
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
92
-
93
  # Dataset v2 series of models:
94
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
95
  SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
96
  CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
97
  CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
98
  VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
99
-
100
  # IdolSankaku series of models:
101
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
102
  SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
103
-
104
  # Files to download from the repos
105
  MODEL_FILENAME = "model.onnx"
106
  LABEL_FILENAME = "selected_tags.csv"
107
-
108
  # LLAMA model
109
  META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
110
  META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
111
 
112
- # https://github.com/toriato/stable-diffusion-webui-wd14-tagger/blob/a9eacb1eff904552d3012babfa28b57e1d3e295c/tagger/ui.py#L368
113
- kaomojis = [
114
- "0_0",
115
- "(o)_(o)",
116
- "+_+",
117
- "+_-",
118
- "._.",
119
- "<o>_<o>",
120
- "<|>_<|>",
121
- "=_=",
122
- ">_<",
123
- "3_3",
124
- "6_9",
125
- ">_o",
126
- "@_@",
127
- "^_^",
128
- "o_o",
129
- "u_u",
130
- "x_x",
131
- "|_|",
132
- "||_||",
133
- ]
134
- def parse_args() -> argparse.Namespace:
135
- parser = argparse.ArgumentParser()
136
- parser.add_argument("--score-slider-step", type=float, default=0.05)
137
- parser.add_argument("--score-general-threshold", type=float, default=0.35)
138
- parser.add_argument("--score-character-threshold", type=float, default=0.85)
139
- parser.add_argument("--share", action="store_true")
140
- return parser.parse_args()
141
- def load_labels(dataframe) -> list[str]:
142
- name_series = dataframe["name"]
143
- name_series = name_series.map(
144
- lambda x: x.replace("_", " ") if x not in kaomojis else x
145
- )
146
- tag_names = name_series.tolist()
147
-
148
- rating_indexes = list(np.where(dataframe["category"] == 9)[0])
149
- general_indexes = list(np.where(dataframe["category"] == 0)[0])
150
- character_indexes = list(np.where(dataframe["category"] == 4)[0])
151
- return tag_names, rating_indexes, general_indexes, character_indexes
152
- def mcut_threshold(probs):
153
- """
154
- Maximum Cut Thresholding (MCut)
155
- Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
156
- for Multi-label Classification. In 11th International Symposium, IDA 2012
157
- (pp. 172-183).
158
- """
159
- sorted_probs = probs[probs.argsort()[::-1]]
160
- difs = sorted_probs[:-1] - sorted_probs[1:]
161
- t = difs.argmax()
162
- thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
163
- return thresh
164
  class Timer:
165
- def __init__(self):
166
- self.start_time = time.perf_counter() # Record the start time
167
- self.checkpoints = [("Start", self.start_time)] # Store checkpoints
168
-
169
- def checkpoint(self, label="Checkpoint"):
170
- """Record a checkpoint with a given label."""
171
- now = time.perf_counter()
172
- self.checkpoints.append((label, now))
173
-
174
- def report(self, is_clear_checkpoints = True):
175
- # Determine the max label width for alignment
176
- max_label_length = max(len(label) for label, _ in self.checkpoints)
177
-
178
- prev_time = self.checkpoints[0][1]
179
- for label, curr_time in self.checkpoints[1:]:
180
- elapsed = curr_time - prev_time
181
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
182
- prev_time = curr_time
183
-
184
- if is_clear_checkpoints:
185
- self.checkpoints.clear()
186
- self.checkpoint() # Store checkpoints
187
-
188
- def report_all(self):
189
- """Print all recorded checkpoints and total execution time with aligned formatting."""
190
- print("\n> Execution Time Report:")
191
-
192
- # Determine the max label width for alignment
193
- max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
194
-
195
- prev_time = self.start_time
196
- for label, curr_time in self.checkpoints[1:]:
197
- elapsed = curr_time - prev_time
198
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
199
- prev_time = curr_time
200
-
201
- total_time = self.checkpoints[-1][1] - self.start_time
202
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
203
-
204
- self.checkpoints.clear()
205
-
206
- def restart(self):
207
- self.start_time = time.perf_counter() # Record the start time
208
- self.checkpoints = [("Start", self.start_time)] # Store checkpoints
209
 
210
  class Llama3Reorganize:
211
- def __init__(
212
- self,
213
- repoId: str,
214
- device: str = None,
215
- loadModel: bool = False,
216
- ):
217
- """Initializes the Llama model.
218
-
219
- Args:
220
- repoId: LLAMA model repo.
221
- device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
222
- ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
223
- localFilesOnly: If True, avoid downloading the file and return the path to the
224
- local cached file if it exists.
225
- """
226
- self.modelPath = self.download_model(repoId)
227
-
228
- if device is None:
229
- import torch
230
- self.totalVram = 0
231
- if torch.cuda.is_available():
232
- try:
233
- deviceId = torch.cuda.current_device()
234
- self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
235
- except Exception as e:
236
- print(traceback.format_exc())
237
- print("Error detect vram: " + str(e))
238
- device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
239
- else:
240
- device = "cpu"
241
-
242
- self.device = device
243
- self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
244
-
245
- if loadModel:
246
- self.load_model()
247
-
248
- def download_model(self, repoId):
249
- import warnings
250
- import requests
251
- allowPatterns = [
252
- "config.json",
253
- "generation_config.json",
254
- "model.bin",
255
- "pytorch_model.bin",
256
- "pytorch_model.bin.index.json",
257
- "pytorch_model-*.bin",
258
- "sentencepiece.bpe.model",
259
- "tokenizer.json",
260
- "tokenizer_config.json",
261
- "shared_vocabulary.txt",
262
- "shared_vocabulary.json",
263
- "special_tokens_map.json",
264
- "spiece.model",
265
- "vocab.json",
266
- "model.safetensors",
267
- "model-*.safetensors",
268
- "model.safetensors.index.json",
269
- "quantize_config.json",
270
- "tokenizer.model",
271
- "vocabulary.json",
272
- "preprocessor_config.json",
273
- "added_tokens.json"
274
- ]
275
-
276
- kwargs = {"allow_patterns": allowPatterns,}
277
-
278
- try:
279
- return huggingface_hub.snapshot_download(repoId, **kwargs)
280
- except (
281
- huggingface_hub.utils.HfHubHTTPError,
282
- requests.exceptions.ConnectionError,
283
- ) as exception:
284
- warnings.warn(
285
- "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
286
- repoId,
287
- exception,
288
- )
289
- warnings.warn(
290
- "Trying to load the model directly from the local cache, if it exists."
291
- )
292
-
293
- kwargs["local_files_only"] = True
294
- return huggingface_hub.snapshot_download(repoId, **kwargs)
295
-
296
 
297
  def load_model(self):
298
- import ctranslate2
299
- import transformers
300
- try:
301
- print('\n\nLoading model: %s\n\n' % self.modelPath)
302
- kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
303
- kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
304
- self.roleSystem = {"role": "system", "content": self.system_prompt}
305
- self.Model = ctranslate2.Generator(**kwargsModel)
306
-
307
- self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
308
- self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
309
-
310
- except Exception as e:
311
- self.release_vram()
312
- raise e
313
-
314
 
315
  def release_vram(self):
316
- try:
317
- import torch
318
- if torch.cuda.is_available():
319
- if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
320
- self.Model.unload_model()
321
-
322
- if getattr(self, "Tokenizer", None) is not None:
323
- del self.Tokenizer
324
- if getattr(self, "Model", None) is not None:
325
- del self.Model
326
- import gc
327
- gc.collect()
328
- try:
329
- torch.cuda.empty_cache()
330
- except Exception as e:
331
- print(traceback.format_exc())
332
- print("\tcuda empty cache, error: " + str(e))
333
- print("release vram end.")
334
- except Exception as e:
335
- print(traceback.format_exc())
336
- print("Error release vram: " + str(e))
337
-
338
- def reorganize(self, text: str, max_length: int = 400):
339
- output = None
340
- result = None
341
- try:
342
- input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
343
- source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
344
- output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
345
- target = output[0]
346
- result = self.Tokenizer.decode(target.sequences_ids[0])
347
-
348
- if len(result) > 2:
349
- if result[0] == "\"" and result[len(result) - 1] == "\"":
350
- result = result[1:-1]
351
- elif result[0] == "'" and result[len(result) - 1] == "'":
352
- result = result[1:-1]
353
- elif result[0] == "「" and result[len(result) - 1] == "」":
354
- result = result[1:-1]
355
- elif result[0] == "『" and result[len(result) - 1] == "』":
356
- result = result[1:-1]
357
- except Exception as e:
358
- print(traceback.format_exc())
359
- print("Error reorganize text: " + str(e))
360
-
361
- return result
362
-
363
 
364
  class Predictor:
365
  def __init__(self):
@@ -412,7 +194,7 @@ class Predictor:
412
 
413
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
414
  padded_image.paste(image, (pad_left, pad_top))
415
-
416
  # Resize
417
  if max_dim != target_size:
418
  padded_image = padded_image.resize(
@@ -421,18 +203,21 @@ class Predictor:
421
  )
422
  # Convert to numpy array
423
  image_array = np.asarray(padded_image, dtype=np.float32)
424
-
425
  # Convert PIL-native RGB to BGR
426
  image_array = image_array[:, :, ::-1]
427
-
428
  return np.expand_dims(image_array, axis=0)
429
 
430
- def create_file(self, text: str, directory: str, fileName: str) -> str:
431
- # Write the text to a file
432
- with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
433
- file.write(text)
 
 
 
 
 
434
 
435
- return file.name
436
 
437
  def predict(
438
  self,
@@ -465,14 +250,13 @@ class Predictor:
465
  progress(current_progress, desc="Initialize wd model finished")
466
  timer.checkpoint(f"Initialize wd model")
467
 
468
- # Result
469
  txt_infos = []
470
  output_dir = tempfile.mkdtemp()
471
  if not os.path.exists(output_dir):
472
  os.makedirs(output_dir)
473
 
474
  sorted_general_strings = ""
475
- # New code to create categorized output string
476
  categorized_output_strings = []
477
  rating = None
478
  character_res = None
@@ -567,6 +351,22 @@ class Predictor:
567
  # Collect all categorized output strings into a single string
568
  final_categorized_output = ', '.join(categorized_output_strings)
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  current_progress += progressRatio/progressTotal;
571
  progress(current_progress, desc=f"image{idx:02d}, predict finished")
572
  timer.checkpoint(f"image{idx:02d}, predict finished")
@@ -602,15 +402,16 @@ class Predictor:
602
  print(traceback.format_exc())
603
  print("Error predict: " + str(e))
604
  # Result
 
605
  download = []
606
  if txt_infos is not None and len(txt_infos) > 0:
607
- downloadZipPath = os.path.join(output_dir, "images-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
608
  with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
609
  for info in txt_infos:
610
  # Get file name from lookup
611
  taggers_zip.write(info["path"], arcname=info["name"])
612
  download.append(downloadZipPath)
613
-
614
  if llama3_reorganize_model_repo:
615
  llama3_reorganize.release_vram()
616
  del llama3_reorganize
@@ -620,11 +421,9 @@ class Predictor:
620
  print("Predict is complete.")
621
 
622
  return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
623
-
624
  def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
625
  if not selected_state:
626
  return selected_state
627
-
628
  tag_result = {
629
  "strings": "",
630
  "strings2": "",
@@ -636,272 +435,80 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
636
  }
637
  if selected_state.value["image"]["path"] in tag_results:
638
  tag_result = tag_results[selected_state.value["image"]["path"]]
639
-
640
  return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
641
-
642
- def append_gallery(gallery: list, image: str):
643
- if gallery is None:
644
- gallery = []
645
- if not image:
646
- return gallery, None
647
-
648
- gallery.append(image)
649
-
650
- return gallery, None
651
-
652
-
653
- def extend_gallery(gallery: list, images):
654
- if gallery is None:
655
- gallery = []
656
- if not images:
657
- return gallery
658
-
659
- # Combine the new images with the existing gallery images
660
- gallery.extend(images)
661
-
662
- return gallery
663
-
664
- def remove_image_from_gallery(gallery: list, selected_image: str):
665
- if not gallery or not selected_image:
666
- return gallery
667
-
668
- selected_image = ast.literal_eval(selected_image) # Use ast.literal_eval to parse text into a tuple.
669
- # Remove the selected image from the gallery
670
- if selected_image in gallery:
671
- gallery.remove(selected_image)
672
- return gallery
673
-
674
  # END
675
 
676
- def fig_to_pil(fig):
677
- buf = io.BytesIO()
678
- fig.savefig(buf, format='png')
679
- buf.seek(0)
680
- return Image.open(buf)
681
-
682
  @spaces.GPU
683
- def run_example(task_prompt, image, text_input=None):
684
- if text_input is None:
685
- prompt = task_prompt
686
- else:
687
- prompt = task_prompt + text_input
688
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
689
- generated_ids = model.generate(
690
- input_ids=inputs["input_ids"],
691
- pixel_values=inputs["pixel_values"],
692
- max_new_tokens=1024,
693
- early_stopping=False,
694
- do_sample=False,
695
- num_beams=3,
696
- )
697
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
698
- parsed_answer = processor.post_process_generation(
699
- generated_text,
700
- task=task_prompt,
701
- image_size=(image.width, image.height)
702
- )
703
- return parsed_answer
704
-
705
- def plot_bbox(image, data):
706
- fig, ax = plt.subplots()
707
- ax.imshow(image)
708
- for bbox, label in zip(data['bboxes'], data['labels']):
709
- x1, y1, x2, y2 = bbox
710
- rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
711
- ax.add_patch(rect)
712
- plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
713
- ax.axis('off')
714
- return fig
715
-
716
- def draw_polygons(image, prediction, fill_mask=False):
717
- draw = ImageDraw.Draw(image)
718
- scale = 1
719
- for polygons, label in zip(prediction['polygons'], prediction['labels']):
720
- color = random.choice(colormap)
721
- fill_color = random.choice(colormap) if fill_mask else None
722
- for _polygon in polygons:
723
- _polygon = np.array(_polygon).reshape(-1, 2)
724
- if len(_polygon) < 3:
725
- print('Invalid polygon:', _polygon)
726
- continue
727
- _polygon = (_polygon * scale).reshape(-1).tolist()
728
- if fill_mask:
729
- draw.polygon(_polygon, outline=color, fill=fill_color)
730
- else:
731
- draw.polygon(_polygon, outline=color)
732
- draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color)
733
- return image
734
-
735
- def convert_to_od_format(data):
736
- bboxes = data.get('bboxes', [])
737
- labels = data.get('bboxes_labels', [])
738
- od_results = {
739
- 'bboxes': bboxes,
740
- 'labels': labels
741
- }
742
- return od_results
743
-
744
- def draw_ocr_bboxes(image, prediction):
745
- scale = 1
746
- draw = ImageDraw.Draw(image)
747
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
748
- for box, label in zip(bboxes, labels):
749
- color = random.choice(colormap)
750
- new_box = (np.array(box) * scale).tolist()
751
- draw.polygon(new_box, width=3, outline=color)
752
- draw.text((new_box[0]+8, new_box[1]+2),
753
- "{}".format(label),
754
- align="right",
755
- fill=color)
756
- return image
757
-
758
- def convert_to_od_format(data):
759
- bboxes = data.get('bboxes', [])
760
- labels = data.get('bboxes_labels', [])
761
- od_results = {
762
- 'bboxes': bboxes,
763
- 'labels': labels
764
- }
765
- return od_results
766
-
767
- def draw_ocr_bboxes(image, prediction):
768
- scale = 1
769
- draw = ImageDraw.Draw(image)
770
- bboxes, labels = prediction['quad_boxes'], prediction['labels']
771
- for box, label in zip(bboxes, labels):
772
- color = random.choice(colormap)
773
- new_box = (np.array(box) * scale).tolist()
774
- draw.polygon(new_box, width=3, outline=color)
775
- draw.text((new_box[0]+8, new_box[1]+2),
776
- "{}".format(label),
777
- align="right",
778
- fill=color)
779
- return image
780
- def process_image(image, task_prompt, text_input=None):
781
- # Test
782
- if isinstance(image, str): # If image is a file path
783
- image = Image.open(image) # Load image from file path
784
- else: # If image is a NumPy array
785
- image = Image.fromarray(image) # Convert NumPy array to PIL Image
786
- if task_prompt == 'Caption':
787
- task_prompt = '<CAPTION>'
788
- results = run_example(task_prompt, image)
789
- return results[task_prompt], None
790
- elif task_prompt == 'Detailed Caption':
791
- task_prompt = '<DETAILED_CAPTION>'
792
- results = run_example(task_prompt, image)
793
- return results[task_prompt], None
794
- elif task_prompt == 'More Detailed Caption':
795
- task_prompt = '<MORE_DETAILED_CAPTION>'
796
- results = run_example(task_prompt, image)
797
- return results, None
798
- elif task_prompt == 'Caption + Grounding':
799
- task_prompt = '<CAPTION>'
800
- results = run_example(task_prompt, image)
801
- text_input = results[task_prompt]
802
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
803
- results = run_example(task_prompt, image, text_input)
804
- results['<CAPTION>'] = text_input
805
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
806
- return results, fig_to_pil(fig)
807
- elif task_prompt == 'Detailed Caption + Grounding':
808
- task_prompt = '<DETAILED_CAPTION>'
809
- results = run_example(task_prompt, image)
810
- text_input = results[task_prompt]
811
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
812
- results = run_example(task_prompt, image, text_input)
813
- results['<DETAILED_CAPTION>'] = text_input
814
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
815
- return results, fig_to_pil(fig)
816
- elif task_prompt == 'More Detailed Caption + Grounding':
817
- task_prompt = '<MORE_DETAILED_CAPTION>'
818
- results = run_example(task_prompt, image)
819
- text_input = results[task_prompt]
820
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
821
- results = run_example(task_prompt, image, text_input)
822
- results['<MORE_DETAILED_CAPTION>'] = text_input
823
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
824
- return results, fig_to_pil(fig)
825
- elif task_prompt == 'Object Detection':
826
- task_prompt = '<OD>'
827
- results = run_example(task_prompt, image)
828
- fig = plot_bbox(image, results['<OD>'])
829
- return results, fig_to_pil(fig)
830
- elif task_prompt == 'Dense Region Caption':
831
- task_prompt = '<DENSE_REGION_CAPTION>'
832
- results = run_example(task_prompt, image)
833
- fig = plot_bbox(image, results['<DENSE_REGION_CAPTION>'])
834
- return results, fig_to_pil(fig)
835
- elif task_prompt == 'Region Proposal':
836
- task_prompt = '<REGION_PROPOSAL>'
837
- results = run_example(task_prompt, image)
838
- fig = plot_bbox(image, results['<REGION_PROPOSAL>'])
839
- return results, fig_to_pil(fig)
840
- elif task_prompt == 'Caption to Phrase Grounding':
841
- task_prompt = '<CAPTION_TO_PHRASE_GROUNDING>'
842
- results = run_example(task_prompt, image, text_input)
843
- fig = plot_bbox(image, results['<CAPTION_TO_PHRASE_GROUNDING>'])
844
- return results, fig_to_pil(fig)
845
- elif task_prompt == 'Referring Expression Segmentation':
846
- task_prompt = '<REFERRING_EXPRESSION_SEGMENTATION>'
847
- results = run_example(task_prompt, image, text_input)
848
- output_image = copy.deepcopy(image)
849
- output_image = draw_polygons(output_image, results['<REFERRING_EXPRESSION_SEGMENTATION>'], fill_mask=True)
850
- return results, output_image
851
- elif task_prompt == 'Region to Segmentation':
852
- task_prompt = '<REGION_TO_SEGMENTATION>'
853
- results = run_example(task_prompt, image, text_input)
854
- output_image = copy.deepcopy(image)
855
- output_image = draw_polygons(output_image, results['<REGION_TO_SEGMENTATION>'], fill_mask=True)
856
- return results, output_image
857
- elif task_prompt == 'Open Vocabulary Detection':
858
- task_prompt = '<OPEN_VOCABULARY_DETECTION>'
859
- results = run_example(task_prompt, image, text_input)
860
- bbox_results = convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>'])
861
- fig = plot_bbox(image, bbox_results)
862
- return results, fig_to_pil(fig)
863
- elif task_prompt == 'Region to Category':
864
- task_prompt = '<REGION_TO_CATEGORY>'
865
- results = run_example(task_prompt, image, text_input)
866
- return results, None
867
- elif task_prompt == 'Region to Description':
868
- task_prompt = '<REGION_TO_DESCRIPTION>'
869
- results = run_example(task_prompt, image, text_input)
870
- return results, None
871
- elif task_prompt == 'OCR':
872
- task_prompt = '<OCR>'
873
- results = run_example(task_prompt, image)
874
- return results, None
875
- elif task_prompt == 'OCR with Region':
876
- task_prompt = '<OCR_WITH_REGION>'
877
- results = run_example(task_prompt, image)
878
- output_image = copy.deepcopy(image)
879
- output_image = draw_ocr_bboxes(output_image, results['<OCR_WITH_REGION>'])
880
- return results, output_image
881
- else:
882
- return "", None # Return empty string and None for unknown task prompts
883
 
884
- # Custom CSS to set the height of the gr.Dropdown menu
885
- css = """
886
- div.progress-level div.progress-level-inner {
887
- text-align: left !important;
888
- width: 55.5% !important;
889
- #output {
890
- height: 500px;
891
- overflow: auto;
892
- border: 1px solid #ccc;
893
- }
894
- """
895
- single_task_list =[
896
- 'Caption', 'Detailed Caption', 'More Detailed Caption', 'Object Detection',
897
- 'Dense Region Caption', 'Region Proposal', 'Caption to Phrase Grounding',
898
- 'Referring Expression Segmentation', 'Region to Segmentation',
899
- 'Open Vocabulary Detection', 'Region to Category', 'Region to Description',
900
- 'OCR', 'OCR with Region'
901
- ]
902
- cascaded_task_list =[
903
- 'Caption + Grounding', 'Detailed Caption + Grounding', 'More Detailed Caption + Grounding'
904
- ]
905
  def update_task_dropdown(choice):
906
  if choice == 'Cascaded task':
907
  return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
@@ -909,7 +516,6 @@ def update_task_dropdown(choice):
909
  return gr.Dropdown(choices=single_task_list, value='Caption')
910
 
911
  args = parse_args()
912
-
913
  predictor = Predictor()
914
 
915
  dropdown_list = [
@@ -933,19 +539,23 @@ llama_list = [
933
  META_LLAMA_3_8B_REPO,
934
  ]
935
 
936
- # This is workaround will make the space restart every 2 days. (for test).
937
  def _restart_space():
938
- HF_TOKEN = os.getenv("HF_TOKEN")
939
- if not HF_TOKEN:
940
- raise ValueError("HF_TOKEN environment variable is not set.")
941
- huggingface_hub.HfApi().restart_space(repo_id="Werli/Multi-Tagger", token=HF_TOKEN, factory_reboot=False)
942
- scheduler = BackgroundScheduler()
943
  # Add a job to restart the space every 2 days (172800 seconds)
944
  restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
945
- # Start the scheduler
946
  scheduler.start()
947
- next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
948
- NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
 
 
 
 
 
 
 
949
 
950
  with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
951
  gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
@@ -1030,7 +640,7 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
1030
  size="lg",
1031
  )
1032
  with gr.Column(variant="panel"):
1033
- download_file = gr.File(label="Output (Download)") # 0
1034
  character_res = gr.Label(label="Output (characters)") # 1
1035
  sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True) # 2
1036
  final_categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True) # 3
@@ -1112,5 +722,4 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
1112
  label='Try examples'
1113
  )
1114
  submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
1115
-
1116
  demo.queue(max_size=2).launch()
 
1
  import os
2
+ import io,copy,requests,numpy as np,spaces,gradio as gr
3
+ from transformers import AutoProcessor,AutoModelForCausalLM,AutoModelForCausalLM,AutoProcessor
 
 
 
 
 
 
 
4
  from transformers.dynamic_module_utils import get_imports
5
+ from PIL import Image,ImageDraw,ImageFont
6
+ import matplotlib.pyplot as plt,matplotlib.patches as patches
 
7
  from unittest.mock import patch
8
+ import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
9
+ from datetime import datetime,timezone
 
 
 
 
 
 
 
 
 
 
10
  from collections import defaultdict
11
  from classifyTags import classify_tags
 
12
  from apscheduler.schedulers.background import BackgroundScheduler
13
+ import json
14
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
15
+
16
+ def fixed_get_imports(filename:str|os.PathLike)->list[str]:
17
+ if not str(filename).endswith('/modeling_florence2.py'):return get_imports(filename)
18
+ imports=get_imports(filename)
19
+ if'flash_attn'in imports:imports.remove('flash_attn')
20
+ return imports
 
 
 
21
  @spaces.GPU
22
  def get_device_type():
23
+ import torch
24
+ if torch.cuda.is_available():return'cuda'
25
+ elif torch.backends.mps.is_available()and torch.backends.mps.is_built():return'mps'
26
+ else:return'cpu'
27
+
 
 
 
 
28
  model_id = 'MiaoshouAI/Florence-2-base-PromptGen-v2.0'
29
 
30
  import subprocess
 
35
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
36
  model.to(device)
37
  else:
 
38
  with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
39
  model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
40
  processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v2.0", trust_remote_code=True)
 
48
  - Supports batch processing of multiple images.
49
  - Tags images with multiple categories: general tags, character tags, and ratings.
50
  - Displays categorized tags in a structured format.
51
+ - Includes a separate tab for image captioning using Florence 2. Supports CUDA, MPS or CPU if one of them is available.
52
+ - Supports various captioning tasks (e.g., Caption, Detailed Caption, Object Detection), it can display output text and images for tasks that generate visual outputs.
53
 
54
  Example image by [me.](https://huggingface.co/Werli)
55
  """
56
+ colormap=['blue','orange','green','purple','brown','pink','gray','olive','cyan','red','lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue']
 
57
 
58
  # Dataset v3 series of models:
59
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
 
61
  VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
62
  VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
63
  EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
 
64
  # Dataset v2 series of models:
65
  MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
66
  SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
67
  CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
68
  CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
69
  VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
 
70
  # IdolSankaku series of models:
71
  EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
72
  SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
 
73
  # Files to download from the repos
74
  MODEL_FILENAME = "model.onnx"
75
  LABEL_FILENAME = "selected_tags.csv"
 
76
  # LLAMA model
77
  META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
78
  META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
79
 
80
+ kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
81
+ def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
82
+ def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
83
+ def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
84
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  class Timer:
86
+ def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
87
+ def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
88
+ def report(self,is_clear_checkpoints=True):
89
+ max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
90
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
91
+ if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
92
+ def report_all(self):
93
+ print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
94
+ for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
95
+ total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
96
+ def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  class Llama3Reorganize:
99
+ def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
100
+ self.modelPath=self.download_model(repoId)
101
+ if device is None:
102
+ import torch;self.totalVram=0
103
+ if torch.cuda.is_available():
104
+ try:deviceId=torch.cuda.current_device();self.totalVram=torch.cuda.get_device_properties(deviceId).total_memory/1073741824
105
+ except Exception as e:print(traceback.format_exc());print('Error detect vram: '+str(e))
106
+ device='cuda'if self.totalVram>(8 if'8B'in repoId else 4)else'cpu'
107
+ else:device='cpu'
108
+ self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
109
+ if loadModel:self.load_model()
110
+
111
+ def download_model(self,repoId):
112
+ import warnings,requests;allowPatterns=['config.json','generation_config.json','model.bin','pytorch_model.bin','pytorch_model.bin.index.json','pytorch_model-*.bin','sentencepiece.bpe.model','tokenizer.json','tokenizer_config.json','shared_vocabulary.txt','shared_vocabulary.json','special_tokens_map.json','spiece.model','vocab.json','model.safetensors','model-*.safetensors','model.safetensors.index.json','quantize_config.json','tokenizer.model','vocabulary.json','preprocessor_config.json','added_tokens.json'];kwargs={'allow_patterns':allowPatterns}
113
+ try:return huggingface_hub.snapshot_download(repoId,**kwargs)
114
+ except(huggingface_hub.utils.HfHubHTTPError,requests.exceptions.ConnectionError)as exception:warnings.warn('An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s',repoId,exception);warnings.warn('Trying to load the model directly from the local cache, if it exists.');kwargs['local_files_only']=True;return huggingface_hub.snapshot_download(repoId,**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def load_model(self):
117
+ import ctranslate2,transformers
118
+ try:print('\n\nLoading model: %s\n\n'%self.modelPath);kwargsTokenizer={'pretrained_model_name_or_path':self.modelPath};kwargsModel={'device':self.device,'model_path':self.modelPath,'compute_type':'auto'};self.roleSystem={'role':'system','content':self.system_prompt};self.Model=ctranslate2.Generator(**kwargsModel);self.Tokenizer=transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer);self.terminators=[self.Tokenizer.eos_token_id,self.Tokenizer.convert_tokens_to_ids('<|eot_id|>')]
119
+ except Exception as e:self.release_vram();raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  def release_vram(self):
122
+ try:
123
+ import torch
124
+ if torch.cuda.is_available():
125
+ if getattr(self,'Model',None)is not None and getattr(self.Model,'unload_model',None)is not None:self.Model.unload_model()
126
+ if getattr(self,'Tokenizer',None)is not None:del self.Tokenizer
127
+ if getattr(self,'Model',None)is not None:del self.Model
128
+ import gc;gc.collect()
129
+ try:torch.cuda.empty_cache()
130
+ except Exception as e:print(traceback.format_exc());print('\tcuda empty cache, error: '+str(e))
131
+ print('release vram end.')
132
+ except Exception as e:print(traceback.format_exc());print('Error release vram: '+str(e))
133
+
134
+ def reorganize(self,text:str,max_length:int=400):
135
+ output=None;result=None
136
+ try:
137
+ input_ids=self.Tokenizer.apply_chat_template([self.roleSystem,{'role':'user','content':text+"\n\nHere's the reorganized English article:"}],tokenize=False,add_generation_prompt=True);source=self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids));output=self.Model.generate_batch([source],max_length=max_length,max_batch_size=2,no_repeat_ngram_size=3,beam_size=2,sampling_temperature=.7,sampling_topp=.9,include_prompt_in_result=False,end_token=self.terminators);target=output[0];result=self.Tokenizer.decode(target.sequences_ids[0])
138
+ if len(result)>2:
139
+ if result[0]=='"'and result[len(result)-1]=='"':result=result[1:-1]
140
+ elif result[0]=="'"and result[len(result)-1]=="'":result=result[1:-1]
141
+ elif result[0]=='「'and result[len(result)-1]=='」':result=result[1:-1]
142
+ elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
143
+ except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
144
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  class Predictor:
147
  def __init__(self):
 
194
 
195
  padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
196
  padded_image.paste(image, (pad_left, pad_top))
197
+
198
  # Resize
199
  if max_dim != target_size:
200
  padded_image = padded_image.resize(
 
203
  )
204
  # Convert to numpy array
205
  image_array = np.asarray(padded_image, dtype=np.float32)
 
206
  # Convert PIL-native RGB to BGR
207
  image_array = image_array[:, :, ::-1]
 
208
  return np.expand_dims(image_array, axis=0)
209
 
210
+ def create_file(self, content: str, directory: str, fileName: str) -> str:
211
+ # Write the content to a file
212
+ file_path = os.path.join(directory, fileName)
213
+ if fileName.endswith('.json'):
214
+ with open(file_path, 'w', encoding="utf-8") as file:
215
+ file.write(content)
216
+ else:
217
+ with open(file_path, 'w+', encoding="utf-8") as file:
218
+ file.write(content)
219
 
220
+ return file_path
221
 
222
  def predict(
223
  self,
 
250
  progress(current_progress, desc="Initialize wd model finished")
251
  timer.checkpoint(f"Initialize wd model")
252
 
 
253
  txt_infos = []
254
  output_dir = tempfile.mkdtemp()
255
  if not os.path.exists(output_dir):
256
  os.makedirs(output_dir)
257
 
258
  sorted_general_strings = ""
259
+ # Create categorized output string
260
  categorized_output_strings = []
261
  rating = None
262
  character_res = None
 
351
  # Collect all categorized output strings into a single string
352
  final_categorized_output = ', '.join(categorized_output_strings)
353
 
354
+ # Create a .txt file for "Output (string)" and "Categorized Output (string)"
355
+ txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
356
+ txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
357
+ txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
358
+
359
+ # Create a .json file for "Categorized (tags)"
360
+ json_content = json.dumps(classified_tags, indent=4)
361
+ json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
362
+ txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
363
+
364
+ # Save a copy of the uploaded image in PNG format
365
+ image_path = value[0]
366
+ image = Image.open(image_path)
367
+ image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
368
+ txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
369
+
370
  current_progress += progressRatio/progressTotal;
371
  progress(current_progress, desc=f"image{idx:02d}, predict finished")
372
  timer.checkpoint(f"image{idx:02d}, predict finished")
 
402
  print(traceback.format_exc())
403
  print("Error predict: " + str(e))
404
  # Result
405
+ # Zip creation logic:
406
  download = []
407
  if txt_infos is not None and len(txt_infos) > 0:
408
+ downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
409
  with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
410
  for info in txt_infos:
411
  # Get file name from lookup
412
  taggers_zip.write(info["path"], arcname=info["name"])
413
  download.append(downloadZipPath)
414
+ # End zip creation logic
415
  if llama3_reorganize_model_repo:
416
  llama3_reorganize.release_vram()
417
  del llama3_reorganize
 
421
  print("Predict is complete.")
422
 
423
  return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
 
424
  def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
425
  if not selected_state:
426
  return selected_state
 
427
  tag_result = {
428
  "strings": "",
429
  "strings2": "",
 
435
  }
436
  if selected_state.value["image"]["path"] in tag_results:
437
  tag_result = tag_results[selected_state.value["image"]["path"]]
 
438
  return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
439
+ def append_gallery(gallery:list,image:str):
440
+ if gallery is None:gallery=[]
441
+ if not image:return gallery,None
442
+ gallery.append(image);return gallery,None
443
+ def extend_gallery(gallery:list,images):
444
+ if gallery is None:gallery=[]
445
+ if not images:return gallery
446
+ gallery.extend(images);return gallery
447
+ def remove_image_from_gallery(gallery:list,selected_image:str):
448
+ if not gallery or not selected_image:return gallery
449
+ selected_image=ast.literal_eval(selected_image)
450
+ if selected_image in gallery:gallery.remove(selected_image)
451
+ return gallery
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  # END
453
 
454
+ def fig_to_pil(fig):buf=io.BytesIO();fig.savefig(buf,format='png');buf.seek(0);return Image.open(buf)
 
 
 
 
 
455
  @spaces.GPU
456
+ def run_example(task_prompt,image,text_input=None):
457
+ if text_input is None:prompt=task_prompt
458
+ else:prompt=task_prompt+text_input
459
+ inputs=processor(text=prompt,images=image,return_tensors='pt').to(device);generated_ids=model.generate(input_ids=inputs['input_ids'],pixel_values=inputs['pixel_values'],max_new_tokens=1024,early_stopping=False,do_sample=False,num_beams=3);generated_text=processor.batch_decode(generated_ids,skip_special_tokens=False)[0];parsed_answer=processor.post_process_generation(generated_text,task=task_prompt,image_size=(image.width,image.height));return parsed_answer
460
+ def plot_bbox(image,data):
461
+ fig,ax=plt.subplots();ax.imshow(image)
462
+ for(bbox,label)in zip(data['bboxes'],data['labels']):x1,y1,x2,y2=bbox;rect=patches.Rectangle((x1,y1),x2-x1,y2-y1,linewidth=1,edgecolor='r',facecolor='none');ax.add_patch(rect);plt.text(x1,y1,label,color='white',fontsize=8,bbox=dict(facecolor='red',alpha=.5))
463
+ ax.axis('off');return fig
464
+ def draw_polygons(image,prediction,fill_mask=False):
465
+ draw=ImageDraw.Draw(image);scale=1
466
+ for(polygons,label)in zip(prediction['polygons'],prediction['labels']):
467
+ color=random.choice(colormap);fill_color=random.choice(colormap)if fill_mask else None
468
+ for _polygon in polygons:
469
+ _polygon=np.array(_polygon).reshape(-1,2)
470
+ if len(_polygon)<3:print('Invalid polygon:',_polygon);continue
471
+ _polygon=(_polygon*scale).reshape(-1).tolist()
472
+ if fill_mask:draw.polygon(_polygon,outline=color,fill=fill_color)
473
+ else:draw.polygon(_polygon,outline=color)
474
+ draw.text((_polygon[0]+8,_polygon[1]+2),label,fill=color)
475
+ return image
476
+ def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
477
+ def draw_ocr_bboxes(image,prediction):
478
+ scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
479
+ for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
480
+ return image
481
+ def convert_to_od_format(data):bboxes=data.get('bboxes',[]);labels=data.get('bboxes_labels',[]);od_results={'bboxes':bboxes,'labels':labels};return od_results
482
+ def draw_ocr_bboxes(image,prediction):
483
+ scale=1;draw=ImageDraw.Draw(image);bboxes,labels=prediction['quad_boxes'],prediction['labels']
484
+ for(box,label)in zip(bboxes,labels):color=random.choice(colormap);new_box=(np.array(box)*scale).tolist();draw.polygon(new_box,width=3,outline=color);draw.text((new_box[0]+8,new_box[1]+2),'{}'.format(label),align='right',fill=color)
485
+ return image
486
+
487
+ def process_image(image,task_prompt,text_input=None):
488
+ if isinstance(image,str):image=Image.open(image)
489
+ else:image=Image.fromarray(image)
490
+ if task_prompt=='Caption':task_prompt='<CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
491
+ elif task_prompt=='Detailed Caption':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);return results[task_prompt],None
492
+ elif task_prompt=='More Detailed Caption':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);return results,None
493
+ elif task_prompt=='Caption + Grounding':task_prompt='<CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
494
+ elif task_prompt=='Detailed Caption + Grounding':task_prompt='<DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
495
+ elif task_prompt=='More Detailed Caption + Grounding':task_prompt='<MORE_DETAILED_CAPTION>';results=run_example(task_prompt,image);text_input=results[task_prompt];task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);results['<MORE_DETAILED_CAPTION>']=text_input;fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
496
+ elif task_prompt=='Object Detection':task_prompt='<OD>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<OD>']);return results,fig_to_pil(fig)
497
+ elif task_prompt=='Dense Region Caption':task_prompt='<DENSE_REGION_CAPTION>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<DENSE_REGION_CAPTION>']);return results,fig_to_pil(fig)
498
+ elif task_prompt=='Region Proposal':task_prompt='<REGION_PROPOSAL>';results=run_example(task_prompt,image);fig=plot_bbox(image,results['<REGION_PROPOSAL>']);return results,fig_to_pil(fig)
499
+ elif task_prompt=='Caption to Phrase Grounding':task_prompt='<CAPTION_TO_PHRASE_GROUNDING>';results=run_example(task_prompt,image,text_input);fig=plot_bbox(image,results['<CAPTION_TO_PHRASE_GROUNDING>']);return results,fig_to_pil(fig)
500
+ elif task_prompt=='Referring Expression Segmentation':task_prompt='<REFERRING_EXPRESSION_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REFERRING_EXPRESSION_SEGMENTATION>'],fill_mask=True);return results,output_image
501
+ elif task_prompt=='Region to Segmentation':task_prompt='<REGION_TO_SEGMENTATION>';results=run_example(task_prompt,image,text_input);output_image=copy.deepcopy(image);output_image=draw_polygons(output_image,results['<REGION_TO_SEGMENTATION>'],fill_mask=True);return results,output_image
502
+ elif task_prompt=='Open Vocabulary Detection':task_prompt='<OPEN_VOCABULARY_DETECTION>';results=run_example(task_prompt,image,text_input);bbox_results=convert_to_od_format(results['<OPEN_VOCABULARY_DETECTION>']);fig=plot_bbox(image,bbox_results);return results,fig_to_pil(fig)
503
+ elif task_prompt=='Region to Category':task_prompt='<REGION_TO_CATEGORY>';results=run_example(task_prompt,image,text_input);return results,None
504
+ elif task_prompt=='Region to Description':task_prompt='<REGION_TO_DESCRIPTION>';results=run_example(task_prompt,image,text_input);return results,None
505
+ elif task_prompt=='OCR':task_prompt='<OCR>';results=run_example(task_prompt,image);return results,None
506
+ elif task_prompt=='OCR with Region':task_prompt='<OCR_WITH_REGION>';results=run_example(task_prompt,image);output_image=copy.deepcopy(image);output_image=draw_ocr_bboxes(output_image,results['<OCR_WITH_REGION>']);return results,output_image
507
+ else:return'',None # Return empty string and None for unknown task prompts
508
+
509
+ single_task_list=['Caption','Detailed Caption','More Detailed Caption','Object Detection','Dense Region Caption','Region Proposal','Caption to Phrase Grounding','Referring Expression Segmentation','Region to Segmentation','Open Vocabulary Detection','Region to Category','Region to Description','OCR','OCR with Region']
510
+ cascaded_task_list=['Caption + Grounding','Detailed Caption + Grounding','More Detailed Caption + Grounding']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  def update_task_dropdown(choice):
513
  if choice == 'Cascaded task':
514
  return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
 
516
  return gr.Dropdown(choices=single_task_list, value='Caption')
517
 
518
  args = parse_args()
 
519
  predictor = Predictor()
520
 
521
  dropdown_list = [
 
539
  META_LLAMA_3_8B_REPO,
540
  ]
541
 
 
542
  def _restart_space():
543
+ HF_TOKEN=os.getenv('HF_TOKEN')
544
+ if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
545
+ huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
546
+ scheduler=BackgroundScheduler()
 
547
  # Add a job to restart the space every 2 days (172800 seconds)
548
  restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
 
549
  scheduler.start()
550
+ next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
551
+ NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
552
+
553
+ css = """
554
+ div.progress-level div.progress-level-inner {text-align: left !important; width: 55.5% !important;}
555
+ #output {height: 500px; overflow: auto; border: 1px solid #ccc;}
556
+ label.float.svelte-i3tvor {position: relative !important;}
557
+ .reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
558
+ """
559
 
560
  with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
561
  gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
 
640
  size="lg",
641
  )
642
  with gr.Column(variant="panel"):
643
+ download_file = gr.File(label="Output (Download All)") # 0
644
  character_res = gr.Label(label="Output (characters)") # 1
645
  sorted_general_strings = gr.Textbox(label="Output (string)", show_label=True, show_copy_button=True) # 2
646
  final_categorized_output = gr.Textbox(label="Categorized Output (string)", show_label=True, show_copy_button=True) # 3
 
722
  label='Try examples'
723
  )
724
  submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
 
725
  demo.queue(max_size=2).launch()