import gradio as gr | |
from transformers import pipeline | |
pipeline = pipeline(task="image-classification", model="julien-c/hotdog-not-hotdog") | |
def predict(image): | |
predictions = pipeline(image) | |
return {p["label"]: p["score"] for p in predictions} | |
gr.Interface( | |
predict, | |
inputs=gr.Image(label="Upload hot dog candidate", type="filepath"), | |
outputs=gr.Label(num_top_classes=2), | |
title="Hot Dog? Or Not?", | |
allow_flagging="manual" | |
).launch() | |
# import matplotlib.pyplot as plt | |
# import torch | |
# from PIL import Image | |
# from torchvision import transforms | |
# import torch.nn.functional as F | |
# from typing import Literal, Any | |
# import gradio as gr | |
# import spaces | |
# from io import BytesIO | |
# class Classifier: | |
# LABELS = [ | |
# "Panoramic", | |
# "Feature", | |
# "Detail", | |
# "Enclosed", | |
# "Focal", | |
# "Ephemeral", | |
# "Canopied", | |
# ] | |
# @spaces.GPU(duration=60) | |
# def __init__( | |
# self, model_path="Litton-7type-visual-landscape-model.pth", device="cuda:0" | |
# ): | |
# self.device = device | |
# self.model = torch.load( | |
# model_path, map_location=self.device, weights_only=False | |
# ) | |
# if hasattr(self.model, "module"): | |
# self.model = self.model.module | |
# self.model.eval() | |
# self.preprocess = transforms.Compose( | |
# [ | |
# transforms.Resize(256), | |
# transforms.CenterCrop(224), | |
# transforms.ToTensor(), | |
# transforms.Normalize( | |
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
# ), | |
# ] | |
# ) | |
# @spaces.GPU(duration=60) | |
# def predict(self, image: Image.Image) -> tuple[Literal["Failed", "Success"], Any]: | |
# image = image.convert("RGB") | |
# input_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
# with torch.no_grad(): | |
# logits = self.model(input_tensor) | |
# probs = F.softmax(logits[:, :7], dim=1).cpu() | |
# return draw_bar_chart( | |
# { | |
# "class": self.LABELS, | |
# "probs": probs[0] * 100, | |
# } | |
# ) | |
# def draw_bar_chart(data: dict[str, list[str | float]]): | |
# classes = data["class"] | |
# probabilities = data["probs"] | |
# plt.figure(figsize=(8, 6)) | |
# plt.bar(classes, probabilities, color="skyblue") | |
# plt.xlabel("Class") | |
# plt.ylabel("Probability (%)") | |
# plt.title("Class Probabilities") | |
# for i, prob in enumerate(probabilities): | |
# plt.text(i, prob + 0.01, f"{prob:.2f}", ha="center", va="bottom") | |
# plt.tight_layout() | |
# return plt | |
# def get_layout(): | |
# demo = gr.Interface(fn=Classifier().predict, inputs="image", outputs="plot") | |
# return demo | |
# css = """ | |
# .main-title { | |
# font-size: 24px; | |
# font-weight: bold; | |
# text-align: center; | |
# margin-bottom: 20px; | |
# } | |
# .reference { | |
# text-align: center; | |
# font-size: 1.2em; | |
# color: #d1d5db; | |
# margin-bottom: 20px; | |
# } | |
# .reference a { | |
# color: #FB923C; | |
# text-decoration: none; | |
# } | |
# .reference a:hover { | |
# text-decoration: underline; | |
# color: #FB923C; | |
# } | |
# .title { | |
# border-bottom: 1px solid; | |
# } | |
# .footer { | |
# text-align: center; | |
# margin-top: 30px; | |
# padding-top: 20px; | |
# border-top: 1px solid #ddd; | |
# color: #d1d5db; | |
# font-size: 14px; | |
# } | |
# """ | |
# theme = gr.themes.Base( | |
# primary_hue="orange", | |
# secondary_hue="orange", | |
# neutral_hue="gray", | |
# font=gr.themes.GoogleFont("Source Sans Pro"), | |
# ).set( | |
# background_fill_primary="*neutral_950", # 主背景色(深黑) | |
# button_primary_background_fill="*primary_500", # 按鈕顏色(橘色) | |
# body_text_color="*neutral_200", # 文字顏色(淺色) | |
# ) | |
# # with gr.Blocks(css=css, theme=theme) as demo: | |
# with gr.Blocks() as demo: | |
# with gr.Column(): | |
# gr.HTML( | |
# value=( | |
# '<div class="main-title">Litton7景觀分類模型</div>' | |
# '<div class="reference">引用資料:' | |
# '<a href="https://www.airitilibrary.com/Article/Detail/10125434-N202406210003-00003" target="_blank">' | |
# "何立智、李沁築、邱浩修(2024)。Litton7:Litton視覺景觀分類深度學習模型。戶外遊憩研究,37(2)" | |
# "</a>" | |
# "</div>" | |
# ), | |
# ) | |
# with gr.Row(equal_height=True): | |
# image_input = gr.Image(label="上傳影像", type="pil") | |
# chart = gr.Image(label="分類結果") | |
# start_button = gr.Button("開始分類", variant="primary") | |
# gr.HTML( | |
# '<div class="footer">© 2024 LCL 版權所有<br>開發者:何立智、楊哲睿</div>', | |
# ) | |
# start_button.click( | |
# fn=Classifier().predict, | |
# inputs=image_input, | |
# outputs=chart, | |
# ) | |
# return demo | |
# if __name__ == "__main__": | |
# get_layout().launch() | |