vinhtruong3's picture
Update app.py
f18cbce verified
raw
history blame contribute delete
18.1 kB
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
import os
from PIL import Image
import zipfile
import tempfile
import re
import torch
import spaces
# Check for GPU and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Define model configurations and task prompts
model_configs = {
'gokaygokay/Florence-2-Flux': "<DESCRIPTION>",
'gokaygokay/Florence-2-Flux-Large': "<DESCRIPTION>",
'yayayaaa/florence-2-large-ft-moredetailed': "<MORE_DETAILED_CAPTION>",
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "<MORE_DETAILED_CAPTION>"
}
# Define a description for each model to be shown in UI
model_descriptions = {
'gokaygokay/Florence-2-Flux': "Faster version with good quality captions",
'gokaygokay/Florence-2-Flux-Large': "Provides detailed captions with better image understanding",
'yayayaaa/florence-2-large-ft-moredetailed': "Fine-tuned specifically for more detailed captions",
'MiaoshouAI/Florence-2-large-PromptGen-v2.0': "Memory efficient model with high quality detailed captions"
}
# Load a single model to start with
print("Loading Florence-2 model...")
model_name = 'gokaygokay/Florence-2-Flux'
task_prompt = model_configs[model_name]
# Load model without device_map
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
).eval()
# Move to GPU if available
if device == "cuda":
model = model.to("cuda")
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Successfully loaded model: {model_name}")
title = """<h1 align="center">Florence-2 Caption Dataset Creator</h1>
<p><center>
<a href="https://huggingface.co/gokaygokay/Florence-2-Flux-Large" target="_blank">[Florence-2 Flux Large]</a>
<a href="https://huggingface.co/gokaygokay/Florence-2-Flux" target="_blank">[Florence-2 Flux Base]</a>
<a href="https://huggingface.co/yayayaaa/florence-2-large-ft-moredetailed" target="_blank">[Florence-2 More Detailed]</a>
<a href="https://huggingface.co/MiaoshouAI/Florence-2-large-PromptGen-v2.0" target="_blank">[MiaoshouAI PromptGen v2.0]</a>
</center></p>"""
# Function to clean caption text
def clean_caption(text):
# Remove <pad> tokens from the end
text = re.sub(r'<pad>+$', '', text)
# Remove any extra whitespace
text = text.strip()
return text
# Function to load a specific model
def load_model(selected_model_name):
global model, processor, model_name, task_prompt
# Only reload if the model is different
if selected_model_name != model_name:
print(f"Switching to model: {selected_model_name}")
# Release memory from the current model
del model
torch.cuda.empty_cache()
# Load the new model
model_name = selected_model_name
task_prompt = model_configs[model_name]
# Load model without device_map
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True
).eval()
# Move to GPU if available
if device == "cuda":
model = model.to("cuda")
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
print(f"Successfully switched to model: {model_name}")
return "Model loaded successfully"
# Special function for MiaoshouAI model
def generate_miaoshou_caption(image):
"""Special handling for MiaoshouAI model"""
# Create inputs for MiaoshouAI model
inputs = processor(
text=task_prompt,
images=image,
return_tensors="pt"
)
# Move inputs to device
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate using only input_ids and pixel_values
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=512,
do_sample=False,
num_beams=3
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
# Use the model's post-processing
try:
parsed_answer = processor.post_process_generation(
generated_text,
task=task_prompt,
image_size=(image.width, image.height)
)
# Get the generated text from parsed answer
if isinstance(parsed_answer, dict) and task_prompt in parsed_answer:
return parsed_answer[task_prompt]
else:
return str(parsed_answer)
except Exception as e:
print(f"Post-processing error: {str(e)}")
# Fallback to regular decoding if post-processing fails
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Function to generate a caption for a single image
@spaces.GPU
def generate_caption(image, selected_model_name):
if image is None:
return "Please upload an image."
# Check if we need to switch models
if selected_model_name != model_name:
try:
load_model(selected_model_name)
except Exception as e:
return f"Error loading model {selected_model_name}: {str(e)}"
if isinstance(image, str):
# Handle file path input
image = Image.open(image)
else:
# Handle numpy array input from gradio
image = Image.fromarray(image)
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
try:
# Special handling for MiaoshouAI model
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
caption = generate_miaoshou_caption(image)
else:
# Regular processing for other models
# Create an appropriate prompt based on the model
prompt = task_prompt
if prompt == "<DESCRIPTION>":
prompt = prompt + "Describe this image in great detail."
# Process the image
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move inputs to the same device as the model
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate the caption
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
num_beams=3,
repetition_penalty=1.10,
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Handle post-processing for different models
if task_prompt == "<DESCRIPTION>":
# Use the post processing for Florence-2-Flux models
try:
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
decoded_text,
task=task_prompt,
image_size=(image.width, image.height)
)
caption = parsed_answer[task_prompt]
except Exception as e:
print(f"Error in post processing: {str(e)}")
caption = generated_text # Fallback to direct output
else:
# For other models, use the generated text directly
caption = generated_text
# Clean the caption to remove padding tokens
clean_text = clean_caption(caption)
return clean_text
except Exception as e:
error_msg = f"Error generating caption: {str(e)}"
print(error_msg)
return error_msg
# Function to process multiple images and create a downloadable zip
@spaces.GPU
def process_images(images, selected_model_name, add_trigger=True, trigger_word="trigger"):
"""Process multiple images, caption them, and create downloadable zip file"""
if not images:
return "No images uploaded.", None
# Check if we need to switch models
if selected_model_name != model_name:
try:
load_model(selected_model_name)
except Exception as e:
return f"Error loading model {selected_model_name}: {str(e)}", None
# Create a temporary directory to store files
temp_dir = tempfile.mkdtemp()
# Path for the zip file
zip_path = os.path.join(temp_dir, "captions_dataset.zip")
results = []
try:
# Create a zip file
with zipfile.ZipFile(zip_path, 'w') as zipf:
for img_file in images:
try:
# Get file path and extract filename
img_path = img_file.name
base_name = os.path.basename(img_path)
file_name, file_ext = os.path.splitext(base_name)
# Skip unsupported formats
if file_ext.lower() not in ['.jpg', '.jpeg', '.png']:
results.append(f"⚠️ Skipped {base_name}: Unsupported format (only jpg, jpeg, png supported)")
continue
# Generate caption
# Open the image once
image = Image.open(img_path)
if image.mode != "RGB":
image = image.convert("RGB")
# Use the same caption generation logic as in generate_caption
if model_name == 'MiaoshouAI/Florence-2-large-PromptGen-v2.0':
caption = generate_miaoshou_caption(image)
else:
# Regular processing for other models
# Create an appropriate prompt based on the model
prompt = task_prompt
if prompt == "<DESCRIPTION>":
prompt = prompt + "Describe this image in great detail."
# Process the image
inputs = processor(text=prompt, images=image, return_tensors="pt")
# Move inputs to the same device as the model
for key in inputs:
if isinstance(inputs[key], torch.Tensor):
inputs[key] = inputs[key].to(device)
# Generate the caption
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
num_beams=3,
repetition_penalty=1.10,
)
# Decode the generated text
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Handle post-processing for different models
if task_prompt == "<DESCRIPTION>":
# Use the post processing for Florence-2-Flux models
try:
decoded_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
decoded_text,
task=task_prompt,
image_size=(image.width, image.height)
)
caption = parsed_answer[task_prompt]
except Exception as e:
print(f"Error in post processing: {str(e)}")
caption = generated_text # Fallback to direct output
else:
# For other models, use the generated text directly
caption = generated_text
# Clean caption and add trigger if needed
caption = clean_caption(caption)
if add_trigger:
caption = f"[{trigger_word}] {caption}"
# Create a text file with the caption
txt_filename = f"{file_name}.txt"
txt_path = os.path.join(temp_dir, txt_filename)
with open(txt_path, "w", encoding="utf-8") as f:
f.write(caption)
# Add the text file to the zip
zipf.write(txt_path, txt_filename)
# Add the image to the zip
zipf.write(img_path, base_name)
# Add to results
caption_preview = f"{caption[:50]}..." if len(caption) > 50 else caption
results.append(f"βœ“ {base_name} β†’ {file_name}.txt: {caption_preview}")
except Exception as e:
results.append(f"❌ Error processing {base_name}: {str(e)}")
except Exception as e:
error_msg = f"Error creating zip file: {str(e)}"
print(error_msg)
return error_msg, None
# Format results
summary = f"Processed {len(results)} images. Ready for download.\n\n"
result_text = summary + "\n".join(results)
return result_text, zip_path
# Create the Gradio interface
with gr.Blocks() as demo:
gr.HTML(title)
with gr.Tabs():
# Single image preview tab
with gr.TabItem("Preview Caption"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Picture")
model_selector = gr.Dropdown(
choices=list(model_configs.keys()),
label="Model",
value=model_name
)
preview_btn = gr.Button(value="Generate Caption")
with gr.Column():
output_text = gr.Textbox(label="Generated Caption", lines=8)
preview_btn.click(generate_caption, [input_img, model_selector], [output_text])
# Dataset creation tab
with gr.TabItem("Create Dataset"):
with gr.Row():
with gr.Column(scale=1):
batch_images = gr.File(
file_count="multiple",
label="Upload Multiple Images (JPG, JPEG, PNG)"
)
batch_model_selector = gr.Dropdown(
choices=list(model_configs.keys()),
label="Model",
value=model_name
)
with gr.Row():
add_trigger = gr.Checkbox(label="Add Trigger Word", value=True)
trigger_word = gr.Textbox(
label="Trigger Word",
placeholder="trigger",
value="trigger"
)
process_btn = gr.Button(value="Process Images")
with gr.Column(scale=1):
batch_results = gr.Textbox(label="Processing Results", lines=15)
download_output = gr.File(label="Download Dataset (Images & Captions)")
# Connect process button
process_btn.click(
fn=process_images,
inputs=[batch_images, batch_model_selector, add_trigger, trigger_word],
outputs=[batch_results, download_output]
)
# Instructions with model information
with gr.Accordion("Instructions & Model Information", open=True):
gr.Markdown("""
## Instructions
### Preview Caption
- Upload a single image and generate a detailed caption
- Try different models to compare results
### Create Dataset
- Upload multiple images to process them all at once
- All images will be captioned and saved with matching .txt files
- By default, captions include `[trigger]` at the beginning (you can modify the trigger word)
- Click "Process Images" to generate captions and create a downloadable dataset
- Use the download button to get a ZIP file containing all images and caption files
## Models Available
""")
# Create a markdown description for each model
model_md = ""
for model_id, description in model_descriptions.items():
model_short_name = model_id.split('/')[-1]
model_md += f"- **{model_short_name}**: {description}\n"
gr.Markdown(model_md)
# Add special note for MiaoshouAI model
gr.Markdown("""
### MiaoshouAI/Florence-2-large-PromptGen-v2.0 Features
- Improved caption quality for detailed captions
- Memory efficient (requires only ~1GB VRAM)
- Fast generation while maintaining high quality
- Supports multiple caption formats including detailed captions, tags, and analysis
Supported image formats: JPG, JPEG, PNG
""")
if __name__ == "__main__":
demo.launch()