Werli commited on
Commit
cdb99b8
·
verified ·
1 Parent(s): 344e0da

Upload 5 files

Browse files

Added "Prompt Enhancer" from [John6666/danbooru-tags-transformer-v2-with-wd-tagger](https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger/blob/main/tagger/promptenhancer.py) (Thanks!) and cleared some code.

app.py CHANGED
@@ -11,6 +11,7 @@ import json
11
  from modules.classifyTags import classify_tags,process_tags
12
  from modules.florence2 import process_image,single_task_list,update_task_dropdown
13
  from modules.llama_loader import llama_list,llama3reorganize
 
14
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
15
 
16
  TITLE = "Multi-Tagger"
@@ -249,9 +250,9 @@ class Predictor:
249
  reverse=True,
250
  )
251
  sorted_general_list = [x[0] for x in sorted_general_list]
252
- #Remove values from character_list that already exist in sorted_general_list
253
  character_list = [item for item in character_list if item not in sorted_general_list]
254
- #Remove values from sorted_general_list that already exist in prepend_list or append_list
255
  if prepend_list:
256
  sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
257
  if append_list:
@@ -312,7 +313,8 @@ class Predictor:
312
  "rating": rating,
313
  "character_res": character_res,
314
  "general_res": general_res,
315
- "unclassified_tags": unclassified_tags
 
316
  }
317
 
318
  timer.report()
@@ -348,11 +350,12 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
348
  "rating": "",
349
  "character_res": "",
350
  "general_res": "",
351
- "unclassified_tags": "{}"
 
352
  }
353
  if selected_state.value["image"]["path"] in tag_results:
354
  tag_result = tag_results[selected_state.value["image"]["path"]]
355
- return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
356
  def append_gallery(gallery:list,image:str):
357
  if gallery is None:gallery=[]
358
  if not image:return gallery,None
@@ -417,7 +420,6 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
417
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
418
  remove_button = gr.Button("Remove Selected Image", size="sm")
419
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
420
-
421
  model_repo = gr.Dropdown(
422
  dropdown_list,
423
  value=EVA02_LARGE_MODEL_DSV3_REPO,
@@ -485,14 +487,17 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
485
  size="lg",
486
  )
487
  with gr.Column(variant="panel"):
488
- download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
489
- character_res = gr.Label(label="Output (characters)") # 1
490
- sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
491
- final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
492
- categorized = gr.JSON(label="Categorized (tags)* - JSON") # 4
493
- rating = gr.Label(label="Rating") # 5
494
- general_res = gr.Label(label="Output (tags)") # 6
495
- unclassified = gr.JSON(label="Unclassified (tags)") # 7
 
 
 
496
  clear.add(
497
  [
498
  download_file,
@@ -503,8 +508,10 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
503
  character_res,
504
  general_res,
505
  unclassified,
 
 
506
  ]
507
- )
508
  tag_results = gr.State({})
509
  # Define the event listener to add the uploaded image to the gallery
510
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
@@ -512,9 +519,11 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
512
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
513
  # Event to update the selected image when an image is clicked in the gallery
514
  selected_image = gr.Textbox(label="Selected Image", visible=False)
515
- gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified])
516
  # Event to remove a selected image from the gallery
517
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
 
 
518
  submit.click(
519
  predictor.predict,
520
  inputs=[
@@ -543,7 +552,7 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
543
  character_mcut_enabled,
544
  ],
545
  )
546
- with gr.Tab(label="Tag Categorizer"):
547
  with gr.Row():
548
  with gr.Column(variant="panel"):
549
  input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
@@ -551,7 +560,12 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
551
  with gr.Column(variant="panel"):
552
  categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
553
  categorized_json = gr.JSON(label="Categorized (tags) - JSON")
554
- submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
 
 
 
 
 
555
  with gr.Tab(label="Florence 2 Image Captioning"):
556
  with gr.Row():
557
  with gr.Column(variant="panel"):
 
11
  from modules.classifyTags import classify_tags,process_tags
12
  from modules.florence2 import process_image,single_task_list,update_task_dropdown
13
  from modules.llama_loader import llama_list,llama3reorganize
14
+ from modules.tag_enhancer import prompt_enhancer
15
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
16
 
17
  TITLE = "Multi-Tagger"
 
250
  reverse=True,
251
  )
252
  sorted_general_list = [x[0] for x in sorted_general_list]
253
+ # Remove values from character_list that already exist in sorted_general_list
254
  character_list = [item for item in character_list if item not in sorted_general_list]
255
+ # Remove values from sorted_general_list that already exist in prepend_list or append_list
256
  if prepend_list:
257
  sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
258
  if append_list:
 
313
  "rating": rating,
314
  "character_res": character_res,
315
  "general_res": general_res,
316
+ "unclassified_tags": unclassified_tags,
317
+ "enhanced_tags": "" # Initialize as empty string
318
  }
319
 
320
  timer.report()
 
350
  "rating": "",
351
  "character_res": "",
352
  "general_res": "",
353
+ "unclassified_tags": "{}",
354
+ "enhanced_tags": ""
355
  }
356
  if selected_state.value["image"]["path"] in tag_results:
357
  tag_result = tag_results[selected_state.value["image"]["path"]]
358
+ return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
359
  def append_gallery(gallery:list,image:str):
360
  if gallery is None:gallery=[]
361
  if not image:return gallery,None
 
420
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
421
  remove_button = gr.Button("Remove Selected Image", size="sm")
422
  gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
 
423
  model_repo = gr.Dropdown(
424
  dropdown_list,
425
  value=EVA02_LARGE_MODEL_DSV3_REPO,
 
487
  size="lg",
488
  )
489
  with gr.Column(variant="panel"):
490
+ download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
491
+ character_res = gr.Label(label="Output (characters)") # 1
492
+ sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
493
+ final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
494
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
495
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
496
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
497
+ categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
498
+ rating = gr.Label(label="Rating") # 8
499
+ general_res = gr.Label(label="Output (tags)") # 9
500
+ unclassified = gr.JSON(label="Unclassified (tags)") # 10
501
  clear.add(
502
  [
503
  download_file,
 
508
  character_res,
509
  general_res,
510
  unclassified,
511
+ prompt_enhancer_model,
512
+ enhanced_tags,
513
  ]
514
+ )
515
  tag_results = gr.State({})
516
  # Define the event listener to add the uploaded image to the gallery
517
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
 
519
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
520
  # Event to update the selected image when an image is clicked in the gallery
521
  selected_image = gr.Textbox(label="Selected Image", visible=False)
522
+ gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
523
  # Event to remove a selected image from the gallery
524
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
525
+ # Event to for the Prompt Enhancer Button
526
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
527
  submit.click(
528
  predictor.predict,
529
  inputs=[
 
552
  character_mcut_enabled,
553
  ],
554
  )
555
+ with gr.Tab(label="Tag Categorizer + Enhancer"):
556
  with gr.Row():
557
  with gr.Column(variant="panel"):
558
  input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
 
560
  with gr.Column(variant="panel"):
561
  categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
562
  categorized_json = gr.JSON(label="Categorized (tags) - JSON")
563
+ submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
564
+ with gr.Column(variant="panel"):
565
+ pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
566
+ enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
567
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
568
+ pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
569
  with gr.Tab(label="Florence 2 Image Captioning"):
570
  with gr.Row():
571
  with gr.Column(variant="panel"):
modules/classifyTags.py CHANGED
@@ -171,9 +171,6 @@ def process_tags(input_tags: str):
171
  categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
172
  categorized_json = {category: tags for category, tags in classified_tags.items()}
173
 
174
- return categorized_string, categorized_json
175
 
176
- tags = []
177
- if __name__ == "__main__":
178
- classify_tags (tags, True)
179
- process_tags(input_tags)
 
171
  categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
172
  categorized_json = {category: tags for category, tags in classified_tags.items()}
173
 
174
+ return categorized_string, categorized_json, "" # Initialize enhanced_prompt as empty
175
 
176
+ tags = []
 
 
 
modules/florence2.py CHANGED
@@ -94,9 +94,4 @@ def update_task_dropdown(choice):
94
  if choice == 'Cascaded task':
95
  return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
96
  else:
97
- return gr.Dropdown(choices=single_task_list, value='Caption')
98
-
99
- if __name__ == "__main__":
100
- process_image()
101
- single_task_list
102
- update_task_dropdown()
 
94
  if choice == 'Cascaded task':
95
  return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
96
  else:
97
+ return gr.Dropdown(choices=single_task_list, value='Caption')
 
 
 
 
 
modules/llama_loader.py CHANGED
@@ -182,8 +182,4 @@ class llama3reorganize:
182
  except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
183
  return result
184
 
185
- llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
186
-
187
- if __name__ == "__main__":
188
- llama3reorganize()
189
- llama_list
 
182
  except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
183
  return result
184
 
185
+ llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
 
 
 
 
modules/tag_enhancer.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
3
+ import re,torch
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+
7
+ def load_models():
8
+ try:
9
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
10
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
11
+ model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
14
+ enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
15
+ except Exception as e:
16
+ print(e)
17
+ enhancer_medium = enhancer_long = enhancer_flux = None
18
+ return enhancer_medium, enhancer_long, enhancer_flux
19
+
20
+ enhancer_medium, enhancer_long, enhancer_flux = load_models()
21
+
22
+ def enhance_prompt(input_prompt, model_choice):
23
+ if model_choice == "Medium":
24
+ result = enhancer_medium("Enhance the description: " + input_prompt)
25
+ enhanced_text = result[0]['summary_text']
26
+
27
+ pattern = r'^.*?of\s+(.*?(?:\.|$))'
28
+ match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
29
+
30
+ if match:
31
+ remaining_text = enhanced_text[match.end():].strip()
32
+ modified_sentence = match.group(1).capitalize()
33
+ enhanced_text = modified_sentence + ' ' + remaining_text
34
+ elif model_choice == "Flux":
35
+ result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256)
36
+ enhanced_text = result[0]['generated_text']
37
+ else: # Long
38
+ result = enhancer_long("Enhance the description: " + input_prompt)
39
+ enhanced_text = result[0]['summary_text']
40
+
41
+ return enhanced_text
42
+
43
+ def prompt_enhancer(character: str, series: str, general: str, model_choice: str):
44
+ characters = character.split(",") if character else []
45
+ serieses = series.split(",") if series else []
46
+ generals = general.split(",") if general else []
47
+ tags = characters + serieses + generals
48
+ cprompt = ",".join(tags) if tags else ""
49
+
50
+ output = enhance_prompt(cprompt, model_choice)
51
+ prompt = cprompt + ", " + output
52
+
53
+ return prompt, gr.update(interactive=True), gr.update(interactive=True)