import gradio as gr from transformers import CLIPModel, CLIPProcessor from PIL import Image import requests # Step 1: Load Fine-Tuned Model from Hugging Face Model Hub model_name = "quadranttechnologies/retail-content-safety-clip-finetuned" print("Initializing the application...") try: print("Loading the model from Hugging Face Model Hub...") model = CLIPModel.from_pretrained(model_name, trust_remote_code=True) processor = CLIPProcessor.from_pretrained(model_name) print("Model and processor loaded successfully.") except Exception as e: print(f"Error loading the model or processor: {e}") raise RuntimeError(f"Failed to load model: {e}") # Step 2: Define the Inference Function def classify_image(image): """ Classify an image as 'safe' or 'unsafe' and return probabilities. Args: image (PIL.Image.Image): Uploaded image. Returns: dict: Classification results or an error message. """ try: print("Starting image classification...") # Validate input if image is None: raise ValueError("No image provided. Please upload a valid image.") # Validate image format if not hasattr(image, "convert"): raise ValueError("Invalid image format. Please upload a valid image (JPEG, PNG, etc.).") # Define categories categories = ["safe", "unsafe"] # Process the image with the processor print("Processing the image...") inputs = processor(text=categories, images=image, return_tensors="pt", padding=True) print(f"Processed inputs: {inputs}") # Run inference with the model print("Running model inference...") outputs = model(**inputs) print(f"Model outputs: {outputs}") # Extract logits and probabilities logits_per_image = outputs.logits_per_image # Image-text similarity scores probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities print(f"Calculated probabilities: {probs}") # Extract probabilities for each category safe_prob = probs[0][0].item() * 100 # Safe percentage unsafe_prob = probs[0][1].item() * 100 # Unsafe percentage # Return results return { "safe": f"{safe_prob:.2f}%", "unsafe": f"{unsafe_prob:.2f}%" } except Exception as e: # Log and return detailed error messages print(f"Error during classification: {e}") return {"Error": str(e)} # Step 3: Set Up Gradio Interface iface = gr.Interface( fn=classify_image, inputs=gr.Image(type="pil"), outputs=gr.Label(label="Output"), # Display probabilities as progress bars title="Content Safety Classification", description="Upload an image to classify it as 'safe' or 'unsafe' with corresponding probabilities.", ) # Step 4: Test Before Launch if __name__ == "__main__": print("Testing model locally with a sample image...") try: # Test with a sample image url = "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png" test_image = Image.open(requests.get(url, stream=True).raw) # Run the classification function print("Running local test...") result = classify_image(test_image) print(f"Local Test Result: {result}") except Exception as e: print(f"Error during local test: {e}") # Launch Gradio Interface print("Launching the Gradio interface...") iface.launch()