File size: 4,183 Bytes
bb0120f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gradio as gr
from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import torch
import os
import json

# 设置 Kaggle API 凭证
def setup_kaggle():
    # 创建 .kaggle 目录
    os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
    # 读取并写入 kaggle.json 文件
    with open("/app/kaggle.json", "r") as f:  # 直接使用 /app/kaggle.json 路径
        kaggle_token = json.load(f)
    with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f:
        json.dump(kaggle_token, f)
    os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600)

# 从 Kaggle 下载模型文件
def download_model():
    # 设置 Kaggle API 凭证
    setup_kaggle()

    # 使用 Kaggle API 下载文件
    os.system("kaggle kernels output sonia0822/20241015 -p /app")  # 修改为您的 Kernel ID 和下载路径

    # 确保模型文件已下载
    if not os.path.exists("/app/model.pth"):
        raise FileNotFoundError("模型文件下载失败!")

# 在加载模型前下载
if not os.path.exists("model.pth"):
    print("Downloading model...")
    download_model()

# 模型保存路径
classification_model_path = "/app/model.pth"
gpt2_model_path = "/app/gpt2-finetuned"

# 加载分类模型和特征提取器
print("加载分类模型...")
classification_model = AutoModelForImageClassification.from_pretrained(
    "microsoft/beit-base-patch16-224-pt22k", num_labels=16
)
classification_model.load_state_dict(torch.load(classification_model_path, map_location="cpu"))
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
print("分类模型加载成功")

# 加载 GPT-2 文本生成模型
print("加载 GPT-2 模型...")
gpt2_generator = pipeline("text-generation", model=gpt2_model_path, tokenizer=gpt2_model_path)
print("GPT-2 模型加载成功")

# 定义风格标签列表
art_styles = [
    "现实主义", "巴洛克", "后印象派", "印象派", "浪漫主义", "超现实主义",
    "表现主义", "立体派", "野兽派", "抽象艺术", "新艺术", "象征主义",
    "新古典主义", "洛可可", "文艺复兴", "极简主义"
]

# 标签映射
label_mapping = {0: 0, 2: 1, 3: 2, 4: 3, 7: 4, 9: 5, 10: 6, 12: 7, 15: 8, 17: 9, 18: 10, 20: 11, 21: 12, 23: 13, 24: 14, 25: 15}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}

# 生成风格描述的函数
def classify_and_generate_description(image):
    image = image.convert("RGB")
    inputs = feature_extractor(images=image, return_tensors="pt").to("cpu")
    classification_model.eval()
    with torch.no_grad():
        outputs = classification_model(**inputs).logits
        predicted_class = torch.argmax(outputs, dim=1).item()

    predicted_label = reverse_label_mapping.get(predicted_class, "未知")
    predicted_style = art_styles[predicted_class] if predicted_class < len(art_styles) else "未知"

    prompt = f"请详细描述{predicted_style}的艺术风格。"
    description = gpt2_generator(prompt, max_length=100, num_return_sequences=1)[0]["generated_text"]
    return predicted_style, description

def ask_gpt2(question):
    response = gpt2_generator(question, max_length=100, num_return_sequences=1)[0]["generated_text"]
    return response

# Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("# 艺术风格分类和生成描述")
    with gr.Row():
        image_input = gr.Image(label="上传一张艺术图片")
        style_output = gr.Textbox(label="预测的艺术风格")
        description_output = gr.Textbox(label="生成的风格描述")
    with gr.Row():
        question_input = gr.Textbox(label="输入问题")
        answer_output = gr.Textbox(label="GPT-2 生成的回答")
    classify_btn = gr.Button("生成风格描述")
    question_btn = gr.Button("问 GPT-2 一个问题")
    classify_btn.click(fn=classify_and_generate_description, inputs=image_input, outputs=[style_output, description_output])
    question_btn.click(fn=ask_gpt2, inputs=question_input, outputs=answer_output)

demo.launch()