Spaces:
Running
Running
Upload 5 files
Browse filesAdded "Prompt Enhancer" from [John6666/danbooru-tags-transformer-v2-with-wd-tagger](https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger/blob/main/tagger/promptenhancer.py) (Thanks!) and cleared some code.
- app.py +32 -18
- modules/classifyTags.py +2 -5
- modules/florence2.py +1 -6
- modules/llama_loader.py +1 -5
- modules/tag_enhancer.py +53 -0
app.py
CHANGED
@@ -11,6 +11,7 @@ import json
|
|
11 |
from modules.classifyTags import classify_tags,process_tags
|
12 |
from modules.florence2 import process_image,single_task_list,update_task_dropdown
|
13 |
from modules.llama_loader import llama_list,llama3reorganize
|
|
|
14 |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
|
15 |
|
16 |
TITLE = "Multi-Tagger"
|
@@ -249,9 +250,9 @@ class Predictor:
|
|
249 |
reverse=True,
|
250 |
)
|
251 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
252 |
-
#Remove values from character_list that already exist in sorted_general_list
|
253 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
254 |
-
#Remove values from sorted_general_list that already exist in prepend_list or append_list
|
255 |
if prepend_list:
|
256 |
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
|
257 |
if append_list:
|
@@ -312,7 +313,8 @@ class Predictor:
|
|
312 |
"rating": rating,
|
313 |
"character_res": character_res,
|
314 |
"general_res": general_res,
|
315 |
-
"unclassified_tags": unclassified_tags
|
|
|
316 |
}
|
317 |
|
318 |
timer.report()
|
@@ -348,11 +350,12 @@ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state:
|
|
348 |
"rating": "",
|
349 |
"character_res": "",
|
350 |
"general_res": "",
|
351 |
-
"unclassified_tags": "{}"
|
|
|
352 |
}
|
353 |
if selected_state.value["image"]["path"] in tag_results:
|
354 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
355 |
-
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"]
|
356 |
def append_gallery(gallery:list,image:str):
|
357 |
if gallery is None:gallery=[]
|
358 |
if not image:return gallery,None
|
@@ -417,7 +420,6 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
417 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
418 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
419 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
|
420 |
-
|
421 |
model_repo = gr.Dropdown(
|
422 |
dropdown_list,
|
423 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
@@ -485,14 +487,17 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
485 |
size="lg",
|
486 |
)
|
487 |
with gr.Column(variant="panel"):
|
488 |
-
download_file = gr.File(label="Download includes: All outputs* and image(s)")
|
489 |
-
character_res = gr.Label(label="Output (characters)")
|
490 |
-
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True)
|
491 |
-
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True)
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
|
|
|
|
|
|
496 |
clear.add(
|
497 |
[
|
498 |
download_file,
|
@@ -503,8 +508,10 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
503 |
character_res,
|
504 |
general_res,
|
505 |
unclassified,
|
|
|
|
|
506 |
]
|
507 |
-
)
|
508 |
tag_results = gr.State({})
|
509 |
# Define the event listener to add the uploaded image to the gallery
|
510 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
@@ -512,9 +519,11 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
512 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
513 |
# Event to update the selected image when an image is clicked in the gallery
|
514 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
515 |
-
gallery.select(get_selection_from_gallery,
|
516 |
# Event to remove a selected image from the gallery
|
517 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
|
|
|
|
518 |
submit.click(
|
519 |
predictor.predict,
|
520 |
inputs=[
|
@@ -543,7 +552,7 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
543 |
character_mcut_enabled,
|
544 |
],
|
545 |
)
|
546 |
-
with gr.Tab(label="Tag Categorizer"):
|
547 |
with gr.Row():
|
548 |
with gr.Column(variant="panel"):
|
549 |
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
|
@@ -551,7 +560,12 @@ with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True
|
|
551 |
with gr.Column(variant="panel"):
|
552 |
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
|
553 |
categorized_json = gr.JSON(label="Categorized (tags) - JSON")
|
554 |
-
|
|
|
|
|
|
|
|
|
|
|
555 |
with gr.Tab(label="Florence 2 Image Captioning"):
|
556 |
with gr.Row():
|
557 |
with gr.Column(variant="panel"):
|
|
|
11 |
from modules.classifyTags import classify_tags,process_tags
|
12 |
from modules.florence2 import process_image,single_task_list,update_task_dropdown
|
13 |
from modules.llama_loader import llama_list,llama3reorganize
|
14 |
+
from modules.tag_enhancer import prompt_enhancer
|
15 |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
|
16 |
|
17 |
TITLE = "Multi-Tagger"
|
|
|
250 |
reverse=True,
|
251 |
)
|
252 |
sorted_general_list = [x[0] for x in sorted_general_list]
|
253 |
+
# Remove values from character_list that already exist in sorted_general_list
|
254 |
character_list = [item for item in character_list if item not in sorted_general_list]
|
255 |
+
# Remove values from sorted_general_list that already exist in prepend_list or append_list
|
256 |
if prepend_list:
|
257 |
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
|
258 |
if append_list:
|
|
|
313 |
"rating": rating,
|
314 |
"character_res": character_res,
|
315 |
"general_res": general_res,
|
316 |
+
"unclassified_tags": unclassified_tags,
|
317 |
+
"enhanced_tags": "" # Initialize as empty string
|
318 |
}
|
319 |
|
320 |
timer.report()
|
|
|
350 |
"rating": "",
|
351 |
"character_res": "",
|
352 |
"general_res": "",
|
353 |
+
"unclassified_tags": "{}",
|
354 |
+
"enhanced_tags": ""
|
355 |
}
|
356 |
if selected_state.value["image"]["path"] in tag_results:
|
357 |
tag_result = tag_results[selected_state.value["image"]["path"]]
|
358 |
+
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
|
359 |
def append_gallery(gallery:list,image:str):
|
360 |
if gallery is None:gallery=[]
|
361 |
if not image:return gallery,None
|
|
|
420 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
421 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
422 |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
|
|
|
423 |
model_repo = gr.Dropdown(
|
424 |
dropdown_list,
|
425 |
value=EVA02_LARGE_MODEL_DSV3_REPO,
|
|
|
487 |
size="lg",
|
488 |
)
|
489 |
with gr.Column(variant="panel"):
|
490 |
+
download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
|
491 |
+
character_res = gr.Label(label="Output (characters)") # 1
|
492 |
+
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
|
493 |
+
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
|
494 |
+
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
|
495 |
+
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
|
496 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
|
497 |
+
categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
|
498 |
+
rating = gr.Label(label="Rating") # 8
|
499 |
+
general_res = gr.Label(label="Output (tags)") # 9
|
500 |
+
unclassified = gr.JSON(label="Unclassified (tags)") # 10
|
501 |
clear.add(
|
502 |
[
|
503 |
download_file,
|
|
|
508 |
character_res,
|
509 |
general_res,
|
510 |
unclassified,
|
511 |
+
prompt_enhancer_model,
|
512 |
+
enhanced_tags,
|
513 |
]
|
514 |
+
)
|
515 |
tag_results = gr.State({})
|
516 |
# Define the event listener to add the uploaded image to the gallery
|
517 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
|
|
519 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
520 |
# Event to update the selected image when an image is clicked in the gallery
|
521 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
522 |
+
gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
|
523 |
# Event to remove a selected image from the gallery
|
524 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
525 |
+
# Event to for the Prompt Enhancer Button
|
526 |
+
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
|
527 |
submit.click(
|
528 |
predictor.predict,
|
529 |
inputs=[
|
|
|
552 |
character_mcut_enabled,
|
553 |
],
|
554 |
)
|
555 |
+
with gr.Tab(label="Tag Categorizer + Enhancer"):
|
556 |
with gr.Row():
|
557 |
with gr.Column(variant="panel"):
|
558 |
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
|
|
|
560 |
with gr.Column(variant="panel"):
|
561 |
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
|
562 |
categorized_json = gr.JSON(label="Categorized (tags) - JSON")
|
563 |
+
submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
|
564 |
+
with gr.Column(variant="panel"):
|
565 |
+
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
|
566 |
+
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
|
567 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
|
568 |
+
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
|
569 |
with gr.Tab(label="Florence 2 Image Captioning"):
|
570 |
with gr.Row():
|
571 |
with gr.Column(variant="panel"):
|
modules/classifyTags.py
CHANGED
@@ -171,9 +171,6 @@ def process_tags(input_tags: str):
|
|
171 |
categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
|
172 |
categorized_json = {category: tags for category, tags in classified_tags.items()}
|
173 |
|
174 |
-
return categorized_string, categorized_json
|
175 |
|
176 |
-
tags = []
|
177 |
-
if __name__ == "__main__":
|
178 |
-
classify_tags (tags, True)
|
179 |
-
process_tags(input_tags)
|
|
|
171 |
categorized_string = ', '.join([tag for category in classified_tags.values() for tag in category])
|
172 |
categorized_json = {category: tags for category, tags in classified_tags.items()}
|
173 |
|
174 |
+
return categorized_string, categorized_json, "" # Initialize enhanced_prompt as empty
|
175 |
|
176 |
+
tags = []
|
|
|
|
|
|
modules/florence2.py
CHANGED
@@ -94,9 +94,4 @@ def update_task_dropdown(choice):
|
|
94 |
if choice == 'Cascaded task':
|
95 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
96 |
else:
|
97 |
-
return gr.Dropdown(choices=single_task_list, value='Caption')
|
98 |
-
|
99 |
-
if __name__ == "__main__":
|
100 |
-
process_image()
|
101 |
-
single_task_list
|
102 |
-
update_task_dropdown()
|
|
|
94 |
if choice == 'Cascaded task':
|
95 |
return gr.Dropdown(choices=cascaded_task_list, value='Caption + Grounding')
|
96 |
else:
|
97 |
+
return gr.Dropdown(choices=single_task_list, value='Caption')
|
|
|
|
|
|
|
|
|
|
modules/llama_loader.py
CHANGED
@@ -182,8 +182,4 @@ class llama3reorganize:
|
|
182 |
except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
|
183 |
return result
|
184 |
|
185 |
-
llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
|
186 |
-
|
187 |
-
if __name__ == "__main__":
|
188 |
-
llama3reorganize()
|
189 |
-
llama_list
|
|
|
182 |
except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
|
183 |
return result
|
184 |
|
185 |
+
llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]
|
|
|
|
|
|
|
|
modules/tag_enhancer.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline,AutoTokenizer,AutoModelForSeq2SeqLM
|
3 |
+
import re,torch
|
4 |
+
|
5 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
6 |
+
|
7 |
+
def load_models():
|
8 |
+
try:
|
9 |
+
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
|
10 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
11 |
+
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
14 |
+
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
15 |
+
except Exception as e:
|
16 |
+
print(e)
|
17 |
+
enhancer_medium = enhancer_long = enhancer_flux = None
|
18 |
+
return enhancer_medium, enhancer_long, enhancer_flux
|
19 |
+
|
20 |
+
enhancer_medium, enhancer_long, enhancer_flux = load_models()
|
21 |
+
|
22 |
+
def enhance_prompt(input_prompt, model_choice):
|
23 |
+
if model_choice == "Medium":
|
24 |
+
result = enhancer_medium("Enhance the description: " + input_prompt)
|
25 |
+
enhanced_text = result[0]['summary_text']
|
26 |
+
|
27 |
+
pattern = r'^.*?of\s+(.*?(?:\.|$))'
|
28 |
+
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL)
|
29 |
+
|
30 |
+
if match:
|
31 |
+
remaining_text = enhanced_text[match.end():].strip()
|
32 |
+
modified_sentence = match.group(1).capitalize()
|
33 |
+
enhanced_text = modified_sentence + ' ' + remaining_text
|
34 |
+
elif model_choice == "Flux":
|
35 |
+
result = enhancer_flux("enhance prompt: " + input_prompt, max_length=256)
|
36 |
+
enhanced_text = result[0]['generated_text']
|
37 |
+
else: # Long
|
38 |
+
result = enhancer_long("Enhance the description: " + input_prompt)
|
39 |
+
enhanced_text = result[0]['summary_text']
|
40 |
+
|
41 |
+
return enhanced_text
|
42 |
+
|
43 |
+
def prompt_enhancer(character: str, series: str, general: str, model_choice: str):
|
44 |
+
characters = character.split(",") if character else []
|
45 |
+
serieses = series.split(",") if series else []
|
46 |
+
generals = general.split(",") if general else []
|
47 |
+
tags = characters + serieses + generals
|
48 |
+
cprompt = ",".join(tags) if tags else ""
|
49 |
+
|
50 |
+
output = enhance_prompt(cprompt, model_choice)
|
51 |
+
prompt = cprompt + ", " + output
|
52 |
+
|
53 |
+
return prompt, gr.update(interactive=True), gr.update(interactive=True)
|