Spaces:
Running
Running
import os | |
import io,copy,requests,spaces,gradio as gr,numpy as np | |
from transformers import AutoProcessor,AutoModelForCausalLM | |
from PIL import Image,ImageDraw,ImageFont | |
from unittest.mock import patch | |
import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time | |
from datetime import datetime,timezone | |
from collections import defaultdict | |
from apscheduler.schedulers.background import BackgroundScheduler | |
import json | |
from modules.classifyTags import classify_tags,process_tags | |
from modules.florence2 import process_image,single_task_list,update_task_dropdown | |
from modules.reorganizer_model import reorganizer_list,reorganizer_class | |
from modules.tag_enhancer import prompt_enhancer | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1' | |
TITLE = "Multi-Tagger" | |
DESCRIPTION = """ | |
Multi-Tagger is a versatile application that combines the Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Perfect for AI artists and enthusiasts, it offers a range of features: | |
- Batch processing for multiple images | |
- Multi-category tagging with structured tag display. | |
- CUDA or CPU support. | |
- Image tagging, various captioning tasks which includes: Caption, Detailed Caption, Object Detection with visual outputs and much more. | |
Example image by [me.](https://huggingface.co/Werli) | |
""" | |
# Dataset v3 series of models: | |
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" | |
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3" | |
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" | |
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
# Dataset v2 series of models: | |
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | |
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
# IdolSankaku series of models: | |
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1" | |
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1" | |
# Files to download from the repos | |
MODEL_FILENAME = "model.onnx" | |
LABEL_FILENAME = "selected_tags.csv" | |
kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||'] | |
def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args() | |
def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes | |
def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh | |
class Timer: | |
def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)] | |
def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now)) | |
def report(self,is_clear_checkpoints=True): | |
max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1] | |
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time | |
if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint() | |
def report_all(self): | |
print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time | |
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time | |
total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear() | |
def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)] | |
class Predictor: | |
def __init__(self): | |
self.model_target_size = None | |
self.last_loaded_repo = None | |
def download_model(self, model_repo): | |
csv_path = huggingface_hub.hf_hub_download( | |
model_repo, | |
LABEL_FILENAME, | |
) | |
model_path = huggingface_hub.hf_hub_download( | |
model_repo, | |
MODEL_FILENAME, | |
) | |
return csv_path, model_path | |
def load_model(self, model_repo): | |
if model_repo == self.last_loaded_repo: | |
return | |
csv_path, model_path = self.download_model(model_repo) | |
tags_df = pd.read_csv(csv_path) | |
sep_tags = load_labels(tags_df) | |
self.tag_names = sep_tags[0] | |
self.rating_indexes = sep_tags[1] | |
self.general_indexes = sep_tags[2] | |
self.character_indexes = sep_tags[3] | |
model = rt.InferenceSession(model_path) | |
_, height, width, _ = model.get_inputs()[0].shape | |
self.model_target_size = height | |
self.last_loaded_repo = model_repo | |
self.model = model | |
def prepare_image(self, path): | |
image = Image.open(path) | |
image = image.convert("RGBA") | |
target_size = self.model_target_size | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
# Pad image to square | |
image_shape = image.size | |
max_dim = max(image_shape) | |
pad_left = (max_dim - image_shape[0]) // 2 | |
pad_top = (max_dim - image_shape[1]) // 2 | |
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
padded_image.paste(image, (pad_left, pad_top)) | |
# Resize | |
if max_dim != target_size: | |
padded_image = padded_image.resize( | |
(target_size, target_size), | |
Image.BICUBIC, | |
) | |
# Convert to numpy array | |
image_array = np.asarray(padded_image, dtype=np.float32) | |
# Convert PIL-native RGB to BGR | |
image_array = image_array[:, :, ::-1] | |
return np.expand_dims(image_array, axis=0) | |
def create_file(self, content: str, directory: str, fileName: str) -> str: | |
# Write the content to a file | |
file_path = os.path.join(directory, fileName) | |
if fileName.endswith('.json'): | |
with open(file_path, 'w', encoding="utf-8") as file: | |
file.write(content) | |
else: | |
with open(file_path, 'w+', encoding="utf-8") as file: | |
file.write(content) | |
return file_path | |
def predict( | |
self, | |
gallery, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
characters_merge_enabled, | |
reorganizer_model_repo, | |
additional_tags_prepend, | |
additional_tags_append, | |
tag_results, | |
progress=gr.Progress() | |
): | |
# Clear tag_results before starting a new prediction | |
tag_results.clear() | |
gallery_len = len(gallery) | |
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}") | |
timer = Timer() # Create a timer | |
progressRatio = 0.5 if reorganizer_model_repo else 1 | |
progressTotal = gallery_len + 1 | |
current_progress = 0 | |
self.load_model(model_repo) | |
current_progress += progressRatio/progressTotal; | |
progress(current_progress, desc="Initialize wd model finished") | |
timer.checkpoint(f"Initialize wd model") | |
txt_infos = [] | |
output_dir = tempfile.mkdtemp() | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
sorted_general_strings = "" | |
# Create categorized output string | |
categorized_output_strings = [] | |
rating = None | |
character_res = None | |
general_res = None | |
if reorganizer_model_repo: | |
print(f"Reorganizer load model {reorganizer_model_repo}") | |
reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True) | |
current_progress += progressRatio/progressTotal; | |
progress(current_progress, desc="Initialize reoganizer model finished") | |
timer.checkpoint(f"Initialize reoganizer model") | |
timer.report() | |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()] | |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()] | |
if prepend_list and append_list: | |
append_list = [item for item in append_list if item not in prepend_list] | |
# Dictionary to track counters for each filename | |
name_counters = defaultdict(int) | |
for idx, value in enumerate(gallery): | |
try: | |
image_path = value[0] | |
image_name = os.path.splitext(os.path.basename(image_path))[0] | |
# Increment the counter for the current name | |
name_counters[image_name] += 1 | |
if name_counters[image_name] > 1: | |
image_name = f"{image_name}_{name_counters[image_name]:02d}" | |
image = self.prepare_image(image_path) | |
input_name = self.model.get_inputs()[0].name | |
label_name = self.model.get_outputs()[0].name | |
print(f"Gallery {idx:02d}: Starting run wd model...") | |
preds = self.model.run([label_name], {input_name: image})[0] | |
labels = list(zip(self.tag_names, preds[0].astype(float))) | |
# First 4 labels are actually ratings: pick one with argmax | |
ratings_names = [labels[i] for i in self.rating_indexes] | |
rating = dict(ratings_names) | |
# Then we have general tags: pick any where prediction confidence > threshold | |
general_names = [labels[i] for i in self.general_indexes] | |
if general_mcut_enabled: | |
general_probs = np.array([x[1] for x in general_names]) | |
general_thresh = mcut_threshold(general_probs) | |
general_res = [x for x in general_names if x[1] > general_thresh] | |
general_res = dict(general_res) | |
# Everything else is characters: pick any where prediction confidence > threshold | |
character_names = [labels[i] for i in self.character_indexes] | |
if character_mcut_enabled: | |
character_probs = np.array([x[1] for x in character_names]) | |
character_thresh = mcut_threshold(character_probs) | |
character_thresh = max(0.15, character_thresh) | |
character_res = [x for x in character_names if x[1] > character_thresh] | |
character_res = dict(character_res) | |
character_list = list(character_res.keys()) | |
sorted_general_list = sorted( | |
general_res.items(), | |
key=lambda x: x[1], | |
reverse=True, | |
) | |
sorted_general_list = [x[0] for x in sorted_general_list] | |
# Remove values from character_list that already exist in sorted_general_list | |
character_list = [item for item in character_list if item not in sorted_general_list] | |
# Remove values from sorted_general_list that already exist in prepend_list or append_list | |
if prepend_list: | |
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list] | |
if append_list: | |
sorted_general_list = [item for item in sorted_general_list if item not in append_list] | |
sorted_general_list = prepend_list + sorted_general_list + append_list | |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)") | |
classified_tags, unclassified_tags = classify_tags(sorted_general_list) | |
# Create a single string of ALL categorized tags for the current image | |
categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()]) | |
categorized_output_strings.append(categorized_output_string) | |
# Collect all categorized output strings into a single string | |
final_categorized_output = ', '.join(categorized_output_strings) | |
# Create a .txt file for "Output (string)" and "Categorized Output (string)" | |
txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}" | |
txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt") | |
txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"}) | |
# Create a .json file for "Categorized (tags)" | |
json_content = json.dumps(classified_tags, indent=4) | |
json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json") | |
txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"}) | |
# Save a copy of the uploaded image in PNG format | |
image_path = value[0] | |
image = Image.open(image_path) | |
image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG") | |
txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"}) | |
current_progress += progressRatio/progressTotal; | |
progress(current_progress, desc=f"image{idx:02d}, predict finished") | |
timer.checkpoint(f"image{idx:02d}, predict finished") | |
if reorganizer_model_repo: | |
print(f"Starting reorganizer...") | |
reorganize_strings = reorganizer.reorganize(sorted_general_strings) | |
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings) | |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings) | |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings) | |
sorted_general_strings += ",\n\n" + reorganize_strings | |
current_progress += progressRatio/progressTotal; | |
progress(current_progress, desc=f"image{idx:02d}, reorganizer finished") | |
timer.checkpoint(f"image{idx:02d}, reorganizer finished") | |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt") | |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"}) | |
# Store the result in tag_results using image_path as the key | |
tag_results[image_path] = { | |
"strings": sorted_general_strings, | |
"strings2": categorized_output_string, # Store the categorized output string here | |
"classified_tags": classified_tags, | |
"rating": rating, | |
"character_res": character_res, | |
"general_res": general_res, | |
"unclassified_tags": unclassified_tags, | |
"enhanced_tags": "" # Initialize as empty string | |
} | |
timer.report() | |
except Exception as e: | |
print(traceback.format_exc()) | |
print("Error predict: " + str(e)) | |
# Zip creation logic: | |
download = [] | |
if txt_infos is not None and len(txt_infos) > 0: | |
downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip") | |
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip: | |
for info in txt_infos: | |
# Get file name from lookup | |
taggers_zip.write(info["path"], arcname=info["name"]) | |
download.append(downloadZipPath) | |
# End zip creation logic | |
if reorganizer_model_repo: | |
reorganizer.release_vram() | |
del reorganizer | |
progress(1, desc=f"Predict completed") | |
timer.report_all() # Print all recorded times | |
print("Predict is complete.") | |
return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results | |
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData): | |
if not selected_state: | |
return selected_state | |
tag_result = { | |
"strings": "", | |
"strings2": "", | |
"classified_tags": "{}", | |
"rating": "", | |
"character_res": "", | |
"general_res": "", | |
"unclassified_tags": "{}", | |
"enhanced_tags": "" | |
} | |
if selected_state.value["image"]["path"] in tag_results: | |
tag_result = tag_results[selected_state.value["image"]["path"]] | |
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"] | |
def append_gallery(gallery:list,image:str): | |
if gallery is None:gallery=[] | |
if not image:return gallery,None | |
gallery.append(image);return gallery,None | |
def extend_gallery(gallery:list,images): | |
if gallery is None:gallery=[] | |
if not images:return gallery | |
gallery.extend(images);return gallery | |
def remove_image_from_gallery(gallery:list,selected_image:str): | |
if not gallery or not selected_image:return gallery | |
selected_image=ast.literal_eval(selected_image) | |
if selected_image in gallery:gallery.remove(selected_image) | |
return gallery | |
args = parse_args() | |
predictor = Predictor() | |
dropdown_list = [ | |
EVA02_LARGE_MODEL_DSV3_REPO, | |
SWINV2_MODEL_DSV3_REPO, | |
CONV_MODEL_DSV3_REPO, | |
VIT_MODEL_DSV3_REPO, | |
VIT_LARGE_MODEL_DSV3_REPO, | |
# --- | |
MOAT_MODEL_DSV2_REPO, | |
SWIN_MODEL_DSV2_REPO, | |
CONV_MODEL_DSV2_REPO, | |
CONV2_MODEL_DSV2_REPO, | |
VIT_MODEL_DSV2_REPO, | |
# --- | |
SWINV2_MODEL_IS_DSV1_REPO, | |
EVA02_LARGE_MODEL_IS_DSV1_REPO, | |
] | |
def _restart_space(): | |
HF_TOKEN=os.getenv('HF_TOKEN') | |
if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.') | |
huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False) | |
scheduler=BackgroundScheduler() | |
# Add a job to restart the space every 2 days (172800 seconds) | |
restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800) | |
scheduler.start() | |
next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc) | |
NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process." | |
css = """ | |
#output {height: 500px; overflow: auto; border: 1px solid #ccc;} | |
label.float.svelte-i3tvor {position: relative !important;} | |
.reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));} | |
""" | |
with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo: | |
gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>") | |
gr.Markdown(value=DESCRIPTION) | |
gr.Markdown(NEXT_RESTART) | |
with gr.Tab(label="Waifu Diffusion"): | |
with gr.Row(): | |
with gr.Column(): | |
submit = gr.Button(value="Submit", variant="primary", size="lg") | |
with gr.Column(variant="panel"): | |
# Create an Image component for uploading images | |
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150) | |
with gr.Row(): | |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm") | |
remove_button = gr.Button("Remove Selected Image", size="sm") | |
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images") | |
model_repo = gr.Dropdown( | |
dropdown_list, | |
value=EVA02_LARGE_MODEL_DSV3_REPO, | |
label="Model", | |
) | |
with gr.Row(): | |
general_thresh = gr.Slider( | |
0, | |
1, | |
step=args.score_slider_step, | |
value=args.score_general_threshold, | |
label="General Tags Threshold", | |
scale=3, | |
) | |
general_mcut_enabled = gr.Checkbox( | |
value=False, | |
label="Use MCut threshold", | |
scale=1, | |
) | |
with gr.Row(): | |
character_thresh = gr.Slider( | |
0, | |
1, | |
step=args.score_slider_step, | |
value=args.score_character_threshold, | |
label="Character Tags Threshold", | |
scale=3, | |
) | |
character_mcut_enabled = gr.Checkbox( | |
value=False, | |
label="Use MCut threshold", | |
scale=1, | |
) | |
with gr.Row(): | |
characters_merge_enabled = gr.Checkbox( | |
value=True, | |
label="Merge characters into the string output", | |
scale=1, | |
) | |
with gr.Row(): | |
reorganizer_model_repo = gr.Dropdown( | |
[None] + reorganizer_list, | |
value=None, | |
label="Reorganizer Model", | |
info="Use a model to create a description for you", | |
) | |
with gr.Row(): | |
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)") | |
additional_tags_append = gr.Text(label="Append Additional tags (comma split)") | |
with gr.Row(): | |
clear = gr.ClearButton( | |
components=[ | |
gallery, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
characters_merge_enabled, | |
reorganizer_model_repo, | |
additional_tags_prepend, | |
additional_tags_append, | |
], | |
variant="secondary", | |
size="lg", | |
) | |
with gr.Column(variant="panel"): | |
download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0 | |
character_res = gr.Label(label="Output (characters)") # 1 | |
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2 | |
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3 | |
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4 | |
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5 | |
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6 | |
categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7 | |
rating = gr.Label(label="Rating") # 8 | |
general_res = gr.Label(label="Output (tags)") # 9 | |
unclassified = gr.JSON(label="Unclassified (tags)") # 10 | |
clear.add( | |
[ | |
download_file, | |
sorted_general_strings, | |
final_categorized_output, | |
categorized, | |
rating, | |
character_res, | |
general_res, | |
unclassified, | |
prompt_enhancer_model, | |
enhanced_tags, | |
] | |
) | |
tag_results = gr.State({}) | |
# Define the event listener to add the uploaded image to the gallery | |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input]) | |
# When the upload button is clicked, add the new images to the gallery | |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery) | |
# Event to update the selected image when an image is clicked in the gallery | |
selected_image = gr.Textbox(label="Selected Image", visible=False) | |
gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags]) | |
# Event to remove a selected image from the gallery | |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery) | |
# Event to for the Prompt Enhancer Button | |
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags]) | |
submit.click( | |
predictor.predict, | |
inputs=[ | |
gallery, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
characters_merge_enabled, | |
reorganizer_model_repo, | |
additional_tags_prepend, | |
additional_tags_append, | |
tag_results, | |
], | |
outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,], | |
) | |
gr.Examples( | |
[["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]], | |
inputs=[ | |
image_input, | |
model_repo, | |
general_thresh, | |
general_mcut_enabled, | |
character_thresh, | |
character_mcut_enabled, | |
], | |
) | |
with gr.Tab(label="Tag Categorizer + Enhancer"): | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...") | |
submit_button = gr.Button(value="Submit", variant="primary", size="lg") | |
with gr.Column(variant="panel"): | |
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8) | |
categorized_json = gr.JSON(label="Categorized (tags) - JSON") | |
submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json]) | |
with gr.Column(variant="panel"): | |
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") | |
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) | |
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") | |
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags]) | |
with gr.Tab(label="Florence 2 Image Captioning"): | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
input_img = gr.Image(label="Input Picture") | |
task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task') | |
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption") | |
task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt) | |
text_input = gr.Textbox(label="Text Input (optional)") | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(variant="panel"): | |
output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8) | |
output_img = gr.Image(label="Output Image") | |
gr.Examples( | |
examples=[ | |
["images/image1.png", 'Object Detection'], | |
["images/image2.png", 'OCR with Region'] | |
], | |
inputs=[input_img, task_prompt], | |
outputs=[output_text, output_img], | |
fn=process_image, | |
cache_examples=False, | |
label='Try examples' | |
) | |
submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img]) | |
demo.queue(max_size=2).launch() |