Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import os | |
from PIL import Image | |
import zipfile | |
import tempfile | |
import re | |
import torch | |
import spaces | |
# Check for GPU and set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Define model configurations and task prompts | |
model_configs = { | |
'gokaygokay/Florence-2-Flux': "<DESCRIPTION>", | |
'gokaygokay/Florence-2-Flux-Large': "<DESCRIPTION>", | |
'yayayaaa/florence-2-large-ft-moredetailed': "<MORE_DETAILED_CAPTION>", | |
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "<MORE_DETAILED_CAPTION>" | |
} | |
# Define a description for each model to be shown in UI | |
model_descriptions = { | |
'gokaygokay/Florence-2-Flux': "Faster version with good quality captions", | |
'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding", | |
'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions", | |
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions" | |
} | |
# Load a single model to start with | |
print("Loading Florence-2 model...") | |
model_name = 'gokaygokay/Florence-2-Flux' | |
task_prompt = model_configs[model_name] | |
# Load model without device_map | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
).eval() | |
# Move to GPU if available | |
if device == "cuda": | |
model = model.to("cuda") | |
processor = AutoProcessor.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
print(f"Successfully loaded model: {model_name}") | |
title = """<h1 align="center">Florence-2 Caption Dataset Creator</h1> | |
<p><center> | |
<a href="https://huggingface.co/gokaygokay/Florence-2-Flux-Large" target="_blank">[Florence-2 Flux Large]</a> | |
<a href="https://huggingface.co/gokaygokay/Florence-2-Flux" target="_blank">[Florence-2 Flux Base]</a> | |
<a href="https://huggingface.co/yayayaaa/florence-2-large-ft-moredetailed" target="_blank">[Florence-2 More Detailed]</a> | |
<a href="https://huggingface.co/MiaoshouAI/Florence-2-large-PromptGen-v2.0" target="_blank">[MiaoshouAI PromptGen v2.0]</a> | |
</center></p>""" | |
# Function to clean caption text | |
def clean_caption(text): | |
# Remove <pad> tokens from the end | |
text = re.sub(r'<pad>+$', '', text) | |
# Remove any extra whitespace | |
text = text.strip() | |
return text | |
# Function to load a specific model | |
def load_model(selected_model_name): | |
global model, processor, model_name, task_prompt | |
# Only reload if the model is different | |
if selected_model_name != model_name: | |
print(f"Switching to model: {selected_model_name}") | |
# Release memory from the current model | |
del model | |
torch.cuda.empty_cache() | |
# Load the new model | |
model_name = selected_model_name | |
task_prompt = model_configs[model_name] | |
# Load model without device_map | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
).eval() | |
# Move to GPU if available | |
if device == "cuda": | |
model = model.to("cuda") | |
processor = AutoProcessor.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
print(f"Successfully switched to model: {model_name}") | |
return "Model loaded successfully" | |
# Special function for MiaoshouAI model | |
def generate_miaoshou_caption(image): | |
"""Special handling for MiaoshouAI model""" | |
# Create inputs for MiaoshouAI model | |
inputs = processor( | |
text=task_prompt, | |
images=image, | |
return_tensors="pt" | |
) | |
# Move inputs to device | |
for key in inputs: | |
if isinstance(inputs[key], torch.Tensor): | |
inputs[key] = inputs[key].to(device) | |
# Generate using only input_ids and pixel_values | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=512, | |
do_sample=False, | |
num_beams=3 | |
) | |
# Decode the generated text | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
# Use the model's post-processing | |
try: | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
# Get the generated text from parsed answer | |
if isinstance(parsed_answer, dict) and task_prompt in parsed_answer: | |
return parsed_answer[task_prompt] | |
else: | |
return str(parsed_answer) | |
except Exception as e: | |
print(f"Post-processing error: {str(e)}") | |
# Fallback to regular decoding if post-processing fails | |
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Function to generate a caption for a single image | |
def generate_caption(image, selected_model_name): | |
if image is None: | |
return "Please upload an image." | |
# Check if we need to switch models | |
if selected_model_name != model_name: | |
try: | |
load_model(selected_model_name) | |
except Exception as e: | |
return f"Error loading model {selected_model_name}: {str(e)}" | |
if isinstance(image, str): | |
# Handle file path input | |
image = Image.open(image) | |
else: | |
# Handle numpy array input from gradio | |
image = Image.fromarray(image) | |
# Ensure image is RGB | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
try: | |
# Special handling for MiaoshouAI model | |
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': | |
caption = generate_miaoshou_caption(image) | |
else: | |
# Regular processing for other models | |
# Create an appropriate prompt based on the model | |
prompt = task_prompt | |
if prompt == "<DESCRIPTION>": | |
prompt = prompt + "Describe this image in great detail." | |
# Process the image | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
# Move inputs to the same device as the model | |
for key in inputs: | |
if isinstance(inputs[key], torch.Tensor): | |
inputs[key] = inputs[key].to(device) | |
# Generate the caption | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
num_beams=3, | |
repetition_penalty=1.10, | |
) | |
# Decode the generated text | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Handle post-processing for different models | |
if task_prompt == "<DESCRIPTION>": | |
# Use the post processing for Florence-2-Flux models | |
try: | |
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
decoded_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
caption = parsed_answer[task_prompt] | |
except Exception as e: | |
print(f"Error in post processing: {str(e)}") | |
caption = generated_text # Fallback to direct output | |
else: | |
# For other models, use the generated text directly | |
caption = generated_text | |
# Clean the caption to remove padding tokens | |
clean_text = clean_caption(caption) | |
return clean_text | |
except Exception as e: | |
error_msg = f"Error generating caption: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# Function to process multiple images and create a downloadable zip | |
def process_images(images, selected_model_name, add_trigger=True, trigger_word="trigger"): | |
"""Process multiple images, caption them, and create downloadable zip file""" | |
if not images: | |
return "No images uploaded.", None | |
# Check if we need to switch models | |
if selected_model_name != model_name: | |
try: | |
load_model(selected_model_name) | |
except Exception as e: | |
return f"Error loading model {selected_model_name}: {str(e)}", None | |
# Create a temporary directory to store files | |
temp_dir = tempfile.mkdtemp() | |
# Path for the zip file | |
zip_path = os.path.join(temp_dir, "captions_dataset.zip") | |
results = [] | |
try: | |
# Create a zip file | |
with zipfile.ZipFile(zip_path, 'w') as zipf: | |
for img_file in images: | |
try: | |
# Get file path and extract filename | |
img_path = img_file.name | |
base_name = os.path.basename(img_path) | |
file_name, file_ext = os.path.splitext(base_name) | |
# Skip unsupported formats | |
if file_ext.lower() not in ['.jpg', '.jpeg', '.png']: | |
results.append(f"β οΈ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)") | |
continue | |
# Generate caption | |
# Open the image once | |
image = Image.open(img_path) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Use the same caption generation logic as in generate_caption | |
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0': | |
caption = generate_miaoshou_caption(image) | |
else: | |
# Regular processing for other models | |
# Create an appropriate prompt based on the model | |
prompt = task_prompt | |
if prompt == "<DESCRIPTION>": | |
prompt = prompt + "Describe this image in great detail." | |
# Process the image | |
inputs = processor(text=prompt, images=image, return_tensors="pt") | |
# Move inputs to the same device as the model | |
for key in inputs: | |
if isinstance(inputs[key], torch.Tensor): | |
inputs[key] = inputs[key].to(device) | |
# Generate the caption | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
num_beams=3, | |
repetition_penalty=1.10, | |
) | |
# Decode the generated text | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Handle post-processing for different models | |
if task_prompt == "<DESCRIPTION>": | |
# Use the post processing for Florence-2-Flux models | |
try: | |
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
decoded_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
caption = parsed_answer[task_prompt] | |
except Exception as e: | |
print(f"Error in post processing: {str(e)}") | |
caption = generated_text # Fallback to direct output | |
else: | |
# For other models, use the generated text directly | |
caption = generated_text | |
# Clean caption and add trigger if needed | |
caption = clean_caption(caption) | |
if add_trigger: | |
caption = f"[{trigger_word}] {caption}" | |
# Create a text file with the caption | |
txt_filename = f"{file_name}.txt" | |
txt_path = os.path.join(temp_dir, txt_filename) | |
with open(txt_path, "w", encoding="utf-8") as f: | |
f.write(caption) | |
# Add the text file to the zip | |
zipf.write(txt_path, txt_filename) | |
# Add the image to the zip | |
zipf.write(img_path, base_name) | |
# Add to results | |
caption_preview = f"{caption[:50]}..." if len(caption) > 50 else caption | |
results.append(f"β {base_name} β {file_name}.txt: {caption_preview}") | |
except Exception as e: | |
results.append(f"β Error processing {base_name}: {str(e)}") | |
except Exception as e: | |
error_msg = f"Error creating zip file: {str(e)}" | |
print(error_msg) | |
return error_msg, None | |
# Format results | |
summary = f"Processed {len(results)} images. Ready for download.\n\n" | |
result_text = summary + "\n".join(results) | |
return result_text, zip_path | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.HTML(title) | |
with gr.Tabs(): | |
# Single image preview tab | |
with gr.TabItem("Preview Caption"): | |
with gr.Row(): | |
with gr.Column(): | |
input_img = gr.Image(label="Input Picture") | |
model_selector = gr.Dropdown( | |
choices=list(model_configs.keys()), | |
label="Model", | |
value=model_name | |
) | |
preview_btn = gr.Button(value="Generate Caption") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Caption", lines=8) | |
preview_btn.click(generate_caption, [input_img, model_selector], [output_text]) | |
# Dataset creation tab | |
with gr.TabItem("Create Dataset"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
batch_images = gr.File( | |
file_count="multiple", | |
label="Upload Multiple Images (JPG, JPEG, PNG)" | |
) | |
batch_model_selector = gr.Dropdown( | |
choices=list(model_configs.keys()), | |
label="Model", | |
value=model_name | |
) | |
with gr.Row(): | |
add_trigger = gr.Checkbox(label="Add Trigger Word", value=True) | |
trigger_word = gr.Textbox( | |
label="Trigger Word", | |
placeholder="trigger", | |
value="trigger" | |
) | |
process_btn = gr.Button(value="Process Images") | |
with gr.Column(scale=1): | |
batch_results = gr.Textbox(label="Processing Results", lines=15) | |
download_output = gr.File(label="Download Dataset (Images & Captions)") | |
# Connect process button | |
process_btn.click( | |
fn=process_images, | |
inputs=[batch_images, batch_model_selector, add_trigger, trigger_word], | |
outputs=[batch_results, download_output] | |
) | |
# Instructions with model information | |
with gr.Accordion("Instructions & Model Information", open=True): | |
gr.Markdown(""" | |
## Instructions | |
### Preview Caption | |
- Upload a single image and generate a detailed caption | |
- Try different models to compare results | |
### Create Dataset | |
- Upload multiple images to process them all at once | |
- All images will be captioned and saved with matching .txt files | |
- By default, captions include `[trigger]` at the beginning (you can modify the trigger word) | |
- Click "Process Images" to generate captions and create a downloadable dataset | |
- Use the download button to get a ZIP file containing all images and caption files | |
## Models Available | |
""") | |
# Create a markdown description for each model | |
model_md = "" | |
for model_id, description in model_descriptions.items(): | |
model_short_name = model_id.split('/')[-1] | |
model_md += f"- **{model_short_name}**: {description}\n" | |
gr.Markdown(model_md) | |
# Add special note for MiaoshouAI model | |
gr.Markdown(""" | |
### MiaoshouAI/Florence-2-large-PromptGen-v2.0 Features | |
- Improved caption quality for detailed captions | |
- Memory efficient (requires only ~1GB VRAM) | |
- Fast generation while maintaining high quality | |
- Supports multiple caption formats including detailed captions, tags, and analysis | |
Supported image formats: JPG, JPEG, PNG | |
""") | |
if __name__ == "__main__": | |
demo.launch() |