File size: 4,230 Bytes
bd7463f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import json
import requests
import torch
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, CLIPProcessor, CLIPModel
from controlnet_aux import MLSDdetector
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetInpaintPipeline
import urllib.request
import shutil

# 创建资源目录
def create_directories():
    directories = [
        "resources",
        "resources/models",
        "resources/images",
        "resources/labels",
        "resources/output"
    ]
    for directory in directories:
        os.makedirs(directory, exist_ok=True)
    print("目录结构创建完成")

# 下载ADE20K标签文件
def download_labels():
    url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json"
    labels_path = "resources/labels/ade20k-id2label.json"
    response = requests.get(url)
    with open(labels_path, 'w') as f:
        f.write(response.text)
    print(f"标签文件已保存到: {labels_path}")

# 下载示例图片
def download_sample_image():
    raw_url = "https://raw.githubusercontent.com/naderAsadi/DesignGenie/main/examples/images/sample_input.png"
    img_path = "resources/images/sample_input.png"
    try:
        urllib.request.urlretrieve(raw_url, img_path)
        print(f"示例图片已保存到: {img_path}")
        # 同时拷贝到根目录,保持原脚本兼容
        shutil.copy(img_path, "sample_input.png")
    except Exception as e:
        print(f"图片下载失败: {e}")

# 下载模型文件
def download_models():
    print("正在下载模型,这可能需要一些时间...")
    
    # 1. 下载 Mask2Former 模型
    print("下载 Mask2Former 模型...")
    processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models")
    model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models")
    print("Mask2Former 模型下载完成")
    
    # 2. 下载 MLSD 检测器
    print("下载 MLSD 检测器...")
    processor = MLSDdetector.from_pretrained("lllyasviel/Annotators", cache_dir="resources/models")
    print("MLSD 检测器下载完成")
    
    # 3. 下载 ControlNet 模型
    print("下载 ControlNet 模型...")
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/control_v11p_sd15_mlsd", 
        torch_dtype=torch.float16,
        cache_dir="resources/models",
        use_safetensors=False
    )
    print("ControlNet 模型下载完成")
    
    # 4. 下载 Stable Diffusion 模型
    print("下载 Stable Diffusion 模型...")
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5", 
        controlnet=controlnet,
        torch_dtype=torch.float16,
        cache_dir="resources/models",
        use_safetensors=False
    )
    print("Stable Diffusion 模型下载完成")
    
    # 5. 下载 Stable Diffusion Inpainting 模型 (用于 inpaint.py)
    print("下载 Stable Diffusion Inpainting 模型...")
    pipe_inpaint = StableDiffusionControlNetInpaintPipeline.from_pretrained(
        "runwayml/stable-diffusion-inpainting",
        controlnet=controlnet,
        torch_dtype=torch.float16,
        cache_dir="resources/models",
        use_safetensors=False
    )
    print("Stable Diffusion Inpainting 模型下载完成")
    
    # 6. 下载图像特征提取模型 (用于相似性搜索)
    print("下载图像特征提取模型...")
    try:
        clip_model = CLIPModel.from_pretrained(
            "openai/clip-vit-base-patch32", 
            cache_dir="resources/models"
        )
        clip_processor = CLIPProcessor.from_pretrained(
            "openai/clip-vit-base-patch32", 
            cache_dir="resources/models"
        )
        print("图像特征提取模型下载完成")
    except Exception as e:
        print(f"图像特征提取模型下载失败: {e}")

if __name__ == "__main__":
    create_directories()
    download_labels()
    download_sample_image()
    download_models()
    print("所有资源下载完成!您可以将整个 'resources' 文件夹保存到本地使用。")