Ateeqq's picture
Update app.py
c1cda26 verified
import torch
from transformers import AutoImageProcessor, SiglipForImageClassification
from PIL import Image
import torch.nn.functional as F
import gradio as gr
# Load model and processor from Hugging Face Hub
model_path = "Ateeqq/nsfw-image-detection"
processor = AutoImageProcessor.from_pretrained(model_path)
model = SiglipForImageClassification.from_pretrained(model_path)
model.eval()
def predict(image):
# Convert to RGB and preprocess
image = Image.fromarray(image).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
# Inference
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1)[0].tolist()
# Return dictionary: class name -> confidence (no extra keys!)
labels = [model.config.id2label[i] for i in range(len(probs))]
return {labels[i]: float(f"{probs[i]:.8f}") for i in range(len(labels))}
# Gradio Interface
def main():
description = "NSFW Image Detection using SigLIP2 Safety Classifier"
model_card_link = "[🧠 View Model on Hugging Face](https://huggingface.co/Ateeqq/nsfw-image-detection)"
article_link = "[πŸ“– Read Training Article](https://exnrt.com/blog/ai/fine-tuning-siglip2/)"
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy", label="Upload Image"),
outputs=gr.Label(num_top_classes=3, label="Predictions"),
title="NSFW Image Detector",
description=description,
article=f"{model_card_link}<br>{article_link}",
allow_flagging="never"
)
iface.launch()
if __name__ == "__main__":
main()