Spaces:
Running
Running
import torch | |
from transformers import AutoImageProcessor, SiglipForImageClassification | |
from PIL import Image | |
import torch.nn.functional as F | |
import gradio as gr | |
# Load model and processor from Hugging Face Hub | |
model_path = "Ateeqq/nsfw-image-detection" | |
processor = AutoImageProcessor.from_pretrained(model_path) | |
model = SiglipForImageClassification.from_pretrained(model_path) | |
model.eval() | |
def predict(image): | |
# Convert to RGB and preprocess | |
image = Image.fromarray(image).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
# Inference | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
probs = F.softmax(logits, dim=1)[0].tolist() | |
# Return dictionary: class name -> confidence (no extra keys!) | |
labels = [model.config.id2label[i] for i in range(len(probs))] | |
return {labels[i]: float(f"{probs[i]:.8f}") for i in range(len(labels))} | |
# Gradio Interface | |
def main(): | |
description = "NSFW Image Detection using SigLIP2 Safety Classifier" | |
model_card_link = "[π§ View Model on Hugging Face](https://huggingface.co/Ateeqq/nsfw-image-detection)" | |
article_link = "[π Read Training Article](https://exnrt.com/blog/ai/fine-tuning-siglip2/)" | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="numpy", label="Upload Image"), | |
outputs=gr.Label(num_top_classes=3, label="Predictions"), | |
title="NSFW Image Detector", | |
description=description, | |
article=f"{model_card_link}<br>{article_link}", | |
allow_flagging="never" | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() |