student0822's picture
Update app.py
e8d514a verified
raw
history blame contribute delete
4.13 kB
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
import os
import json
# 设置 Kaggle API 凭证
# 设置 Kaggle API 凭证
# 设置 Kaggle API 凭证
def setup_kaggle():
# 创建 .kaggle 目录
os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
# 读取并写入 kaggle.json 文件
with open("./kaggle.json", "r") as f: # 使用相对路径 ./kaggle.json
kaggle_token = json.load(f)
with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f:
json.dump(kaggle_token, f)
os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600)
# 从 Kaggle 下载模型文件
def download_model():
# 设置 Kaggle API 凭证
setup_kaggle()
# 使用 Kaggle API 下载文件
os.system("kaggle kernels output sonia0822/20241015 -p /app") # 修改为您的 Kernel ID 和下载路径
# 确保模型文件已下载
if not os.path.exists("/app/model.pth"):
raise FileNotFoundError("模型文件下载失败!")
# 在加载模型前下载
if not os.path.exists("model.pth"):
print("Downloading model...")
download_model()
# 模型保存路径
classification_model_path = "/app/model.pth"
gpt2_model_path = "/app/gpt2-finetuned"
# 加载分类模型和特征提取器
print("加载分类模型...")
classification_model = AutoModelForImageClassification.from_pretrained(
"microsoft/beit-base-patch16-224-pt22k", num_labels=16
)
classification_model.load_state_dict(torch.load(classification_model_path, map_location="cpu"))
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
print("分类模型加载成功")
# 加载 GPT-2 文本生成模型
print("加载 GPT-2 模型...")
gpt2_generator = pipeline("text-generation", model=gpt2_model_path, tokenizer=gpt2_model_path)
print("GPT-2 模型加载成功")
# 定义风格标签列表
art_styles = [
"现实主义", "巴洛克", "后印象派", "印象派", "浪漫主义", "超现实主义",
"表现主义", "立体派", "野兽派", "抽象艺术", "新艺术", "象征主义",
"新古典主义", "洛可可", "文艺复兴", "极简主义"
]
# 标签映射
label_mapping = {0: 0, 2: 1, 3: 2, 4: 3, 7: 4, 9: 5, 10: 6, 12: 7, 15: 8, 17: 9, 18: 10, 20: 11, 21: 12, 23: 13, 24: 14, 25: 15}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
# 生成风格描述的函数
def classify_and_generate_description(image):
image = image.convert("RGB")
inputs = feature_extractor(images=image, return_tensors="pt").to("cpu")
classification_model.eval()
with torch.no_grad():
outputs = classification_model(**inputs).logits
predicted_class = torch.argmax(outputs, dim=1).item()
predicted_label = reverse_label_mapping.get(predicted_class, "未知")
predicted_style = art_styles[predicted_class] if predicted_class < len(art_styles) else "未知"
prompt = f"请详细描述{predicted_style}的艺术风格。"
description = gpt2_generator(prompt, max_length=100, num_return_sequences=1)[0]["generated_text"]
return predicted_style, description
def ask_gpt2(question):
response = gpt2_generator(question, max_length=100, num_return_sequences=1)[0]["generated_text"]
return response
# Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# 艺术风格分类和生成描述")
with gr.Row():
image_input = gr.Image(label="上传一张艺术图片")
style_output = gr.Textbox(label="预测的艺术风格")
description_output = gr.Textbox(label="生成的风格描述")
with gr.Row():
question_input = gr.Textbox(label="输入问题")
answer_output = gr.Textbox(label="GPT-2 生成的回答")
classify_btn = gr.Button("生成风格描述")
question_btn = gr.Button("问 GPT-2 一个问题")
classify_btn.click(fn=classify_and_generate_description, inputs=image_input, outputs=[style_output, description_output])
question_btn.click(fn=ask_gpt2, inputs=question_input, outputs=answer_output)
demo.launch()