Multi-Tagger / app.py
Werli's picture
Update
36e1bda verified
import os
import io,copy,requests,spaces,gradio as gr,numpy as np
from transformers import AutoProcessor,AutoModelForCausalLM
from PIL import Image,ImageDraw,ImageFont
from unittest.mock import patch
import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast,time
from datetime import datetime,timezone
from collections import defaultdict
from apscheduler.schedulers.background import BackgroundScheduler
import json
from modules.classifyTags import classify_tags,process_tags
from modules.florence2 import process_image,single_task_list,update_task_dropdown
from modules.reorganizer_model import reorganizer_list,reorganizer_class
from modules.tag_enhancer import prompt_enhancer
os.environ['PYTORCH_ENABLE_MPS_FALLBACK']='1'
TITLE = "Multi-Tagger"
DESCRIPTION = """
Multi-Tagger is a versatile application that combines the Waifu Diffusion and Florence 2 models for advanced image analysis and captioning. Perfect for AI artists and enthusiasts, it offers a range of features:
- Batch processing for multiple images
- Multi-category tagging with structured tag display.
- CUDA or CPU support.
- Image tagging, various captioning tasks which includes: Caption, Detailed Caption, Object Detection with visual outputs and much more.
Example image by [me.](https://huggingface.co/Werli)
"""
# Dataset v3 series of models:
SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
VIT_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-tagger-v3"
VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3"
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
# Dataset v2 series of models:
MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
# IdolSankaku series of models:
EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1"
SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1"
# Files to download from the repos
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
kaomojis=['0_0','(o)_(o)','+_+','+_-','._.','<o>_<o>','<|>_<|>','=_=','>_<','3_3','6_9','>_o','@_@','^_^','o_o','u_u','x_x','|_|','||_||']
def parse_args()->argparse.Namespace:parser=argparse.ArgumentParser();parser.add_argument('--score-slider-step',type=float,default=.05);parser.add_argument('--score-general-threshold',type=float,default=.35);parser.add_argument('--score-character-threshold',type=float,default=.85);parser.add_argument('--share',action='store_true');return parser.parse_args()
def load_labels(dataframe)->list[str]:name_series=dataframe['name'];name_series=name_series.map(lambda x:x.replace('_',' ')if x not in kaomojis else x);tag_names=name_series.tolist();rating_indexes=list(np.where(dataframe['category']==9)[0]);general_indexes=list(np.where(dataframe['category']==0)[0]);character_indexes=list(np.where(dataframe['category']==4)[0]);return tag_names,rating_indexes,general_indexes,character_indexes
def mcut_threshold(probs):sorted_probs=probs[probs.argsort()[::-1]];difs=sorted_probs[:-1]-sorted_probs[1:];t=difs.argmax();thresh=(sorted_probs[t]+sorted_probs[t+1])/2;return thresh
class Timer:
def __init__(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
def checkpoint(self,label='Checkpoint'):now=time.perf_counter();self.checkpoints.append((label,now))
def report(self,is_clear_checkpoints=True):
max_label_length=max(len(label)for(label,_)in self.checkpoints);prev_time=self.checkpoints[0][1]
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
if is_clear_checkpoints:self.checkpoints.clear();self.checkpoint()
def report_all(self):
print('\n> Execution Time Report:');max_label_length=max(len(label)for(label,_)in self.checkpoints)if len(self.checkpoints)>0 else 0;prev_time=self.start_time
for(label,curr_time)in self.checkpoints[1:]:elapsed=curr_time-prev_time;print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds");prev_time=curr_time
total_time=self.checkpoints[-1][1]-self.start_time;print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n");self.checkpoints.clear()
def restart(self):self.start_time=time.perf_counter();self.checkpoints=[('Start',self.start_time)]
class Predictor:
def __init__(self):
self.model_target_size = None
self.last_loaded_repo = None
def download_model(self, model_repo):
csv_path = huggingface_hub.hf_hub_download(
model_repo,
LABEL_FILENAME,
)
model_path = huggingface_hub.hf_hub_download(
model_repo,
MODEL_FILENAME,
)
return csv_path, model_path
def load_model(self, model_repo):
if model_repo == self.last_loaded_repo:
return
csv_path, model_path = self.download_model(model_repo)
tags_df = pd.read_csv(csv_path)
sep_tags = load_labels(tags_df)
self.tag_names = sep_tags[0]
self.rating_indexes = sep_tags[1]
self.general_indexes = sep_tags[2]
self.character_indexes = sep_tags[3]
model = rt.InferenceSession(model_path)
_, height, width, _ = model.get_inputs()[0].shape
self.model_target_size = height
self.last_loaded_repo = model_repo
self.model = model
def prepare_image(self, path):
image = Image.open(path)
image = image.convert("RGBA")
target_size = self.model_target_size
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
# Pad image to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
# Resize
if max_dim != target_size:
padded_image = padded_image.resize(
(target_size, target_size),
Image.BICUBIC,
)
# Convert to numpy array
image_array = np.asarray(padded_image, dtype=np.float32)
# Convert PIL-native RGB to BGR
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def create_file(self, content: str, directory: str, fileName: str) -> str:
# Write the content to a file
file_path = os.path.join(directory, fileName)
if fileName.endswith('.json'):
with open(file_path, 'w', encoding="utf-8") as file:
file.write(content)
else:
with open(file_path, 'w+', encoding="utf-8") as file:
file.write(content)
return file_path
def predict(
self,
gallery,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
characters_merge_enabled,
reorganizer_model_repo,
additional_tags_prepend,
additional_tags_append,
tag_results,
progress=gr.Progress()
):
# Clear tag_results before starting a new prediction
tag_results.clear()
gallery_len = len(gallery)
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
timer = Timer() # Create a timer
progressRatio = 0.5 if reorganizer_model_repo else 1
progressTotal = gallery_len + 1
current_progress = 0
self.load_model(model_repo)
current_progress += progressRatio/progressTotal;
progress(current_progress, desc="Initialize wd model finished")
timer.checkpoint(f"Initialize wd model")
txt_infos = []
output_dir = tempfile.mkdtemp()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
sorted_general_strings = ""
# Create categorized output string
categorized_output_strings = []
rating = None
character_res = None
general_res = None
if reorganizer_model_repo:
print(f"Reorganizer load model {reorganizer_model_repo}")
reorganizer = reorganizer_class(reorganizer_model_repo, loadModel=True)
current_progress += progressRatio/progressTotal;
progress(current_progress, desc="Initialize reoganizer model finished")
timer.checkpoint(f"Initialize reoganizer model")
timer.report()
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
if prepend_list and append_list:
append_list = [item for item in append_list if item not in prepend_list]
# Dictionary to track counters for each filename
name_counters = defaultdict(int)
for idx, value in enumerate(gallery):
try:
image_path = value[0]
image_name = os.path.splitext(os.path.basename(image_path))[0]
# Increment the counter for the current name
name_counters[image_name] += 1
if name_counters[image_name] > 1:
image_name = f"{image_name}_{name_counters[image_name]:02d}"
image = self.prepare_image(image_path)
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
print(f"Gallery {idx:02d}: Starting run wd model...")
preds = self.model.run([label_name], {input_name: image})[0]
labels = list(zip(self.tag_names, preds[0].astype(float)))
# First 4 labels are actually ratings: pick one with argmax
ratings_names = [labels[i] for i in self.rating_indexes]
rating = dict(ratings_names)
# Then we have general tags: pick any where prediction confidence > threshold
general_names = [labels[i] for i in self.general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_thresh = mcut_threshold(general_probs)
general_res = [x for x in general_names if x[1] > general_thresh]
general_res = dict(general_res)
# Everything else is characters: pick any where prediction confidence > threshold
character_names = [labels[i] for i in self.character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_thresh = mcut_threshold(character_probs)
character_thresh = max(0.15, character_thresh)
character_res = [x for x in character_names if x[1] > character_thresh]
character_res = dict(character_res)
character_list = list(character_res.keys())
sorted_general_list = sorted(
general_res.items(),
key=lambda x: x[1],
reverse=True,
)
sorted_general_list = [x[0] for x in sorted_general_list]
# Remove values from character_list that already exist in sorted_general_list
character_list = [item for item in character_list if item not in sorted_general_list]
# Remove values from sorted_general_list that already exist in prepend_list or append_list
if prepend_list:
sorted_general_list = [item for item in sorted_general_list if item not in prepend_list]
if append_list:
sorted_general_list = [item for item in sorted_general_list if item not in append_list]
sorted_general_list = prepend_list + sorted_general_list + append_list
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
classified_tags, unclassified_tags = classify_tags(sorted_general_list)
# Create a single string of ALL categorized tags for the current image
categorized_output_string = ', '.join([', '.join(tags) for tags in classified_tags.values()])
categorized_output_strings.append(categorized_output_string)
# Collect all categorized output strings into a single string
final_categorized_output = ', '.join(categorized_output_strings)
# Create a .txt file for "Output (string)" and "Categorized Output (string)"
txt_content = f"Output (string): {sorted_general_strings}\nCategorized Output (string): {final_categorized_output}"
txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
txt_infos.append({"path": txt_file, "name": f"{image_name}_output.txt"})
# Create a .json file for "Categorized (tags)"
json_content = json.dumps(classified_tags, indent=4)
json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized_tags.json")
txt_infos.append({"path": json_file, "name": f"{image_name}_categorized_tags.json"})
# Save a copy of the uploaded image in PNG format
image_path = value[0]
image = Image.open(image_path)
image.save(os.path.join(output_dir, f"{image_name}.png"), format="PNG")
txt_infos.append({"path": os.path.join(output_dir, f"{image_name}.png"), "name": f"{image_name}.png"})
current_progress += progressRatio/progressTotal;
progress(current_progress, desc=f"image{idx:02d}, predict finished")
timer.checkpoint(f"image{idx:02d}, predict finished")
if reorganizer_model_repo:
print(f"Starting reorganizer...")
reorganize_strings = reorganizer.reorganize(sorted_general_strings)
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
sorted_general_strings += ",\n\n" + reorganize_strings
current_progress += progressRatio/progressTotal;
progress(current_progress, desc=f"image{idx:02d}, reorganizer finished")
timer.checkpoint(f"image{idx:02d}, reorganizer finished")
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
# Store the result in tag_results using image_path as the key
tag_results[image_path] = {
"strings": sorted_general_strings,
"strings2": categorized_output_string, # Store the categorized output string here
"classified_tags": classified_tags,
"rating": rating,
"character_res": character_res,
"general_res": general_res,
"unclassified_tags": unclassified_tags,
"enhanced_tags": "" # Initialize as empty string
}
timer.report()
except Exception as e:
print(traceback.format_exc())
print("Error predict: " + str(e))
# Zip creation logic:
download = []
if txt_infos is not None and len(txt_infos) > 0:
downloadZipPath = os.path.join(output_dir, "Multi-tagger-" + datetime.now().strftime("%Y%m%d-%H%M%S") + ".zip")
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
for info in txt_infos:
# Get file name from lookup
taggers_zip.write(info["path"], arcname=info["name"])
download.append(downloadZipPath)
# End zip creation logic
if reorganizer_model_repo:
reorganizer.release_vram()
del reorganizer
progress(1, desc=f"Predict completed")
timer.report_all() # Print all recorded times
print("Predict is complete.")
return download, sorted_general_strings, final_categorized_output, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
if not selected_state:
return selected_state
tag_result = {
"strings": "",
"strings2": "",
"classified_tags": "{}",
"rating": "",
"character_res": "",
"general_res": "",
"unclassified_tags": "{}",
"enhanced_tags": ""
}
if selected_state.value["image"]["path"] in tag_results:
tag_result = tag_results[selected_state.value["image"]["path"]]
return (selected_state.value["image"]["path"], selected_state.value["caption"]), tag_result["strings"], tag_result["strings2"], tag_result["classified_tags"], tag_result["rating"], tag_result["character_res"], tag_result["general_res"], tag_result["unclassified_tags"], tag_result["enhanced_tags"]
def append_gallery(gallery:list,image:str):
if gallery is None:gallery=[]
if not image:return gallery,None
gallery.append(image);return gallery,None
def extend_gallery(gallery:list,images):
if gallery is None:gallery=[]
if not images:return gallery
gallery.extend(images);return gallery
def remove_image_from_gallery(gallery:list,selected_image:str):
if not gallery or not selected_image:return gallery
selected_image=ast.literal_eval(selected_image)
if selected_image in gallery:gallery.remove(selected_image)
return gallery
args = parse_args()
predictor = Predictor()
dropdown_list = [
EVA02_LARGE_MODEL_DSV3_REPO,
SWINV2_MODEL_DSV3_REPO,
CONV_MODEL_DSV3_REPO,
VIT_MODEL_DSV3_REPO,
VIT_LARGE_MODEL_DSV3_REPO,
# ---
MOAT_MODEL_DSV2_REPO,
SWIN_MODEL_DSV2_REPO,
CONV_MODEL_DSV2_REPO,
CONV2_MODEL_DSV2_REPO,
VIT_MODEL_DSV2_REPO,
# ---
SWINV2_MODEL_IS_DSV1_REPO,
EVA02_LARGE_MODEL_IS_DSV1_REPO,
]
def _restart_space():
HF_TOKEN=os.getenv('HF_TOKEN')
if not HF_TOKEN:raise ValueError('HF_TOKEN environment variable is not set.')
huggingface_hub.HfApi().restart_space(repo_id='Werli/Multi-Tagger',token=HF_TOKEN,factory_reboot=False)
scheduler=BackgroundScheduler()
# Add a job to restart the space every 2 days (172800 seconds)
restart_space_job = scheduler.add_job(_restart_space, "interval", seconds=172800)
scheduler.start()
next_run_time_utc=restart_space_job.next_run_time.astimezone(timezone.utc)
NEXT_RESTART=f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
css = """
#output {height: 500px; overflow: auto; border: 1px solid #ccc;}
label.float.svelte-i3tvor {position: relative !important;}
.reduced-height.svelte-11chud3 {height: calc(80% - var(--size-10));}
"""
with gr.Blocks(title=TITLE, css=css, theme="Werli/Multi-Tagger", fill_width=True) as demo:
gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
gr.Markdown(value=DESCRIPTION)
gr.Markdown(NEXT_RESTART)
with gr.Tab(label="Waifu Diffusion"):
with gr.Row():
with gr.Column():
submit = gr.Button(value="Submit", variant="primary", size="lg")
with gr.Column(variant="panel"):
# Create an Image component for uploading images
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
with gr.Row():
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
remove_button = gr.Button("Remove Selected Image", size="sm")
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Grid of images")
model_repo = gr.Dropdown(
dropdown_list,
value=EVA02_LARGE_MODEL_DSV3_REPO,
label="Model",
)
with gr.Row():
general_thresh = gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.score_general_threshold,
label="General Tags Threshold",
scale=3,
)
general_mcut_enabled = gr.Checkbox(
value=False,
label="Use MCut threshold",
scale=1,
)
with gr.Row():
character_thresh = gr.Slider(
0,
1,
step=args.score_slider_step,
value=args.score_character_threshold,
label="Character Tags Threshold",
scale=3,
)
character_mcut_enabled = gr.Checkbox(
value=False,
label="Use MCut threshold",
scale=1,
)
with gr.Row():
characters_merge_enabled = gr.Checkbox(
value=True,
label="Merge characters into the string output",
scale=1,
)
with gr.Row():
reorganizer_model_repo = gr.Dropdown(
[None] + reorganizer_list,
value=None,
label="Reorganizer Model",
info="Use a model to create a description for you",
)
with gr.Row():
additional_tags_prepend = gr.Text(label="Prepend Additional tags (comma split)")
additional_tags_append = gr.Text(label="Append Additional tags (comma split)")
with gr.Row():
clear = gr.ClearButton(
components=[
gallery,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
characters_merge_enabled,
reorganizer_model_repo,
additional_tags_prepend,
additional_tags_append,
],
variant="secondary",
size="lg",
)
with gr.Column(variant="panel"):
download_file = gr.File(label="Download includes: All outputs* and image(s)") # 0
character_res = gr.Label(label="Output (characters)") # 1
sorted_general_strings = gr.Textbox(label="Output (string)*", show_label=True, show_copy_button=True) # 2
final_categorized_output = gr.Textbox(label="Categorized (string)* - If it's too long, select an image to display tags correctly.", show_label=True, show_copy_button=True) # 3
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary") # 4
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True) # 5
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers") # 6
categorized = gr.JSON(label="Categorized (tags)* - JSON") # 7
rating = gr.Label(label="Rating") # 8
general_res = gr.Label(label="Output (tags)") # 9
unclassified = gr.JSON(label="Unclassified (tags)") # 10
clear.add(
[
download_file,
sorted_general_strings,
final_categorized_output,
categorized,
rating,
character_res,
general_res,
unclassified,
prompt_enhancer_model,
enhanced_tags,
]
)
tag_results = gr.State({})
# Define the event listener to add the uploaded image to the gallery
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
# When the upload button is clicked, add the new images to the gallery
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
# Event to update the selected image when an image is clicked in the gallery
selected_image = gr.Textbox(label="Selected Image", visible=False)
gallery.select(get_selection_from_gallery,inputs=[gallery, tag_results],outputs=[selected_image, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, enhanced_tags])
# Event to remove a selected image from the gallery
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
# Event to for the Prompt Enhancer Button
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[final_categorized_output,prompt_enhancer_model],outputs=[enhanced_tags])
submit.click(
predictor.predict,
inputs=[
gallery,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
characters_merge_enabled,
reorganizer_model_repo,
additional_tags_prepend,
additional_tags_append,
tag_results,
],
outputs=[download_file, sorted_general_strings, final_categorized_output, categorized, rating, character_res, general_res, unclassified, tag_results,],
)
gr.Examples(
[["images/1girl.png", VIT_LARGE_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
inputs=[
image_input,
model_repo,
general_thresh,
general_mcut_enabled,
character_thresh,
character_mcut_enabled,
],
)
with gr.Tab(label="Tag Categorizer + Enhancer"):
with gr.Row():
with gr.Column(variant="panel"):
input_tags = gr.Textbox(label="Input Tags (Danbooru comma-separated)", placeholder="1girl, cat, horns, blue hair, ...")
submit_button = gr.Button(value="Submit", variant="primary", size="lg")
with gr.Column(variant="panel"):
categorized_string = gr.Textbox(label="Categorized (string)", show_label=True, show_copy_button=True, lines=8)
categorized_json = gr.JSON(label="Categorized (tags) - JSON")
submit_button.click(process_tags, inputs=[input_tags], outputs=[categorized_string, categorized_json])
with gr.Column(variant="panel"):
pe_generate_btn = gr.Button(value="ENHANCE TAGS", size="lg", variant="primary")
enhanced_tags = gr.Textbox(label="Enhanced Tags", show_label=True, show_copy_button=True)
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Medium", info="Enhance your prompts with Medium or Long answers")
pe_generate_btn.click(lambda tags,model:prompt_enhancer('','',tags,model)[0],inputs=[categorized_string,prompt_enhancer_model],outputs=[enhanced_tags])
with gr.Tab(label="Florence 2 Image Captioning"):
with gr.Row():
with gr.Column(variant="panel"):
input_img = gr.Image(label="Input Picture")
task_type = gr.Radio(choices=['Single task', 'Cascaded task'], label='Task type selector', value='Single task')
task_prompt = gr.Dropdown(choices=single_task_list, label="Task Prompt", value="Caption")
task_type.change(fn=update_task_dropdown, inputs=task_type, outputs=task_prompt)
text_input = gr.Textbox(label="Text Input (optional)")
submit_btn = gr.Button(value="Submit")
with gr.Column(variant="panel"):
output_text = gr.Textbox(label="Output Text", show_label=True, show_copy_button=True, lines=8)
output_img = gr.Image(label="Output Image")
gr.Examples(
examples=[
["images/image1.png", 'Object Detection'],
["images/image2.png", 'OCR with Region']
],
inputs=[input_img, task_prompt],
outputs=[output_text, output_img],
fn=process_image,
cache_examples=False,
label='Try examples'
)
submit_btn.click(process_image, [input_img, task_prompt, text_input], [output_text, output_img])
demo.queue(max_size=2).launch()