ariG23498's picture
ariG23498 HF Staff
Update app.py
b714526 verified
raw
history blame contribute delete
2.67 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import io
# Load model and processor (using CPU)
folder_path = "diffusers/shot-categorizer-v0"
model = AutoModelForCausalLM.from_pretrained(folder_path, trust_remote_code=True).eval()
processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True)
# Define analysis function
def analyze_image(image):
# Convert Gradio image input to PIL Image
if isinstance(image, Image.Image):
img = image.convert("RGB")
else:
img = Image.open(io.BytesIO(image)).convert("RGB")
prompts = ["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"]
results = {}
# Process each prompt
with torch.no_grad():
for prompt in prompts:
inputs = processor(text=prompt, images=img, return_tensors="pt")
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text, task=prompt, image_size=(img.width, img.height)
)
results[prompt] = parsed_answer
# Format the output
output_text = "Image Analysis Results:\n\n"
output_text += f"Color: {results['<COLOR>']}\n"
output_text += f"Lighting: {results['<LIGHTING>']}\n"
output_text += f"Lighting Type: {results['<LIGHTING_TYPE>']}\n"
output_text += f"Composition: {results['<COMPOSITION>']}\n"
return output_text
# Create Gradio interface
with gr.Blocks(title="Image Analyzer") as demo:
gr.Markdown("# Image Analysis Demo")
gr.Markdown("Upload an image to analyze its color, lighting, and composition characteristics.")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Image")
analyze_button = gr.Button("Analyze Image")
with gr.Column():
output_text = gr.Textbox(label="Analysis Results", lines=10)
# Add example images
examples = gr.Examples(
examples=["shot.jpg"],
inputs=image_input,
label="Try with this example"
)
# Connect the button to the function
analyze_button.click(
fn=analyze_image,
inputs=image_input,
outputs=output_text
)
# Launch the demo
demo.launch()