H1017 commited on
Commit
bd7463f
·
verified ·
1 Parent(s): 2181670

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. .gitignore +63 -0
  2. .gradio/certificate.pem +31 -0
  3. README.md +196 -7
  4. app.py +1251 -0
  5. download_resources.py +112 -0
  6. requirements.txt +0 -0
.gitignore ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python 编译文件
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # 虚拟环境
24
+ venv/
25
+ ENV/
26
+ env/
27
+ .env
28
+
29
+ # 下载的资源文件
30
+ resources/
31
+ *.pt
32
+ *.pth
33
+ *.bin
34
+ *.safetensors
35
+ *.onnx
36
+ model_cache/
37
+
38
+ # 生成的图像
39
+ *.png
40
+ *.jpg
41
+ *.jpeg
42
+ *.gif
43
+ *.bmp
44
+ *.tiff
45
+ sample_input.png
46
+
47
+ # 日志文件
48
+ *.log
49
+ logs/
50
+
51
+ # IDE 相关文件
52
+ .idea/
53
+ .vscode/
54
+ *.swp
55
+ *.swo
56
+ .DS_Store
57
+
58
+ # 临时文件
59
+ .ipynb_checkpoints/
60
+ .pytest_cache/
61
+ .coverage
62
+ htmlcov/
63
+ .tox/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,12 +1,201 @@
1
  ---
2
  title: AiRoom
3
- emoji: 🦀
4
- colorFrom: green
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.25.2
8
  app_file: app.py
9
- pinned: false
 
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: AiRoom
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.20.1
6
  ---
7
+ # AiRoom - AI辅助室内设计工具
8
+
9
+ ## 项目简介
10
+
11
+ AiRoom是一个基于AI技术的室内设计辅助工具,通过结合ControlNet和Stable Diffusion模型,实现对室内场景的全局风格调整和局部区域风格定制。该工具提供了直观的交互式界面,使用户能够轻松地对室内设计进行创意探索和风格转换,并支持相似图像搜索功能,帮助用户发现灵感。
12
+
13
+ ## 功能特点
14
+
15
+ - **全局风格调整**:使用ControlNet保持原始空间布局的同时,通过Stable Diffusion调整整体风格
16
+ - **局部风格调整**:针对特定区域(如墙壁、地板、家具等)进行风格定制,保持其他区域不变
17
+ - **相似图像搜索**:基于CLIP和FAISS实现的高效图像相似性搜索,帮助用户发现相似设计方案
18
+ - **交互式界面**:基于Gradio构建的用户友好界面,支持实时预览和参数调整
19
+ - **多方案生成**:每次生成多个设计方案供用户选择,以2x2网格形式展示
20
+ - **区域智能识别**:自动分析图像中的不同功能区域,无需手动标注
21
+
22
+ ## 安装说明
23
+
24
+ ### 环境要求
25
+
26
+ - Python 3.8+
27
+ - CUDA支持的GPU (推荐8GB+显存)
28
+
29
+ ### 安装步骤
30
+
31
+ 1. 克隆本仓库到本地:
32
+
33
+ ```bash
34
+ git clone https://github.com/yourusername/AiRoom.git
35
+ cd AiRoom
36
+ ```
37
+
38
+ 2. 创建并激活虚拟环境(推荐):
39
+
40
+ ```bash
41
+ # 使用Conda创建虚拟环境
42
+ conda create -n Airoom python=3.10
43
+ conda activate Airoom
44
+
45
+ # 或使用venv创建虚拟环境
46
+ python -m venv Airoom
47
+ # Windows激活
48
+ Airoom\Scripts\activate
49
+ # Linux/Mac激活
50
+ source Airoom/bin/activate
51
+ ```
52
+
53
+ 3. 安装依赖包:
54
+
55
+ ```bash
56
+ # 安装PyTorch(根据您的CUDA版本选择适当的命令)
57
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
58
+
59
+ # 安装项目依赖
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ 4. 下载必要资源:
64
+
65
+ ```bash
66
+ python download_resources.py
67
+ ```
68
+
69
+ ## 使用指南
70
+
71
+ 1. 启动应用:
72
+
73
+ ```bash
74
+ python app.py
75
+ ```
76
+
77
+ 2. 在浏览器中访问显示的本地URL(通常为 http://127.0.0.1:7860)
78
+
79
+ 3. 使用流程:
80
+ - 首先点击"加载模型"按钮,等待所有模型加载完成
81
+ - 选择功能模式(通过顶部选项卡:全局风格调整、局部风格调整或相似图像搜索)
82
+ - 上传室内场景图片或使用示例图片
83
+ - 点击"分析图像结构"按钮处理输入图像
84
+ - 根据需要调整参数
85
+ - 点击"生成设计方案"按钮创建新设计
86
+ - 从生成的多个设计方案中选择喜欢的结果
87
+ - 可选择保存设计方案供后续参考或搜索
88
+
89
+ ## 功能详解
90
+
91
+ ### 全局风格调整
92
+
93
+ 全局风格调整功能允许用户保持原始空间布局的同时,改变整个场景的设计风格。用户可以:
94
+
95
+ - 输入详细的风格描述提示词(可从预设列表中选择或自定义)
96
+ - 选择房间类型(卧室、客厅、厨房等)
97
+ - 选择风格主题(现代、北欧、工业风等)
98
+ - 调整推理步数(影响生成质量和时间)
99
+ - 调整引导比例(影响生成结果对提示词的遵循程度)
100
+ - 同时生成4个不同的设计方案进行比较
101
+ - 选择并保存喜欢的设计方案
102
+
103
+ 工作原理:
104
+ - 使用MLSD检测器提取房间的线条结构,生成控制图像
105
+ - ControlNet确保生成的图像保持原始空间布局和结构
106
+ - Stable Diffusion根据提示词和控制图像生成符合要求的设计风格
107
+
108
+ ### 局部风格调整
109
+
110
+ 局部风格调整功能允许用户针对场景中的特定区域进行风格定制,而保持其他区域不变。用户可以:
111
+
112
+ - 从下拉菜单中选择要调整的区域(墙壁、地板、家具等)
113
+ - 查看所选区域的掩码预览(红色半透明覆盖显示选中区域)
114
+ - 输入针对该区域的风格描述提示词
115
+ - 调整区域变化的强度和细节
116
+ - 生成保持整体结构的局部风格变化
117
+
118
+ 工作原理:
119
+ - 使用Mask2Former模型进行语义分割,识别图像中的不同功能区域
120
+ - 将识别的区域转换为掩码,供用户选择
121
+ - 结合ControlNet和Stable Diffusion Inpainting进行局部区域的风格调整
122
+ - 保持未选中区域不变,只修改选中区域的风格
123
+
124
+ ### 相似图像搜索
125
+
126
+ 相似图像搜索功能利用CLIP模型和FAISS索引,帮助用户查找与参考图像风格相似的设计方案。用户可以:
127
+
128
+ - 上传参考图像
129
+ - 设置搜索结果数量(2-8个)
130
+ - 查看以2x2网格布局展示的相似图像结果
131
+ - 查看每个结果的相似度百分比
132
+ - 通过"重建图像索引"按钮更新索引,包含新生成的设计方案
133
+
134
+ 工作原理:
135
+ - 使用CLIP模型提取图像的语义特征向量
136
+ - FAISS索引存储所有已生成设计方案的特征向量
137
+ - 搜索时计算查询图像与索引中所有图像的余弦相似度
138
+ - 返回相似度最高的图像作为结果
139
+
140
+ ## 项目结构
141
+
142
+ - `app.py`:主应用程序,包含Gradio界面和核心功能实现(全局风格调整、局部风格调整、相似图像搜索)
143
+ - `download_resources.py`:下载必要模型和资源的工具脚本
144
+ - `requirements.txt`:项目依赖列表
145
+ - `resources/`:存放模型、图像和标签数据的目录
146
+ - `models/`:存储AI模型(Mask2Former、ControlNet、Stable Diffusion等)
147
+ - `images/`:存储示例和生成的图像
148
+ - `labels/`:存储标签数据(如ADE20K数据集标签)
149
+ - `output/`:存储生成的设计方案
150
+ - `global_style/`:全局风格调整生成的图像
151
+ - `local_style/`:局部风格调整生成的图像
152
+ - `features/`:存储图像特征和索引文件(用于相似图像搜索)
153
+ - `image_features.index`:FAISS索引文件
154
+ - `image_metadata.pkl`:图像元数据文件
155
+
156
+ ## 技术实现
157
+
158
+ 项目使用了多种先进的AI模型和技术:
159
+
160
+ - **Mask2Former**:用于场景语义分割,识别不同功能区域(如墙壁、地板、家具等)
161
+ - **ControlNet (MLSD)**:保持原始场景的结构和布局,通过线条检测提供控制指导
162
+ - **Stable Diffusion**:生成符合提示词描述的图像内容
163
+ - **Stable Diffusion Inpainting**:针对特定区域进行图像修复和风格转换
164
+ - **CLIP**:提取图像特征,用于相似性搜索和语义理解
165
+ - **FAISS**:高效的向量相似性搜索库,支持大规模图像检索
166
+ - **Gradio**:构建直观的用户界面,支持交互式操作和实时预览
167
+ - **PyTorch**:深度学习框架,支持GPU加速的模型推理
168
+
169
+ 模型加载策略:
170
+ - 使用`torch.float16`精度减少内存占用
171
+ - 实现模型CPU卸载以优化内存使用
172
+ - 支持xformers内存优化(如果安装)
173
+ - 从本地缓存加载模型,避免重复下载
174
+
175
+ ## 注意事项
176
+
177
+ - 首次运行时需要下载较大的模型文件(约10GB),请确保有足够的磁盘空间和稳定的网络连接
178
+ - 生成过程可能需要较长时间,取决于您的硬件配置(推荐使用NVIDIA GPU)
179
+ - 为获得最佳效果,建议使用清晰的室内场景照片作为输入
180
+ - 相似图像搜索功能需要先生成并保存一些设计方案才能有效工作
181
+ - 调整推理步数可以平衡生成质量和速度,通常20-30步可以获得不错的结果
182
+ - 调整引导比例可以控制生成结果的创意程度,较高的值(7-9)会更严格遵循提示词
183
+
184
+ ## 许可证
185
+
186
+ [在此添加您的许可证信息]
187
+
188
+ ## 致谢
189
+
190
+ 本项目基于以下开源项目和模型:
191
+
192
+ - [Hugging Face Diffusers](https://github.com/huggingface/diffusers)
193
+ - [ControlNet](https://github.com/lllyasviel/ControlNet)
194
+ - [Mask2Former](https://github.com/facebookresearch/Mask2Former)
195
+ - [CLIP](https://github.com/openai/CLIP)
196
+ - [FAISS](https://github.com/facebookresearch/faiss)
197
+ - [Gradio](https://github.com/gradio-app/gradio)
198
+
199
+ ## 联系方式
200
 
201
+ [在此添加您的联系信息]
app.py ADDED
@@ -0,0 +1,1251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import json
6
+ import gradio as gr
7
+ import torchvision.transforms as transforms
8
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ from controlnet_aux import MLSDdetector
11
+ from diffusers import (
12
+ ControlNetModel,
13
+ StableDiffusionControlNetPipeline,
14
+ StableDiffusionControlNetInpaintPipeline,
15
+ UniPCMultistepScheduler
16
+ )
17
+ from diffusers.utils import load_image
18
+ import cv2
19
+ import pickle
20
+ import faiss
21
+ import datetime
22
+ import glob
23
+
24
+ # 设置资源路径
25
+ RESOURCE_DIR = "resources"
26
+ MODELS_DIR = os.path.join(RESOURCE_DIR, "models")
27
+ IMAGES_DIR = os.path.join(RESOURCE_DIR, "images")
28
+ LABELS_DIR = os.path.join(RESOURCE_DIR, "labels")
29
+ OUTPUT_DIR = os.path.join(RESOURCE_DIR, "output")
30
+ GLOBAL_SAVE_DIR = os.path.join(OUTPUT_DIR, "global_style") # 全局风格调整保存目录
31
+ LOCAL_SAVE_DIR = os.path.join(OUTPUT_DIR, "local_style") # 局部风格调整保存目录
32
+ FEATURES_DIR = os.path.join(RESOURCE_DIR, "features") # 图像特征存储目录
33
+ INDEX_PATH = os.path.join(FEATURES_DIR, "image_features.index") # FAISS索引文件
34
+ METADATA_PATH = os.path.join(FEATURES_DIR, "image_metadata.pkl") # 图像元数据文件
35
+
36
+ # 确保输出目录存在
37
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
38
+ os.makedirs(GLOBAL_SAVE_DIR, exist_ok=True)
39
+ os.makedirs(LOCAL_SAVE_DIR, exist_ok=True)
40
+ os.makedirs(FEATURES_DIR, exist_ok=True)
41
+
42
+ # 从本地JSON文件加载ADE20K数据集的标签信息
43
+ labels_path = os.path.join(LABELS_DIR, "ade20k-id2label.json")
44
+ if os.path.exists(labels_path):
45
+ with open(labels_path, 'r') as f:
46
+ LABELS = json.load(f)
47
+ else:
48
+ # 如果本地文件不存在,则从网络获取
49
+ import requests
50
+ print("本地标签文件不存在,从网络获取...")
51
+ LABELS = requests.get("https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json").json()
52
+ # 确保目录存在
53
+ os.makedirs(LABELS_DIR, exist_ok=True)
54
+ # 保存到本地
55
+ with open(labels_path, 'w') as f:
56
+ json.dump(LABELS, f)
57
+
58
+ # 全局变量存储加载的模型
59
+ processor = None
60
+ mask2former_model = None
61
+ mlsd_processor = None
62
+ controlnet = None
63
+ global_pipe = None
64
+ inpaint_pipe = None
65
+ segmentation_result = None
66
+ clip_processor = None
67
+ clip_model = None
68
+ faiss_index = None
69
+ image_metadata = {}
70
+
71
+ def load_models():
72
+ """加载所有需要的模型"""
73
+ global processor, mask2former_model, mlsd_processor, controlnet, global_pipe, inpaint_pipe, clip_processor, clip_model, faiss_index, image_metadata
74
+
75
+ # 加载 Mask2Former 模型
76
+ print("加载 Mask2Former 模型...")
77
+ processor = AutoImageProcessor.from_pretrained(
78
+ "facebook/mask2former-swin-large-ade-semantic",
79
+ cache_dir=MODELS_DIR
80
+ )
81
+ mask2former_model = Mask2FormerForUniversalSegmentation.from_pretrained(
82
+ "facebook/mask2former-swin-large-ade-semantic",
83
+ cache_dir=MODELS_DIR
84
+ )
85
+
86
+ # 加载 MLSD 检测器
87
+ print("加载 MLSD 检测器...")
88
+ mlsd_processor = MLSDdetector.from_pretrained(
89
+ "lllyasviel/Annotators",
90
+ cache_dir=MODELS_DIR
91
+ )
92
+
93
+ # 加载 ControlNet 模型
94
+ print("加载 ControlNet 模型...")
95
+ controlnet = ControlNetModel.from_pretrained(
96
+ "lllyasviel/control_v11p_sd15_mlsd",
97
+ torch_dtype=torch.float16,
98
+ cache_dir=MODELS_DIR,
99
+ use_safetensors=False
100
+ )
101
+
102
+ # 加载全局风格调整管道
103
+ print("加载 Stable Diffusion 全局风格调整模型...")
104
+ global_pipe = StableDiffusionControlNetPipeline.from_pretrained(
105
+ "runwayml/stable-diffusion-v1-5",
106
+ controlnet=controlnet,
107
+ torch_dtype=torch.float16,
108
+ cache_dir=MODELS_DIR,
109
+ use_safetensors=False
110
+ )
111
+ global_pipe.scheduler = UniPCMultistepScheduler.from_config(global_pipe.scheduler.config)
112
+ global_pipe.enable_model_cpu_offload()
113
+
114
+ # 加载局部风格调整管道
115
+ print("加载 Stable Diffusion Inpainting 局部风格调整模型...")
116
+ inpaint_pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
117
+ "runwayml/stable-diffusion-inpainting",
118
+ controlnet=controlnet,
119
+ torch_dtype=torch.float16,
120
+ cache_dir=MODELS_DIR,
121
+ use_safetensors=False
122
+ )
123
+ inpaint_pipe.scheduler = UniPCMultistepScheduler.from_config(inpaint_pipe.scheduler.config)
124
+ inpaint_pipe.enable_model_cpu_offload()
125
+
126
+ # 加载 CLIP 模型用于图像特征提取
127
+ print("加载 CLIP 模型...")
128
+ clip_processor = CLIPProcessor.from_pretrained(
129
+ "openai/clip-vit-base-patch32",
130
+ cache_dir=MODELS_DIR
131
+ )
132
+ clip_model = CLIPModel.from_pretrained(
133
+ "openai/clip-vit-base-patch32",
134
+ cache_dir=MODELS_DIR
135
+ )
136
+
137
+ # 加载或创建FAISS索引
138
+ load_or_create_index()
139
+
140
+ # 默认使用标准注意力机制
141
+ print("使用默认注意力机制")
142
+
143
+ return "所有模型加载完成!"
144
+
145
+ def extract_image_features(image):
146
+ """
147
+ 使用CLIP模型提取图像特征
148
+
149
+ Args:
150
+ image: PIL图像对象
151
+
152
+ Returns:
153
+ numpy数组,图像特征向量
154
+ """
155
+ global clip_processor, clip_model
156
+
157
+ if clip_processor is None or clip_model is None:
158
+ return None, "请先加载模型!"
159
+
160
+ # 确保图像是PIL格式
161
+ if not isinstance(image, Image.Image):
162
+ image = Image.fromarray(image)
163
+
164
+ # 使用CLIP处理图像
165
+ with torch.no_grad():
166
+ inputs = clip_processor(images=image, return_tensors="pt")
167
+ image_features = clip_model.get_image_features(**inputs)
168
+
169
+ # 归一化特征向量
170
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
171
+
172
+ # 转换为numpy数组
173
+ features = image_features.cpu().numpy().astype('float32')
174
+
175
+ return features, "特征提取成功"
176
+
177
+ def load_or_create_index():
178
+ """
179
+ 加载现有的FAISS索引或创建新索引
180
+ """
181
+ global faiss_index, image_metadata
182
+
183
+ # 检查索引文件是否存在
184
+ if os.path.exists(INDEX_PATH) and os.path.exists(METADATA_PATH):
185
+ print("加载现有的图像特征索引...")
186
+ try:
187
+ faiss_index = faiss.read_index(INDEX_PATH)
188
+ with open(METADATA_PATH, 'rb') as f:
189
+ image_metadata = pickle.load(f)
190
+ print(f"成功加载索引,包含 {faiss_index.ntotal} 张图像")
191
+ except Exception as e:
192
+ print(f"加载索引失败: {e}")
193
+ create_new_index()
194
+ else:
195
+ print("创建新的图像特征索引...")
196
+ create_new_index()
197
+
198
+ def create_new_index():
199
+ """
200
+ 创建新的FAISS索引并扫描现有图像
201
+ """
202
+ global faiss_index, image_metadata
203
+
204
+ # 创建新的索引和元数据字典
205
+ feature_dim = 512 # CLIP-ViT-B/32的特征维度
206
+ faiss_index = faiss.IndexFlatIP(feature_dim) # 使用内积相似度(余弦相似度)
207
+ image_metadata = {}
208
+
209
+ # 扫描并索引现有的图像
210
+ index_existing_images()
211
+
212
+ def index_existing_images():
213
+ """
214
+ 扫描并索引现有的设计方案图像
215
+ """
216
+ global faiss_index, image_metadata, clip_processor, clip_model
217
+
218
+ if clip_processor is None or clip_model is None:
219
+ print("CLIP模型未加载,无法索引图像")
220
+ return
221
+
222
+ # 获取所有保存的图像
223
+ global_images = glob.glob(os.path.join(GLOBAL_SAVE_DIR, "*.png"))
224
+ local_images = glob.glob(os.path.join(LOCAL_SAVE_DIR, "*.png"))
225
+ all_images = global_images + local_images
226
+
227
+ print(f"发现 {len(all_images)} 张现有图像")
228
+
229
+ # 提取并索引每张图像的特征
230
+ new_features = []
231
+ new_metadata = []
232
+
233
+ for img_path in all_images:
234
+ # 检查是否已经索引过
235
+ if img_path in image_metadata:
236
+ continue
237
+
238
+ try:
239
+ # 加载图像
240
+ img = Image.open(img_path)
241
+
242
+ # 提取特征
243
+ features, _ = extract_image_features(img)
244
+ if features is not None:
245
+ # 准备元数据
246
+ metadata = {
247
+ "path": img_path,
248
+ "filename": os.path.basename(img_path),
249
+ "type": "global" if img_path in global_images else "local",
250
+ "timestamp": datetime.datetime.fromtimestamp(os.path.getmtime(img_path)).strftime('%Y-%m-%d %H:%M:%S')
251
+ }
252
+
253
+ # 解析文件名以提取额外信息
254
+ filename = os.path.basename(img_path)
255
+ parts = filename.split('_')
256
+ if len(parts) >= 3:
257
+ metadata["room_type"] = parts[0]
258
+ metadata["style_theme"] = parts[1]
259
+
260
+ # 添加到待索引列表
261
+ new_features.append(features[0])
262
+ new_metadata.append(metadata)
263
+
264
+ # 更新元数据字典
265
+ image_metadata[img_path] = metadata
266
+ except Exception as e:
267
+ print(f"处理图像 {img_path} 时出错: {e}")
268
+
269
+ # 将新特征添加到索引
270
+ if new_features:
271
+ new_features = np.array(new_features).astype('float32')
272
+ faiss_index.add(new_features)
273
+ print(f"成功索引 {len(new_features)} 张新图像")
274
+
275
+ # 保存索引和元数据
276
+ save_index()
277
+
278
+ def save_index():
279
+ """
280
+ 保存FAISS索引和元数据到文件
281
+ """
282
+ global faiss_index, image_metadata
283
+
284
+ if faiss_index is not None and image_metadata:
285
+ try:
286
+ faiss.write_index(faiss_index, INDEX_PATH)
287
+ with open(METADATA_PATH, 'wb') as f:
288
+ pickle.dump(image_metadata, f)
289
+ print(f"索引已保存,包含 {faiss_index.ntotal} 张图像")
290
+ except Exception as e:
291
+ print(f"保存索引失败: {e}")
292
+
293
+ def add_image_to_index(image_path, image=None):
294
+ """
295
+ 将新图像添加到索引
296
+
297
+ Args:
298
+ image_path: 图像文件路径
299
+ image: 可选,PIL图像对象
300
+ """
301
+ global faiss_index, image_metadata, clip_processor, clip_model
302
+
303
+ if clip_processor is None or clip_model is None:
304
+ print("CLIP模型未加载,无法添加图像到索引")
305
+ return
306
+
307
+ # 检查图像是否已经在索引中
308
+ if image_path in image_metadata:
309
+ print(f"图像 {image_path} 已在索引中")
310
+ return
311
+
312
+ try:
313
+ # 加载图像(如果未提供)
314
+ if image is None:
315
+ image = Image.open(image_path)
316
+
317
+ # 提取特征
318
+ features, _ = extract_image_features(image)
319
+ if features is not None:
320
+ # 准备元数据
321
+ is_global = "global_style" in image_path
322
+ metadata = {
323
+ "path": image_path,
324
+ "filename": os.path.basename(image_path),
325
+ "type": "global" if is_global else "local",
326
+ "timestamp": datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
327
+ }
328
+
329
+ # 解析文件名以提取额外信息
330
+ filename = os.path.basename(image_path)
331
+ parts = filename.split('_')
332
+ if len(parts) >= 3:
333
+ metadata["room_type"] = parts[0]
334
+ metadata["style_theme"] = parts[1]
335
+
336
+ # 添加到索引
337
+ faiss_index.add(features)
338
+
339
+ # 更新元数据字典
340
+ image_metadata[image_path] = metadata
341
+
342
+ # 保存索引和元数据
343
+ save_index()
344
+
345
+ print(f"图像 {image_path} 已添加到索引")
346
+ except Exception as e:
347
+ print(f"添加图像 {image_path} 到索引时出错: {e}")
348
+
349
+ def search_similar_images(query_image, top_k=8):
350
+ """
351
+ 搜索与查询图像相似的图像
352
+
353
+ Args:
354
+ query_image: PIL图像对象或numpy数组
355
+ top_k: 返回的最相似图像数量
356
+
357
+ Returns:
358
+ 相似图像的路径列表和相似度分数
359
+ """
360
+ global faiss_index, image_metadata, clip_processor, clip_model
361
+
362
+ if faiss_index is None or clip_processor is None or clip_model is None:
363
+ return [], [], "请先加载模型!"
364
+
365
+ if faiss_index.ntotal == 0:
366
+ return [], [], "索引为空,请先生成并保存一些设计方案"
367
+
368
+ # 提取查询图像的特征
369
+ query_features, status = extract_image_features(query_image)
370
+ if query_features is None:
371
+ return [], [], status
372
+
373
+ # 执行相似度搜索
374
+ scores, indices = faiss_index.search(query_features, min(top_k, faiss_index.ntotal))
375
+
376
+ # 获取结果图像的路径和元数据
377
+ result_paths = []
378
+ result_metadata = []
379
+
380
+ for i, idx in enumerate(indices[0]):
381
+ # 获取图像路径
382
+ paths = [path for path, meta in image_metadata.items() if meta.get("index", -1) == idx]
383
+
384
+ # 如果找不到对应的索引,则使用遍历方式查找
385
+ if not paths:
386
+ # 获取所有图像路径的列表
387
+ all_paths = list(image_metadata.keys())
388
+ if idx < len(all_paths):
389
+ paths = [all_paths[idx]]
390
+
391
+ if paths:
392
+ result_paths.append(paths[0])
393
+ meta = image_metadata.get(paths[0], {})
394
+ meta["similarity"] = float(scores[0][i]) # 添加相似度分数
395
+ result_metadata.append(meta)
396
+
397
+ return result_paths, result_metadata, "搜索完成"
398
+
399
+ def get_mask_from_segmentation_map(seg_map):
400
+ """从分割图生成掩码,每个类别对应一个掩码"""
401
+ masks, labels, label_names = [], [], []
402
+
403
+ # 定义ADE20K标签的中文翻译
404
+ chinese_labels = {
405
+ "wall": "墙壁", "building": "建筑", "sky": "天空", "floor": "地板", "tree": "树",
406
+ "ceiling": "天花板", "road": "道路", "bed": "床", "windowpane": "窗户", "grass": "草地",
407
+ "cabinet": "柜子", "sidewalk": "人行道", "person": "人", "earth": "土地", "door": "门",
408
+ "table": "桌子", "mountain": "山", "plant": "植物", "curtain": "窗帘", "chair": "椅子",
409
+ "car": "汽车", "water": "水", "painting": "画", "sofa": "沙发", "shelf": "架子",
410
+ "house": "房子", "sea": "海", "mirror": "镜子", "rug": "地毯", "field": "田野",
411
+ "armchair": "扶手椅", "seat": "座位", "fence": "栅栏", "desk": "书桌", "rock": "岩石",
412
+ "wardrobe": "衣柜", "lamp": "灯", "bathtub": "浴缸", "railing": "栏杆", "cushion": "靠垫",
413
+ "base": "底座", "box": "盒子", "column": "柱子", "signboard": "招牌", "chest of drawers": "抽屉柜",
414
+ "counter": "柜台", "sand": "沙子", "sink": "水槽", "skyscraper": "摩天大楼", "fireplace": "壁炉",
415
+ "refrigerator": "冰箱", "grandstand": "看台", "path": "小路", "stairs": "楼梯", "runway": "跑道",
416
+ "case": "箱子", "pool table": "台球桌", "pillow": "枕头", "screen door": "纱门", "stairway": "阶梯",
417
+ "river": "河流", "bridge": "桥", "bookcase": "书柜", "blind": "百叶窗", "coffee table": "咖啡桌",
418
+ "toilet": "马桶", "flower": "花", "book": "书", "hill": "山丘", "bench": "长凳",
419
+ "countertop": "台面", "stove": "炉子", "palm": "棕榈树", "kitchen island": "厨房中岛", "computer": "电脑",
420
+ "swivel chair": "旋转椅", "boat": "船", "bar": "吧台", "arcade machine": "街机", "hovel": "小屋",
421
+ "bus": "公交车", "towel": "毛巾", "light": "灯光", "truck": "卡车", "tower": "塔",
422
+ "chandelier": "吊灯", "awning": "遮阳篷", "streetlight": "路灯", "booth": "摊位", "television receiver": "电视机",
423
+ "airplane": "飞机", "dirt track": "泥路", "apparel": "服装", "pole": "杆子", "land": "陆地",
424
+ "bannister": "栏杆", "escalator": "自动扶梯", "ottoman": "脚凳", "bottle": "瓶子", "buffet": "自助餐",
425
+ "poster": "海报", "stage": "舞台", "van": "货车", "ship": "轮船", "fountain": "喷泉",
426
+ "conveyer belt": "传送带", "canopy": "天篷", "washer": "洗衣机", "plaything": "玩具", "swimming pool": "游泳池",
427
+ "stool": "凳子", "barrel": "桶", "basket": "篮子", "waterfall": "瀑布", "tent": "帐篷",
428
+ "bag": "包", "minibike": "小型摩托车", "cradle": "摇篮", "oven": "烤箱", "ball": "球",
429
+ "food": "食物", "step": "台阶", "tank": "水箱", "trade name": "商标", "microwave": "微波炉",
430
+ "pot": "锅", "animal": "动物", "bicycle": "自行车", "lake": "湖", "dishwasher": "洗碗机",
431
+ "screen": "屏幕", "blanket": "毯子", "sculpture": "雕塑", "hood": "引擎盖", "sconce": "壁灯",
432
+ "vase": "花瓶", "traffic light": "交通灯", "tray": "托盘", "ashcan": "垃圾桶", "fan": "风扇",
433
+ "pier": "码头", "crt screen": "显示器", "plate": "盘子", "monitor": "显示器", "bulletin board": "公告板",
434
+ "shower": "淋浴", "radiator": "暖气片", "glass": "玻璃", "clock": "时钟", "flag": "旗帜"
435
+ }
436
+
437
+ for label in range(150): # ADE20K数据集有150个类别
438
+ mask = np.ones((seg_map.shape[0], seg_map.shape[1]), dtype=np.uint8)
439
+ indices = (seg_map == label)
440
+ mask[indices] = 0 # 将目标区域设为0,背景为1
441
+ if indices.sum() > 0: # 如果存在该类别
442
+ masks.append(mask)
443
+ labels.append(label)
444
+
445
+ # 获取英文标签
446
+ english_label = LABELS[str(label)]
447
+
448
+ # 查找中文翻译,如果没有则使用英文
449
+ chinese_label = chinese_labels.get(english_label, english_label)
450
+
451
+ # 添加带有中文翻译的标签
452
+ label_names.append(f"{label}: {english_label} - {chinese_label}")
453
+
454
+ print(f"创建了 {len(masks)} 个掩码")
455
+ for idx, label in enumerate(labels):
456
+ print(f"索引: {idx}\t类别ID: {label}\t标签: {LABELS[str(label)]}")
457
+
458
+ return masks, labels, label_names
459
+
460
+ def segment_image(image):
461
+ """对图像进行语义分割"""
462
+ global segmentation_result, processor, mask2former_model, mlsd_processor
463
+
464
+ if processor is None or mask2former_model is None or mlsd_processor is None:
465
+ return None, "请先加载模型!", []
466
+
467
+ # 调整图像大小
468
+ image_pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
469
+ image_pil = image_pil.resize((768, 512))
470
+
471
+ # 进行语义分割
472
+ inputs = processor(images=[image_pil], return_tensors="pt")
473
+ outputs = mask2former_model(**inputs)
474
+ predicted_semantic_map = processor.post_process_semantic_segmentation(
475
+ outputs, target_sizes=[image_pil.size[::-1]]
476
+ )[0]
477
+
478
+ # 生成分割掩码
479
+ masks, labels, label_names = get_mask_from_segmentation_map(predicted_semantic_map)
480
+
481
+ # 保存分割结果供后续使用
482
+ segmentation_result = {
483
+ "image": image_pil,
484
+ "masks": masks,
485
+ "labels": labels,
486
+ "label_names": label_names,
487
+ "semantic_map": predicted_semantic_map
488
+ }
489
+
490
+ # 生成控制图像
491
+ control_image = mlsd_processor(image_pil)
492
+
493
+ print(f"分割完成,找到 {len(label_names)} 个区域: {label_names}")
494
+
495
+ return control_image, f"图像分割完成,找到 {len(label_names)} 个可调整区域", label_names
496
+
497
+ def adjust_global_style(prompt, negative_prompt, room_type, style_theme, num_steps, guidance_scale, num_images=4):
498
+ """全局风格调整"""
499
+ global segmentation_result, global_pipe, mlsd_processor
500
+
501
+ if segmentation_result is None:
502
+ return [None] * num_images + ["请先进行图像分割!"]
503
+
504
+ if global_pipe is None or mlsd_processor is None:
505
+ return [None] * num_images + ["请先加载模型!"]
506
+
507
+ # 获取原始图像
508
+ image = segmentation_result["image"]
509
+
510
+ # 生成控制图像
511
+ control_image = mlsd_processor(image)
512
+
513
+ # 提取英文部分(去除中文描述)
514
+ room_type = room_type.split(" - ")[0]
515
+ style_theme = style_theme.split(" - ")[0]
516
+
517
+ # 构建完整提示词,结合房间类型和风格主题
518
+ full_prompt = f"A {style_theme} style {room_type}, {prompt}"
519
+
520
+ # 设置生成参数
521
+ prompts = [full_prompt] * num_images
522
+ negative_prompts = [negative_prompt] * num_images
523
+ generator = [torch.Generator(device="cuda").manual_seed(int(i)) for i in np.random.randint(1000, size=num_images)]
524
+
525
+ # 执行图像生成
526
+ output = global_pipe(
527
+ prompts,
528
+ image=control_image, # 直接使用控制图像
529
+ negative_prompt=negative_prompts,
530
+ num_inference_steps=num_steps,
531
+ generator=generator,
532
+ guidance_scale=guidance_scale
533
+ )
534
+
535
+ # 保存生成的图像到临时位置
536
+ for i, img in enumerate(output.images):
537
+ img.save(os.path.join(OUTPUT_DIR, f"global_style_{i+1}.png"))
538
+
539
+ # 保存生成的图像到管道对象,以便后续保存
540
+ global_pipe._last_images = output.images
541
+
542
+ # 返回单独的图像和状态文本,而不是列表+文本
543
+ return output.images[0], output.images[1], output.images[2], output.images[3], "全局风格调整完成!"
544
+
545
+ def adjust_local_style(prompt, negative_prompt, mask_label, room_type, style_theme, num_steps, guidance_scale, num_images=4):
546
+ """局部风格调整(Inpainting)"""
547
+ global segmentation_result, inpaint_pipe, mlsd_processor
548
+
549
+ if segmentation_result is None:
550
+ return [None] * num_images + ["请先进行图像分割!"]
551
+
552
+ if inpaint_pipe is None or mlsd_processor is None:
553
+ return [None] * num_images + ["请先加载模型!"]
554
+
555
+ # 获取原始图像和选定的掩码
556
+ image = segmentation_result["image"]
557
+ masks = segmentation_result["masks"]
558
+ labels = segmentation_result["labels"]
559
+ label_names = segmentation_result["label_names"]
560
+
561
+ # 找到选定标签对应的掩码索引
562
+ try:
563
+ if mask_label is None or mask_label == "":
564
+ return [None] * num_images + ["请选择要调整的区域"]
565
+
566
+ # 找到选中的标签在label_names中的索引
567
+ mask_id = label_names.index(mask_label)
568
+ except (ValueError, IndexError, AttributeError):
569
+ return [None] * num_images + ["无效的区域选择,请重新选择"]
570
+
571
+ # 生成控制图像
572
+ control_image = mlsd_processor(image)
573
+
574
+ # 将控制图像和原始图像混合,创建更自然的控制引导
575
+ control_tensor = transforms.ToTensor()(control_image)
576
+ image_tensor = transforms.ToTensor()(image)
577
+ mixed_control_tensor = control_tensor * 0.5 + image_tensor * 0.5
578
+ mixed_control_image = transforms.ToPILImage()(mixed_control_tensor)
579
+
580
+ # 处理掩码并创建用于修复的遮罩图像
581
+ mask = torch.Tensor(masks[mask_id])
582
+ object_mask = 1 - mask # 反转掩码,0变为1,1变为0
583
+ mask_image = transforms.ToPILImage()(object_mask.unsqueeze(0))
584
+
585
+ # 提取英文部分(去除中文描述)
586
+ room_type = room_type.split(" - ")[0]
587
+ style_theme = style_theme.split(" - ")[0]
588
+
589
+ # 构建完整提示词,结合房间类型和风格主题
590
+ full_prompt = f"A {style_theme} style {room_type}, {prompt}"
591
+
592
+ # 设置生成参数
593
+ prompts = [full_prompt] * num_images
594
+ negative_prompts = [negative_prompt] * num_images
595
+ generator = [torch.Generator(device="cuda").manual_seed(int(i)) for i in np.random.randint(1000, size=num_images)]
596
+
597
+ # 执行图像生成
598
+ output = inpaint_pipe(
599
+ prompts,
600
+ image=image,
601
+ mask_image=mask_image,
602
+ control_image=mixed_control_image,
603
+ negative_prompt=negative_prompts,
604
+ num_inference_steps=num_steps,
605
+ generator=generator,
606
+ controlnet_conditioning_scale=0.7,
607
+ guidance_scale=guidance_scale
608
+ )
609
+
610
+ # 保存生成的图像到临时位置
611
+ for i, img in enumerate(output.images):
612
+ img.save(os.path.join(OUTPUT_DIR, f"local_style_{i+1}.png"))
613
+
614
+ # 保存生成的图像到管道对象,以便后续保存
615
+ inpaint_pipe._last_images = output.images
616
+
617
+ # 返回单独的图像和状态文本,而不是列表+文本
618
+ return output.images[0], output.images[1], output.images[2], output.images[3], "局部风格调整���成!"
619
+
620
+ # 显示选定区域的掩码
621
+ def display_selected_mask(mask_label):
622
+ """根据选择的区域标签显示对应的掩码图像"""
623
+ global segmentation_result
624
+
625
+ if segmentation_result is None:
626
+ return None, "请先进行图像分割!"
627
+
628
+ if mask_label is None or mask_label == "":
629
+ return None, "请选择要调整的区域"
630
+
631
+ try:
632
+ # 获取掩码和标签
633
+ masks = segmentation_result["masks"]
634
+ label_names = segmentation_result["label_names"]
635
+ image = segmentation_result["image"]
636
+
637
+ # 找到选中的标签在label_names中的索引
638
+ mask_id = label_names.index(mask_label)
639
+
640
+ # 获取对应的掩码
641
+ mask = masks[mask_id]
642
+
643
+ # 创建彩色掩码图像以便更好地可视化
644
+ # 创建RGB图像,将选中区域标记为红色
645
+ mask_rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
646
+ mask_rgb[mask == 0] = [255, 0, 0] # 红色表示选中的区域
647
+
648
+ # 将原始图像和掩码混合,使掩码半透明
649
+ image_np = np.array(image)
650
+ image_np = cv2.resize(image_np, (mask.shape[1], mask.shape[0]))
651
+
652
+ # 创建混合图像
653
+ alpha = 0.5
654
+ mask_overlay = cv2.addWeighted(image_np, 1 - alpha, mask_rgb, alpha, 0)
655
+
656
+ # 将NumPy数组转换为PIL图像
657
+ mask_image = Image.fromarray(mask_overlay)
658
+
659
+ return mask_image, f"已选择区域: {mask_label}"
660
+ except (ValueError, IndexError, AttributeError) as e:
661
+ print(f"显示掩码时出错: {e}")
662
+ return None, f"无法显示所选区域: {str(e)}"
663
+
664
+ # 保存设计方案
665
+ def save_global_style(image_indices, room_type, style_theme):
666
+ """保存全局风格调整的设计方案"""
667
+ global global_pipe
668
+
669
+ if global_pipe is None:
670
+ return "请先加载模型!"
671
+
672
+ if not hasattr(global_pipe, "_last_images") or not global_pipe._last_images:
673
+ return "没有可保存的图像!"
674
+
675
+ # 提取房间类型和风格主题的英文部分
676
+ room_type_en = room_type.split(" - ")[0]
677
+ style_theme_en = style_theme.split(" - ")[0]
678
+
679
+ # 获取当前保存目录中的文件数量,用于自增编号
680
+ existing_files = [f for f in os.listdir(GLOBAL_SAVE_DIR) if f.startswith(f"{room_type_en}_{style_theme_en}_")]
681
+ start_index = len(existing_files) + 1
682
+
683
+ saved_paths = []
684
+ for i, idx in enumerate(image_indices):
685
+ if 0 <= idx - 1 < len(global_pipe._last_images):
686
+ image = global_pipe._last_images[idx - 1]
687
+
688
+ # 构建简洁的文件名: room_style_number.png
689
+ filename = f"{room_type_en}_{style_theme_en}_{start_index + i}.png"
690
+ save_path = os.path.join(GLOBAL_SAVE_DIR, filename)
691
+
692
+ # 保存图像
693
+ image.save(save_path)
694
+ saved_paths.append(save_path)
695
+
696
+ # 将图像添加到索引
697
+ add_image_to_index(save_path, image)
698
+
699
+ if saved_paths:
700
+ return f"已保存 {len(saved_paths)} 张设计方案到 {GLOBAL_SAVE_DIR}"
701
+ else:
702
+ return "没有保存任何图像"
703
+
704
+ def save_local_style(image_indices, room_type, style_theme, mask_label):
705
+ """保存局部风格调整的设计方案"""
706
+ global inpaint_pipe
707
+
708
+ if inpaint_pipe is None:
709
+ return "请先加载模型!"
710
+
711
+ if not hasattr(inpaint_pipe, "_last_images") or not inpaint_pipe._last_images:
712
+ return "没有可保存的图像!"
713
+
714
+ # 提取房间类型和风格主题的英文部分
715
+ room_type_en = room_type.split(" - ")[0]
716
+ style_theme_en = style_theme.split(" - ")[0]
717
+
718
+ # 提取区域标签
719
+ region_label = "unknown"
720
+ if mask_label:
721
+ try:
722
+ region_label = mask_label.split(":")[1].split("-")[0].strip()
723
+ except:
724
+ pass
725
+
726
+ # 获取当前保存目录中的文件数量,用于自增编号
727
+ existing_files = [f for f in os.listdir(LOCAL_SAVE_DIR) if f.startswith(f"{room_type_en}_{style_theme_en}_{region_label}_")]
728
+ start_index = len(existing_files) + 1
729
+
730
+ saved_paths = []
731
+ for i, idx in enumerate(image_indices):
732
+ if 0 <= idx - 1 < len(inpaint_pipe._last_images):
733
+ image = inpaint_pipe._last_images[idx - 1]
734
+
735
+ # 构建简洁的文件名: room_style_region_number.png
736
+ filename = f"{room_type_en}_{style_theme_en}_{region_label}_{start_index + i}.png"
737
+ save_path = os.path.join(LOCAL_SAVE_DIR, filename)
738
+
739
+ # 保存图像
740
+ image.save(save_path)
741
+ saved_paths.append(save_path)
742
+
743
+ # 将图像添加到索引
744
+ add_image_to_index(save_path, image)
745
+
746
+ if saved_paths:
747
+ return f"已保存 {len(saved_paths)} 张设计方案到 {LOCAL_SAVE_DIR}"
748
+ else:
749
+ return "没有保存任何图像"
750
+
751
+ def perform_image_search(query_image, top_k=8):
752
+ """
753
+ 执行图像相似度搜索并返回结果
754
+
755
+ Args:
756
+ query_image: 查询图像
757
+ top_k: 返回的结果数量
758
+
759
+ Returns:
760
+ 相似图像列表、相似度分数列表和状态信息
761
+ """
762
+ # 执行相似度搜索
763
+ result_paths, result_metadata, status = search_similar_images(query_image, top_k)
764
+
765
+ if not result_paths:
766
+ return [], [], status
767
+
768
+ # 加载结果图像和相似度分数
769
+ result_images = []
770
+ similarity_scores = []
771
+
772
+ for i, path in enumerate(result_paths):
773
+ try:
774
+ img = Image.open(path)
775
+ result_images.append(img)
776
+
777
+ # 获取相似度分数(转换为百分比)
778
+ similarity = result_metadata[i].get("similarity", 0)
779
+ similarity_percentage = f"相似度: {similarity * 100:.1f}%"
780
+ similarity_scores.append(similarity_percentage)
781
+ except Exception as e:
782
+ print(f"加载图像 {path} 时出错: {e}")
783
+
784
+ # 确保只返回请求的数量
785
+ if len(result_images) > top_k:
786
+ result_images = result_images[:top_k]
787
+ similarity_scores = similarity_scores[:top_k]
788
+
789
+ return result_images, similarity_scores, status
790
+
791
+ # 创建Gradio界面
792
+ def create_interface():
793
+ with gr.Blocks(title="AI房间设计助手", css="""
794
+ #region-dropdown .wrap {
795
+ max-height: 300px;
796
+ overflow-y: auto;
797
+ z-index: 999;
798
+ position: relative;
799
+ }
800
+ #region-dropdown .wrap::-webkit-scrollbar {
801
+ width: 10px;
802
+ }
803
+ #region-dropdown .wrap::-webkit-scrollbar-track {
804
+ background: #f1f1f1;
805
+ }
806
+ #region-dropdown .wrap::-webkit-scrollbar-thumb {
807
+ background: #888;
808
+ }
809
+ #region-dropdown .wrap::-webkit-scrollbar-thumb:hover {
810
+ background: #555;
811
+ }
812
+ .similar-image {
813
+ border: 1px solid #ddd;
814
+ border-radius: 8px;
815
+ padding: 5px;
816
+ transition: transform 0.2s;
817
+ }
818
+ .similar-image:hover {
819
+ transform: scale(1.05);
820
+ box-shadow: 0 0 10px rgba(0,0,0,0.2);
821
+ }
822
+ /* 相似图像结果滚动窗口样式 */
823
+ .similar-results-container {
824
+ max-height: 600px;
825
+ overflow-y: auto;
826
+ padding: 10px;
827
+ border: 1px solid #eee;
828
+ border-radius: 8px;
829
+ background-color: #f9f9f9;
830
+ }
831
+ .similar-results-container::-webkit-scrollbar {
832
+ width: 10px;
833
+ }
834
+ .similar-results-container::-webkit-scrollbar-track {
835
+ background: #f1f1f1;
836
+ border-radius: 8px;
837
+ }
838
+ .similar-results-container::-webkit-scrollbar-thumb {
839
+ background: #888;
840
+ border-radius: 8px;
841
+ }
842
+ .similar-results-container::-webkit-scrollbar-thumb:hover {
843
+ background: #555;
844
+ }
845
+ .result-item {
846
+ margin-bottom: 15px;
847
+ }
848
+ """) as app:
849
+ gr.Markdown("# AI房间设计助手")
850
+ gr.Markdown("## 使用ControlNet和Stable Diffusion进行房间风格调整")
851
+
852
+ # 定义房间类型和风格主题选项
853
+ room_types = [
854
+ "living room - 客厅",
855
+ "bedroom - 卧室",
856
+ "kitchen - 厨房",
857
+ "bathroom - 浴室",
858
+ "dining room - 餐厅",
859
+ "office - 办公室",
860
+ "study room - 书房",
861
+ "children's room - 儿童房"
862
+ ]
863
+
864
+ style_themes = [
865
+ "modern - 现代",
866
+ "minimalist - 极简",
867
+ "Scandinavian - 北欧",
868
+ "industrial - 工业风",
869
+ "rustic - 乡村",
870
+ "traditional - 传统",
871
+ "contemporary - 当代",
872
+ "mid-century modern - 中世纪现代",
873
+ "bohemian - 波西米亚",
874
+ "coastal - 海岸风",
875
+ "farmhouse - 农舍",
876
+ "luxury - 奢华"
877
+ ]
878
+
879
+ # 定义提示词预设
880
+ prompt_presets = {
881
+ "简约舒适": "clean lines, comfortable seating, natural light, warm tones, simple decor",
882
+ "奢华典雅": "elegant furnishings, crystal chandelier, marble surfaces, plush seating, gold accents",
883
+ "自然原木": "wooden furniture, plants, natural materials, earth tones, organic textures",
884
+ "明亮通透": "large windows, white walls, light wood floors, minimal furniture, airy space",
885
+ "复古怀旧": "vintage furniture, retro color palette, antique accessories, classic patterns",
886
+ "工业风格": "exposed brick, metal fixtures, concrete floors, raw materials, minimal decor",
887
+ "温馨家庭": "comfortable seating, soft textiles, family photos, warm lighting, cozy atmosphere",
888
+ "艺术创意": "colorful accents, unique art pieces, creative lighting, bold patterns, artistic elements"
889
+ }
890
+
891
+ # 定义负面提示词预设
892
+ negative_prompt_presets = {
893
+ "标准负面提示词": "cluttered, dark, oversaturated, poor quality, blurry, unrealistic",
894
+ "避免过度装饰": "over decorated, cluttered, busy, chaotic, messy, disorganized",
895
+ "避免昏暗效果": "dark, gloomy, dim, shadowy, poorly lit, murky",
896
+ "避免不真实效果": "unrealistic, cartoon, anime, illustration, painting, drawing, 3d render",
897
+ "避免低质量": "poor quality, low resolution, blurry, noisy, distorted, deformed",
898
+ "避免人物": "people, person, human, face, hands, fingers",
899
+ "避免文字": "text, letters, words, signage, labels, logos",
900
+ "避免奇怪构图": "cropped, cut off, weird angle, distorted perspective, bad composition"
901
+ }
902
+
903
+ # 模型加载按钮
904
+ with gr.Row():
905
+ load_models_btn = gr.Button("加载模型")
906
+ model_status = gr.Textbox(label="模型状态", value="未加载")
907
+
908
+ # 创建选项卡界面
909
+ with gr.Tabs() as tabs:
910
+ # 全局风格调整选项卡
911
+ with gr.TabItem("全局风格调整"):
912
+ with gr.Row():
913
+ with gr.Column(scale=1):
914
+ # 输入区域
915
+ input_image = gr.Image(label="输入图像", type="pil")
916
+ segment_btn = gr.Button("分析图像结构")
917
+
918
+ # 参数设置
919
+ room_type = gr.Dropdown(label="房间类型", choices=room_types, value="living room - 客厅")
920
+ style_theme = gr.Dropdown(label="主题风格", choices=style_themes, value="modern - 现代")
921
+
922
+ # 提示词预设和输入
923
+ prompt_preset = gr.Dropdown(label="提示词预设", choices=list(prompt_presets.keys()), value="简约舒适")
924
+ prompt = gr.Textbox(label="提示词", value=prompt_presets["简约舒适"])
925
+
926
+ # 负面提示词预设和输入
927
+ negative_prompt_preset = gr.Dropdown(label="负面提示词预设", choices=list(negative_prompt_presets.keys()), value="标准负面提示词")
928
+ negative_prompt = gr.Textbox(label="负面提示词", value=negative_prompt_presets["标准负面提示词"])
929
+
930
+ num_steps = gr.Slider(label="推理步数", minimum=10, maximum=50, step=1, value=30)
931
+ guidance_scale = gr.Slider(label="引导比例", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
932
+
933
+ # 生成按钮
934
+ generate_btn = gr.Button("生成设计方案")
935
+
936
+ with gr.Column(scale=1):
937
+ # 预览区域
938
+ control_image = gr.Image(label="结构控制图像")
939
+ status_text = gr.Textbox(label="状态信息")
940
+
941
+ # 结果展示区域
942
+ gr.Markdown("### 设计方案")
943
+ with gr.Row():
944
+ output_images = [gr.Image(label=f"方案 {i+1}") for i in range(2)]
945
+ with gr.Row():
946
+ output_images.extend([gr.Image(label=f"方案 {i+3}") for i in range(2)])
947
+
948
+ # 保存按钮区域
949
+ gr.Markdown("### 保存设计方案")
950
+ with gr.Row():
951
+ save_image_index = gr.CheckboxGroup(label="选择要保存的方案", choices=["方案 1", "方案 2", "方案 3", "方案 4"], value=[])
952
+ save_btn = gr.Button("保存选中的设计方案")
953
+ save_status = gr.Textbox(label="保存状态")
954
+
955
+ # 局部风格调整选项卡
956
+ with gr.TabItem("局部风格调整"):
957
+ with gr.Row():
958
+ with gr.Column(scale=1):
959
+ # 输入区域
960
+ input_image_local = gr.Image(label="输入图像", type="pil")
961
+ segment_btn_local = gr.Button("分析图像结构")
962
+
963
+ # 参数设置
964
+ region_choices = gr.Textbox(visible=False) # 隐藏的文本框用于存储区域选项
965
+ with gr.Row(elem_id="region-dropdown"):
966
+ mask_label_local = gr.Dropdown(label="选择调整区域", choices=[], interactive=True)
967
+ room_type_local = gr.Dropdown(label="房间类型", choices=room_types, value="living room - 客厅")
968
+ style_theme_local = gr.Dropdown(label="主题风格", choices=style_themes, value="modern - 现代")
969
+
970
+ # 提示词预设和输入
971
+ prompt_preset_local = gr.Dropdown(label="提示词预设", choices=list(prompt_presets.keys()), value="简约舒适")
972
+ prompt_local = gr.Textbox(label="提示词", value=prompt_presets["简约舒适"])
973
+
974
+ # 负面提示词预设和输入
975
+ negative_prompt_preset_local = gr.Dropdown(label="负面提示词预设", choices=list(negative_prompt_presets.keys()), value="标准负面提示词")
976
+ negative_prompt_local = gr.Textbox(label="负面提示词", value=negative_prompt_presets["标准负面提示词"])
977
+
978
+ num_steps_local = gr.Slider(label="推理步数", minimum=10, maximum=50, step=1, value=30)
979
+ guidance_scale_local = gr.Slider(label="引导比例", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
980
+
981
+ # 生成按钮
982
+ generate_btn_local = gr.Button("生成设计方案")
983
+ update_regions_btn = gr.Button("更新区域列表", visible=False) # 隐藏的按钮用于触发更新
984
+
985
+ with gr.Column(scale=1):
986
+ # 预览区域
987
+ control_image_local = gr.Image(label="区域掩码图像")
988
+ status_text_local = gr.Textbox(label="状态信息")
989
+
990
+ # 结果展示区域
991
+ gr.Markdown("### 设计方案")
992
+ with gr.Row():
993
+ output_images_local = [gr.Image(label=f"方案 {i+1}") for i in range(2)]
994
+ with gr.Row():
995
+ output_images_local.extend([gr.Image(label=f"方案 {i+3}") for i in range(2)])
996
+
997
+ # 保存按钮区域
998
+ gr.Markdown("### 保存设计方案")
999
+ with gr.Row():
1000
+ save_image_index_local = gr.CheckboxGroup(label="选择要保存的方案", choices=["方案 1", "方案 2", "方案 3", "方案 4"], value=[])
1001
+ save_btn_local = gr.Button("保存选中的设计方案")
1002
+ save_status_local = gr.Textbox(label="保存状态")
1003
+
1004
+ # 图像相似性搜索选项卡
1005
+ with gr.TabItem("相似图像搜索"):
1006
+ with gr.Row():
1007
+ with gr.Column(scale=1):
1008
+ # 输入区域
1009
+ gr.Markdown("### 上传参考图像")
1010
+ reference_image = gr.Image(label="参考图像", type="pil")
1011
+
1012
+ # 搜索参数
1013
+ num_results = gr.Slider(label="搜索结果数量", minimum=2, maximum=8, step=2, value=4)
1014
+
1015
+ # 搜索按钮
1016
+ search_btn = gr.Button("搜索相似图像")
1017
+ search_status = gr.Textbox(label="搜索状态")
1018
+
1019
+ # 索引管理
1020
+ gr.Markdown("### 索引管理")
1021
+ rebuild_index_btn = gr.Button("重建图像索引")
1022
+ index_status = gr.Textbox(label="索引状态")
1023
+
1024
+ with gr.Column(scale=1):
1025
+ # 结果展示区域
1026
+ gr.Markdown("### 相似图像结果")
1027
+
1028
+ # 创建一个带滚动条的容器来动态显示结果
1029
+ with gr.Column(elem_classes="similar-results-container") as result_container:
1030
+ # 创建所有可能的结果行(最多8个结果,2x2布局)
1031
+ # 第一行(结果1-2)
1032
+ with gr.Row(visible=True, elem_classes="result-item") as row1:
1033
+ similar_images_row1 = [gr.Image(label=f"结果 {i+1}", elem_classes="similar-image") for i in range(2)]
1034
+ with gr.Row(visible=True, elem_classes="result-item") as score_row1:
1035
+ similarity_scores_row1 = [gr.Textbox(label="相似度", elem_classes="similarity-score") for _ in range(2)]
1036
+
1037
+ # 第二行(结果3-4)
1038
+ with gr.Row(visible=True, elem_classes="result-item") as row2:
1039
+ similar_images_row2 = [gr.Image(label=f"结果 {i+3}", elem_classes="similar-image") for i in range(2)]
1040
+ with gr.Row(visible=True, elem_classes="result-item") as score_row2:
1041
+ similarity_scores_row2 = [gr.Textbox(label="相似度", elem_classes="similarity-score") for _ in range(2)]
1042
+
1043
+ # 第三行(结果5-6)
1044
+ with gr.Row(visible=True, elem_classes="result-item") as row3:
1045
+ similar_images_row3 = [gr.Image(label=f"结果 {i+5}", elem_classes="similar-image") for i in range(2)]
1046
+ with gr.Row(visible=True, elem_classes="result-item") as score_row3:
1047
+ similarity_scores_row3 = [gr.Textbox(label="相似度", elem_classes="similarity-score") for _ in range(2)]
1048
+
1049
+ # 第四行(结果7-8)
1050
+ with gr.Row(visible=True, elem_classes="result-item") as row4:
1051
+ similar_images_row4 = [gr.Image(label=f"结果 {i+7}", elem_classes="similar-image") for i in range(2)]
1052
+ with gr.Row(visible=True, elem_classes="result-item") as score_row4:
1053
+ similarity_scores_row4 = [gr.Textbox(label="相似度", elem_classes="similarity-score") for _ in range(2)]
1054
+
1055
+ # 合并所有结果图像组件和相似度分数组件
1056
+ similar_images = similar_images_row1 + similar_images_row2 + similar_images_row3 + similar_images_row4
1057
+ similarity_scores = similarity_scores_row1 + similarity_scores_row2 + similarity_scores_row3 + similarity_scores_row4
1058
+
1059
+ # 保存所有行的引用,用于控制可见性
1060
+ image_rows = [row1, row2, row3, row4]
1061
+ score_rows = [score_row1, score_row2, score_row3, score_row4]
1062
+
1063
+ # 设置事件处理
1064
+ load_models_btn.click(load_models, inputs=[], outputs=[model_status])
1065
+
1066
+ # 全局风格调整事件
1067
+ segment_btn.click(
1068
+ segment_image,
1069
+ inputs=[input_image],
1070
+ outputs=[control_image, status_text, region_choices]
1071
+ )
1072
+
1073
+ # 提示词预设选择事件
1074
+ def update_prompt(preset_name):
1075
+ return prompt_presets.get(preset_name, "")
1076
+
1077
+ def update_negative_prompt(preset_name):
1078
+ return negative_prompt_presets.get(preset_name, "")
1079
+
1080
+ prompt_preset.change(
1081
+ update_prompt,
1082
+ inputs=[prompt_preset],
1083
+ outputs=[prompt]
1084
+ )
1085
+
1086
+ negative_prompt_preset.change(
1087
+ update_negative_prompt,
1088
+ inputs=[negative_prompt_preset],
1089
+ outputs=[negative_prompt]
1090
+ )
1091
+
1092
+ # 局部风格调整的提示词预设选择事件
1093
+ prompt_preset_local.change(
1094
+ update_prompt,
1095
+ inputs=[prompt_preset_local],
1096
+ outputs=[prompt_local]
1097
+ )
1098
+
1099
+ negative_prompt_preset_local.change(
1100
+ update_negative_prompt,
1101
+ inputs=[negative_prompt_preset_local],
1102
+ outputs=[negative_prompt_local]
1103
+ )
1104
+
1105
+ generate_btn.click(
1106
+ adjust_global_style,
1107
+ inputs=[prompt, negative_prompt, room_type, style_theme, num_steps, guidance_scale],
1108
+ outputs=output_images + [status_text]
1109
+ )
1110
+
1111
+ # 局部风格调整事件
1112
+ # 分割图像并存储区域列表
1113
+ def process_segmentation_local(image):
1114
+ control_img, status, label_choices = segment_image(image)
1115
+ # 将选项列表转换为字符串存储
1116
+ choices_str = "|||".join(label_choices)
1117
+ return control_img, status, choices_str
1118
+
1119
+ # 更新下拉菜单选项
1120
+ def update_dropdown(choices_str):
1121
+ if not choices_str:
1122
+ return gr.Dropdown(choices=[])
1123
+ choices = choices_str.split("|||")
1124
+ return gr.Dropdown(choices=choices)
1125
+
1126
+ segment_btn_local.click(
1127
+ process_segmentation_local,
1128
+ inputs=[input_image_local],
1129
+ outputs=[control_image_local, status_text_local, region_choices]
1130
+ )
1131
+
1132
+ # 使用region_choices更新下拉菜单
1133
+ region_choices.change(
1134
+ update_dropdown,
1135
+ inputs=[region_choices],
1136
+ outputs=[mask_label_local]
1137
+ )
1138
+
1139
+ # 当用户选择区域时,更新掩码图像
1140
+ mask_label_local.change(
1141
+ display_selected_mask,
1142
+ inputs=[mask_label_local],
1143
+ outputs=[control_image_local, status_text_local]
1144
+ )
1145
+
1146
+ generate_btn_local.click(
1147
+ adjust_local_style,
1148
+ inputs=[prompt_local, negative_prompt_local, mask_label_local, room_type_local, style_theme_local, num_steps_local, guidance_scale_local],
1149
+ outputs=output_images_local + [status_text_local]
1150
+ )
1151
+
1152
+ # 保存设计方案事件
1153
+ def process_save_global(image_indices, room_type, style_theme):
1154
+ # 从选择的方案中提取索引号
1155
+ indices = [int(idx.split(" ")[1]) for idx in image_indices]
1156
+ return save_global_style(indices, room_type, style_theme)
1157
+
1158
+ def process_save_local(image_indices, room_type, style_theme, mask_label):
1159
+ # 从选择的方案中提取索引号
1160
+ indices = [int(idx.split(" ")[1]) for idx in image_indices]
1161
+ return save_local_style(indices, room_type, style_theme, mask_label)
1162
+
1163
+ # 全局风格调整保存按钮事件
1164
+ save_btn.click(
1165
+ process_save_global,
1166
+ inputs=[save_image_index, room_type, style_theme],
1167
+ outputs=[save_status]
1168
+ )
1169
+
1170
+ # 局部风格调整保存按钮事件
1171
+ save_btn_local.click(
1172
+ process_save_local,
1173
+ inputs=[save_image_index_local, room_type_local, style_theme_local, mask_label_local],
1174
+ outputs=[save_status_local]
1175
+ )
1176
+
1177
+ # 图像相似性搜索事件
1178
+ def handle_image_search(query_image, num_results):
1179
+ """处理图像相似性搜索请求"""
1180
+ if query_image is None:
1181
+ # 返回空结果列表,每个图像组件对应一个None
1182
+ empty_results = [None] * 8 # 固定返回8个None,对应8个图像组件
1183
+ empty_scores = [""] * 8 # 固定返回8个空字符串,对应8个相似度标签
1184
+
1185
+ # 隐藏所有额外结果行
1186
+ for row in image_rows[1:]:
1187
+ row.update(visible=False)
1188
+ for row in score_rows[1:]:
1189
+ row.update(visible=False)
1190
+
1191
+ return empty_results + empty_scores + ["请先上传参考图像"]
1192
+
1193
+ # 执行相似度搜索,只获取用户请求的数量
1194
+ result_images, similarity_scores, status = perform_image_search(query_image, int(num_results))
1195
+
1196
+ # 打印调试信息
1197
+ print(f"请求的结果数量: {num_results}")
1198
+ print(f"实际返回的结果数量: {len(result_images)}")
1199
+
1200
+ # 清空所有结果
1201
+ padded_results = [None] * 8
1202
+ padded_scores = [""] * 8
1203
+
1204
+ # 填充实际结果
1205
+ for i in range(min(len(result_images), 8)):
1206
+ padded_results[i] = result_images[i]
1207
+ padded_scores[i] = similarity_scores[i]
1208
+
1209
+ # 控制结果行的可见性
1210
+ for i, row in enumerate(image_rows):
1211
+ row.update(visible=i < len(result_images))
1212
+ for i, row in enumerate(score_rows):
1213
+ row.update(visible=i < len(result_images))
1214
+
1215
+ # 返回图像列表、相似度分数列表和状态文本
1216
+ return padded_results + padded_scores + [f"找到 {len(result_images)} 个相似图像"]
1217
+
1218
+ # 绑定搜索按钮事件
1219
+ search_btn.click(
1220
+ handle_image_search,
1221
+ inputs=[reference_image, num_results],
1222
+ outputs=similar_images + similarity_scores + [search_status]
1223
+ )
1224
+
1225
+ # 重建索引事件
1226
+ def rebuild_image_index():
1227
+ """重建图像特征索引"""
1228
+ global faiss_index, image_metadata
1229
+
1230
+ # 创建新的索引
1231
+ create_new_index()
1232
+
1233
+ # 返回索引状态
1234
+ if faiss_index is not None:
1235
+ return f"索引重建完成,共索引了 {faiss_index.ntotal} 张图像"
1236
+ else:
1237
+ return "索引重建失败"
1238
+
1239
+ # 绑定重建索引按钮事件
1240
+ rebuild_index_btn.click(
1241
+ rebuild_image_index,
1242
+ inputs=[],
1243
+ outputs=[index_status]
1244
+ )
1245
+
1246
+ return app
1247
+
1248
+ # 启动应用
1249
+ if __name__ == "__main__":
1250
+ app = create_interface()
1251
+ app.launch(share=True)
download_resources.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import torch
5
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation, CLIPProcessor, CLIPModel
6
+ from controlnet_aux import MLSDdetector
7
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionControlNetInpaintPipeline
8
+ import urllib.request
9
+ import shutil
10
+
11
+ # 创建资源目录
12
+ def create_directories():
13
+ directories = [
14
+ "resources",
15
+ "resources/models",
16
+ "resources/images",
17
+ "resources/labels",
18
+ "resources/output"
19
+ ]
20
+ for directory in directories:
21
+ os.makedirs(directory, exist_ok=True)
22
+ print("目录结构创建完成")
23
+
24
+ # 下载ADE20K标签文件
25
+ def download_labels():
26
+ url = "https://huggingface.co/datasets/huggingface/label-files/raw/main/ade20k-id2label.json"
27
+ labels_path = "resources/labels/ade20k-id2label.json"
28
+ response = requests.get(url)
29
+ with open(labels_path, 'w') as f:
30
+ f.write(response.text)
31
+ print(f"标签文件已保存到: {labels_path}")
32
+
33
+ # 下载示例图片
34
+ def download_sample_image():
35
+ raw_url = "https://raw.githubusercontent.com/naderAsadi/DesignGenie/main/examples/images/sample_input.png"
36
+ img_path = "resources/images/sample_input.png"
37
+ try:
38
+ urllib.request.urlretrieve(raw_url, img_path)
39
+ print(f"示例图片已保存到: {img_path}")
40
+ # 同时拷贝到根目录,保持原脚本兼容
41
+ shutil.copy(img_path, "sample_input.png")
42
+ except Exception as e:
43
+ print(f"图片下载失败: {e}")
44
+
45
+ # 下载模型文件
46
+ def download_models():
47
+ print("正在下载模型,这可能需要一些时间...")
48
+
49
+ # 1. 下载 Mask2Former 模型
50
+ print("下载 Mask2Former 模型...")
51
+ processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models")
52
+ model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-ade-semantic", cache_dir="resources/models")
53
+ print("Mask2Former 模型下载完成")
54
+
55
+ # 2. 下载 MLSD 检测器
56
+ print("下载 MLSD 检测器...")
57
+ processor = MLSDdetector.from_pretrained("lllyasviel/Annotators", cache_dir="resources/models")
58
+ print("MLSD 检测器下载完成")
59
+
60
+ # 3. 下载 ControlNet 模型
61
+ print("下载 ControlNet 模型...")
62
+ controlnet = ControlNetModel.from_pretrained(
63
+ "lllyasviel/control_v11p_sd15_mlsd",
64
+ torch_dtype=torch.float16,
65
+ cache_dir="resources/models",
66
+ use_safetensors=False
67
+ )
68
+ print("ControlNet 模型下载完成")
69
+
70
+ # 4. 下载 Stable Diffusion 模型
71
+ print("下载 Stable Diffusion 模型...")
72
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
73
+ "runwayml/stable-diffusion-v1-5",
74
+ controlnet=controlnet,
75
+ torch_dtype=torch.float16,
76
+ cache_dir="resources/models",
77
+ use_safetensors=False
78
+ )
79
+ print("Stable Diffusion 模型下载完成")
80
+
81
+ # 5. 下载 Stable Diffusion Inpainting 模型 (用于 inpaint.py)
82
+ print("下载 Stable Diffusion Inpainting 模型...")
83
+ pipe_inpaint = StableDiffusionControlNetInpaintPipeline.from_pretrained(
84
+ "runwayml/stable-diffusion-inpainting",
85
+ controlnet=controlnet,
86
+ torch_dtype=torch.float16,
87
+ cache_dir="resources/models",
88
+ use_safetensors=False
89
+ )
90
+ print("Stable Diffusion Inpainting 模型下载完成")
91
+
92
+ # 6. 下载图像特征提取模型 (用于相似性搜索)
93
+ print("下载图像特征提取模型...")
94
+ try:
95
+ clip_model = CLIPModel.from_pretrained(
96
+ "openai/clip-vit-base-patch32",
97
+ cache_dir="resources/models"
98
+ )
99
+ clip_processor = CLIPProcessor.from_pretrained(
100
+ "openai/clip-vit-base-patch32",
101
+ cache_dir="resources/models"
102
+ )
103
+ print("图像特征提取模型下载完成")
104
+ except Exception as e:
105
+ print(f"图像特征提取模型下载失败: {e}")
106
+
107
+ if __name__ == "__main__":
108
+ create_directories()
109
+ download_labels()
110
+ download_sample_image()
111
+ download_models()
112
+ print("所有资源下载完成!您可以将整个 'resources' 文件夹保存到本地使用。")
requirements.txt ADDED
Binary file (4.08 kB). View file