vincentclaes's picture
return confidence score
5ed6ee0
raw
history blame contribute delete
1.74 kB
import pathlib
import gradio as gr
from loguru import logger
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
logger.info("starting gradio app")
CURRENT_DIR = pathlib.Path(__file__).resolve().parent
APP_NAME = "Mona Lisa Detection"
logger.debug("loading processor and model.")
processor = AutoFeatureExtractor.from_pretrained(
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
model = AutoModelForImageClassification.from_pretrained(
"drift-ai/autotrain-mona-lisa-detection-38345101350", use_auth_token=True
)
logger.debug("loading processor and model succeeded.")
def process_image(image, model=model, processor=processor):
logger.info("Making a prediction ...")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
label = {1: "Not Mona Lisa", 0: "Mona Lisa"}
predictions = logits.softmax(dim=-1).tolist()
result = {label[predicted_class_idx]: predictions[0][predicted_class_idx]}
print("Predicted class:", result)
logger.info("Prediction finished.")
return result
examples = [
"mona-lisa-1.jpg",
"mona-lisa-2.jpg",
"mona-lisa-3.jpg",
"not-mona-lisa-1.jpg",
"not-mona-lisa-2.jpg",
"not-mona-lisa-3.jpg",
]
if __name__ == "__main__":
title = """
Mona Lisa Detection.
"""
app = gr.Interface(
fn=process_image,
inputs=[
gr.inputs.Image(type="pil", label="Image"),
],
outputs=gr.Label(label="Predictions:", show_label=True),
examples=examples,
examples_per_page=32,
title=title,
enable_queue=True,
).launch()