imgprivllm / app.py
hugohabicht01
fix accelerate error
6aa2d3f
raw
history blame
11.8 kB
import gradio as gr
import spaces
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import numpy as np
import traceback
from typing import Any, Optional
import utils
from utils import BoundingBox
import blurnonymize
MODEL_NAME = "cborg/qwen2.5VL-3b-privacydetector"
MAX_NEW_TOKENS = 2048
TEMPERATURE = 1.0
MIN_P = 0.1
SYSTEM_PROMPT = """You are a helpful assistant for privacy analysis of images. Please always answer in English. Please obey the users instructions and follow the provided format."""
DEFAULT_PROMPT = """
You are an expert at pixel perfect image analysis and in privacy.
First write down your thoughts within a <think> block.
Please go through all objects in the image and consider whether they are private data or not.
End this with a </think> block.
After going through everything, output your findings in an <output></output> block as a json list with the following keys:
{"label": <|object_ref_start|>str<|object_ref_end|>, "description": str, "explanation": str, "bounding_box": <|box_start|>[x_min, y_min, x_max, y_max]<|box_end|>, "severity": int}
Some things to remember:
- private data is only data thats linked to a human person, common examples being a persons face, name, address, license plate
- whenever something can be used to identify a unique human person, it is private data
- report sensitive data as well, such as a nude person
- Severity is a number between 0 and 10, with 0 being not private data and 10 being extremely sensitive private data.
- don't report items which dont contain private data in the final output, you may mention them in your thoughts
- animals and animal faces are not personal data, so a giraffe or a dog is not private data
- you can use whatever format you want within the <think> </think> blocks
- only output valid JSON in between the <output> </output> blocks, adhering to the schema provided
- output the bounding box always as an array of form [x_min, y_min, x_max, y_max]
- private data have a severity greater than 0, so a human face would have severity 6
- go through the image step by step and report the private data, its better to be a bit too sensitive than to miss anything
- put the bounding boxes around the human's face and not the entire person when reporting people as personal data
- Think step by step, take your time.
Here is the image to analyse, start your analysis directly after:
"""
def build_messages(image, history: Optional[list[dict[str, Any]]] = None, prompt: Optional[str] = None):
if not prompt:
prompt = DEFAULT_PROMPT
if history:
return [
*history,
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
return [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYSTEM_PROMPT,
}
],
},
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image", "image": image},
],
},
]
# --- Model Loading ---
# Load model using unsloth for 4-bit quantization
try:
# default: Load the model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to("cuda").eval()
tokenizer = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
model.to("cuda").eval() # Ensure model is on GPU and in eval mode
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print(traceback.format_exc())
# Optionally raise or handle the error to prevent app launch if model fails
raise gr.Error(f"Failed to load model {MODEL_NAME}. Check logs. Error: {e}")
# --- Blurnonymizer Instance ---
try:
blurnonymizer_instance = blurnonymize.ImageBlurnonymizer()
print("Blurnonymizer initialized successfully.")
except Exception as e:
print(f"Error initializing Blurnonymizer: {e}")
print(traceback.format_exc())
raise gr.Error(f"Failed to initialize Blurnonymizer. Check logs. Error: {e}")
# --- Core Processing Function ---
@spaces.GPU(duration=20) # add this so that the sam segmentation runs on the gpu
def anonymise_image(input_image_np: np.ndarray, boxes: list[BoundingBox]):
"""Calls the blurnonymizer instance to censor the image."""
if not blurnonymizer_instance:
raise gr.Error("Blurnonymizer not initialized.")
return blurnonymizer_instance.censor_image_blur_easy(
input_image_np, boxes, method="segmentation", verbose=False # Set verbose as needed
)
def run_model_inference(input_image_pil: Image.Image, prompt_text: str):
"""
Runs model inference on the input image and prompt.
"""
# 1. Run Model Inference
print("Running model inference...")
messages = build_messages(
input_image_pil,
prompt=prompt_text)
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = tokenizer(
text=[input_text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
out_tokens = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
use_cache=True,
temperature=TEMPERATURE,
min_p=MIN_P,
)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, out_tokens)
]
raw_model_output = tokenizer.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)[0]
input_height = inputs['image_grid_thw'][0][1]*14
input_width = inputs['image_grid_thw'][0][2]*14
if input_height != input_image_pil.height:
print("[!] tokenized image height differs from actual height:")
print(f"Actual: {input_image_pil.height}, processed: {input_height}")
if input_width != input_image_pil.width:
print("[!] tokenized image width differs from actual width:")
print(f"Actual: {input_image_pil.width}, processed: {input_width}")
print("[+] Model inference completed.")
print("[*] Raw output:")
print(raw_model_output)
return raw_model_output, input_height, input_width
@spaces.GPU(duration=90) # Request GPU for this function, allow up to 120 seconds
def analyze_image(input_image_pil: Image.Image, prompt_text: str):
"""
Analyzes the input image using the VLM, visualizes findings, and anonymizes.
"""
if input_image_pil is None:
raise gr.Error("Please upload an image.")
if not prompt_text:
raise gr.Error("Please provide a prompt.")
original_image_np = np.array(input_image_pil)
# 1. Run Model Inference
try:
raw_model_output, image_height, image_width = run_model_inference(input_image_pil, prompt_text)
except Exception as e:
print(f"Error during model inference: {e}")
print(traceback.format_exc())
raise gr.Error(f"Model inference failed: {e}")
# 2. Parse Findings
try:
print("Parsing findings...")
# Use the provided utility functions
parsed_findings = utils.parse_into_models(
utils.parse_json_response(raw_model_output)
)
print(f"[+] Parsed {len(parsed_findings)} findings.")
if not parsed_findings:
print("[*] No findings were parsed from the model output.")
except Exception as e:
print(f"Error parsing model output: {e}")
print(traceback.format_exc())
# Don't raise error here, allow visualization/anonymization steps to proceed if possible
# or return early with only original image if parsing is critical
gr.Warning(
f"Could not parse findings from model output: {e}. Visualization and anonymization might be incomplete."
)
# Fallback: visualize/anonymize based on empty findings list if needed
parsed_findings = [] # Ensure it's an empty list for downstream steps
# Initialize boxes_for_viz before the try block
boxes_for_viz = []
try:
# 3. Visualize Findings
print("Visualizing findings...")
if parsed_findings:
# Convert Findings to BoundingBox for visualization function
boxes_for_viz = [BoundingBox.from_finding(f) for f in parsed_findings]
# Ensure image is in the correct format (np array) for visualize_boxes_annotated
visualized_image_np = utils.visualize_boxes_annotated(
original_image_np, boxes_for_viz
)
print("Visualization generated.")
else:
print("No findings to visualize, using original image.")
visualized_image_np = (
original_image_np.copy()
) # Show original if no findings
except Exception as e:
print(f"Error during visualization: {e}")
print(traceback.format_exc())
gr.Warning(f"Failed to visualize findings: {e}")
visualized_image_np = original_image_np.copy() # Fallback to original
try:
# 4. Anonymize Image
print("Anonymizing image...")
# Use the blurnonymize function with the raw output (as it might contain info needed by the func)
# Ensure image is numpy array
# Check if boxes_for_viz is populated before calling anonymise_image
if boxes_for_viz:
anonymized_image_np = anonymise_image(original_image_np, boxes_for_viz)
print("Anonymization generated.")
else:
print("No boxes found for anonymization, using original image.")
anonymized_image_np = original_image_np.copy()
except Exception as e:
print(f"Error during anonymization: {e}")
print(traceback.format_exc())
gr.Warning(f"Failed to anonymize image: {e}")
anonymized_image_np = original_image_np.copy() # Fallback to original
# Convert numpy arrays back to PIL Images for Gradio output if needed, or let Gradio handle numpy
# Gradio's gr.Image output can handle numpy arrays directly
# Return the three images
return raw_model_output, visualized_image_np, anonymized_image_np
# --- Gradio Interface ---
with gr.Blocks() as demo:
gr.Markdown("# Private Data Detection & Anonymization UI")
gr.Markdown(f"Using model: `{MODEL_NAME}` on ZeroGPU.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Image")
prompt_textbox = gr.Textbox(
label="Analysis Prompt", value=DEFAULT_PROMPT, lines=4
)
analyze_button = gr.Button("Analyze Image")
with gr.Column(scale=2):
with gr.Column():
raw_output = gr.Textbox(
label="Raw Model Output", interactive=False
)
output_visualized = gr.Image(
label="Detected Privacy Findings", type="numpy", interactive=False
)
output_anonymized = gr.Image(
label="Anonymized", type="numpy", interactive=False
)
analyze_button.click(
fn=analyze_image,
inputs=[input_image, prompt_textbox],
outputs=[raw_output, output_visualized, output_anonymized],
)
# --- Launch App ---
if __name__ == "__main__":
demo.queue().launch(
debug=True
) # Enable queue for handling multiple requests, debug mode for logs