Superigni commited on
Commit
0adeb3f
·
verified ·
1 Parent(s): c9cbd31

Add app.py with model download and Gradio interface

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import os
9
+ from tqdm import tqdm # Добавляем импорт tqdm
10
+
11
+ # Импорты из diffusers
12
+ from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionPipeline
13
+ # from diffusers.utils import load_image # Не нужен для этого кода
14
+ # from huggingface_hub import hf_hub_download # Не нужен для этого кода
15
+
16
+ # --- Вспомогательная функция для скачивания файлов (например, с Civitai) ---
17
+ # Эта функция будет скачивать модель SafeTensor внутри Space при первом запуске
18
+ def download_file(url, local_filename):
19
+ """Скачивает файл по URL с индикатором прогресса."""
20
+ print(f"Скачиваю {url} в {local_filename}...")
21
+ # Проверяем, существует ли файл, чтобы не скачивать его каждый раз
22
+ if os.path.exists(local_filename):
23
+ print(f"Файл уже существует: {local_filename}. Пропускаю скачивание.")
24
+ return local_filename
25
+
26
+ try:
27
+ response = requests.get(url, stream=True)
28
+ response.raise_for_status() # Проверка на ошибки HTTP
29
+
30
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
31
+ block_size = 8192 # 8 Kibibytes
32
+
33
+ with tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Скачивание {local_filename}") as progress_bar:
34
+ with open(local_filename, 'wb') as f:
35
+ for chunk in response.iter_content(chunk_size=block_size):
36
+ progress_bar.update(len(chunk))
37
+ f.write(chunk)
38
+
39
+ print(f"Скачивание завершено: {local_filename}")
40
+ return local_filename
41
+ except requests.exceptions.RequestException as e:
42
+ print(f"Ошибка скачивания {url}: {e}")
43
+ return None
44
+ except Exception as e:
45
+ print(f"Произошла другая ошибка при скачивании: {e}")
46
+ return None
47
+
48
+
49
+ # --- Определение путей/ID моделей ---
50
+ # URL вашей SafeTensor модели с Civitai
51
+ CIVITAI_SAFETENSOR_URL = "https://civitai.com/api/download/models/1413133?type=Model&format=SafeTensor&size=full&fp=fp8"
52
+ # Локальное имя файла для сохранения SafeTensor модели внутри Space
53
+ LOCAL_SAFETENSOR_FILENAME = "ultrareal_fine_tune_fp8_full.safetensors"
54
+
55
+ # ControlNet модель с Hugging Face
56
+ CONTROLNET_MODEL_ID = "ABDALLALSWAITI/FLUX.1-dev-ControlNet-Union-Pro-2.0-fp8"
57
+
58
+ # --- Скачиваем SafeTensor модель (выполнится при запуске скрипта в Space) ---
59
+ print("Начинаю скачивание базовой модели...")
60
+ downloaded_base_model_path = download_file(CIVITAI_SAFETENSOR_URL, LOCAL_SAFETENSOR_FILENAME)
61
+
62
+ if not downloaded_base_model_path or not os.path.exists(downloaded_base_model_path):
63
+ # Если скачивание не удалось или файл не существует после попытки
64
+ print(f"Критическая ошибка: Не удалось получить файл базовой модели по пути: {LOCAL_SAFETENSOR_FILENAME}")
65
+ print("Проверьте логи Space на наличие ошибок скачивания.")
66
+ # Возможно, здесь стоит выбросить исключение или как-то иначе остановить приложение
67
+ # Для примера, просто присвоим None и приложение не сможет загрузить пайплайн
68
+ pipeline = None
69
+ else:
70
+ # --- Загрузка моделей и создание пайплайна ---
71
+ def load_pipeline_components(base_model_path, controlnet_model_id):
72
+ """Загружает базовую модель из локального файла, ControlNet и собирает пайплайн."""
73
+ print(f"Загрузка ControlNet модели: {controlnet_model_id}")
74
+ # Загрузка ControlNet с Hugging Face Hub - кешируется автоматически Space
75
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
76
+
77
+ print(f"Загрузка базовой модели из локального файла: {base_model_path}")
78
+ # Загружаем базовую модель из локального SafeTensor файла
79
+ # diffusers умеет загружать локальные файлы .safetensors
80
+ pipe = StableDiffusionPipeline.from_pretrained(
81
+ base_model_path, # Указываем путь к локальн��му файлу внутри Space
82
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
83
+ safety_checker=None # Отключение safety checker для скорости (используйте с осторожностью!)
84
+ )
85
+
86
+ # Теперь объединяем базовый пайплайн с ControlNet
87
+ # Создаем StableDiffusionControlNetPipeline на основе загруженного базового пайплайна
88
+ print("Создание пайплайна StableDiffusionControlNetPipeline...")
89
+ controlnet_pipe = StableDiffusionControlNetPipeline(
90
+ vae=pipe.vae,
91
+ text_encoder=pipe.text_encoder,
92
+ tokenizer=pipe.tokenizer,
93
+ unet=pipe.unet,
94
+ controlnet=controlnet, # Передаем загруженный ControlNet
95
+ scheduler=pipe.scheduler, # Используем планировщик из базового пайплайна
96
+ safety_checker=None,
97
+ feature_extractor=pipe.feature_extractor
98
+ )
99
+
100
+ # Рекомендуется использовать планировщик UniPC для ControlNet (или обновить существующий)
101
+ # Обновляем планировщик в новом ControlNet пайплайне
102
+ controlnet_pipe.scheduler = UniPCMultistepScheduler.from_config(controlnet_pipe.scheduler.config)
103
+
104
+ # Удаляем старый базовый пайплайн для освобождения памяти
105
+ del pipe
106
+ if torch.cuda.is_available():
107
+ torch.cuda.empty_cache()
108
+
109
+ # Перемещаем ControlNet пайплайн на GPU, если доступно
110
+ if torch.cuda.is_available():
111
+ controlnet_pipe = controlnet_pipe.to("cuda")
112
+ print("Пайплайн перемещен на GPU.")
113
+ else:
114
+ print("GPU не найдено. Пайплайн будет работать на CPU (будет медленно).") # В Space на CPU работать не будет эффективно
115
+
116
+ return controlnet_pipe
117
+
118
+ # Загружаем пайплайн при запуске скрипта, только если файл модели успешно скачан
119
+ pipeline = load_pipeline_components(downloaded_base_model_path, CONTROLNET_MODEL_ID)
120
+
121
+
122
+ # --- Функция рендеринга для Gradio ---
123
+ # Эта функция будет вызываться интерфейсом Gradio в Space
124
+ def generate_image_gradio(controlnet_image: np.ndarray, prompt: str, negative_prompt: str = "", guidance_scale: float = 7.5, num_inference_steps: int = 30, controlnet_conditioning_scale: float = 1.0):
125
+ """
126
+ Генерирует изображение с использованием Stable Diffusion ControlNet.
127
+ Принимает изображение NumPy, текст промта и другие параметры.
128
+ Возвращает сгенерированное изображение в формате PIL Image.
129
+ """
130
+ # Проверяем, успешно ли загрузился пайплайн
131
+ if pipeline is None:
132
+ return None, "Ошибка: Пайплайн модели не загружен. Проверьте логи Space."
133
+
134
+ if controlnet_image is None:
135
+ return None, "Ошибка: необходимо загрузить изображение для ControlNet."
136
+
137
+ print(f"Генерация изображения с промтом: '{prompt}'")
138
+ print(f"Размер входного изображения: {controlnet_image.shape}")
139
+
140
+ # Gradio возвращает изображение как numpy array. Преобразуем в PIL Image для пайплайна.
141
+ input_image_pil = Image.fromarray(controlnet_image).convert("RGB")
142
+
143
+ # Выполняем рендеринг с помощью пайплайна
144
+ try:
145
+ output = pipeline(
146
+ prompt=prompt,
147
+ image=input_image_pil, # Входное изображение для ControlNet
148
+ negative_prompt=negative_prompt,
149
+ guidance_scale=guidance_scale,
150
+ num_inference_steps=num_inference_steps,
151
+ controlnet_conditioning_scale=controlnet_conditioning_scale
152
+ # Здесь можно добавить generator=... (для сидов), width=..., height=..., etc.
153
+ )
154
+
155
+ # Результат находится в output.images[0]
156
+ generated_image_pil = output.images[0]
157
+
158
+ print("Генерация завершена.")
159
+ return generated_image_pil, "Успех!"
160
+ except Exception as e:
161
+ print(f"Ошибка при генерации: {e}")
162
+ return None, f"Ошибка при генерации: {e}"
163
+
164
+
165
+ # --- Настройка интерфейса Gradio ---
166
+ # Определяем входные и выходные элементы
167
+ input_image_comp = gr.Image(type="numpy", label="Изображение для ControlNet (набросок, карта глубины и т.д.)")
168
+ prompt_comp = gr.Textbox(label="Промт (Prompt)")
169
+ negative_prompt_comp = gr.Textbox(label="Негативный промт (Negative Prompt)")
170
+ guidance_scale_comp = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Степень соответствия промту (Guidance Scale)")
171
+ num_inference_steps_comp = gr.Slider(minimum=10, maximum=150, value=30, step=1, label="Количество шагов (Inference Steps)")
172
+ controlnet_conditioning_scale_comp = gr.Slider(minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Вес ControlNet (ControlNet Scale)")
173
+
174
+ output_image_comp = gr.Image(type="pil", label="Сгенерированное изображение")
175
+ status_text_comp = gr.Textbox(label="Статус")
176
+
177
+
178
+ # Создаем интерфейс Gradio
179
+ # Поскольку мы в Space, Gradio SDK сам вызовет interface.launch()
180
+ # Нам просто нужно определить интерфейс
181
+ interface = gr.Interface(
182
+ fn=generate_image_gradio,
183
+ inputs=[
184
+ input_image_comp,
185
+ prompt_comp,
186
+ negative_prompt_comp,
187
+ guidance_scale_comp,
188
+ num_inference_steps_comp,
189
+ controlnet_conditioning_scale_comp
190
+ ],
191
+ outputs=[output_image_comp, status_text_comp],
192
+ title="Stable Diffusion ControlNet Interface (SafeTensor Base Model)",
193
+ description="Загрузите изображение для ControlNet, введите промт и нажмите 'Generate'. Используется локальная SafeTensor модель и ControlNet с Hugging Face."
194
+ )
195
+
196
+ # Важно: Не вызывайте interface.launch() в блоке if __name__ == "__main__":
197
+ # Gradio SDK в Space сделает это автоматически.
198
+ # Если вы оставите if __name__ == "__main__": interface.launch(), оно тоже будет работать,
199
+ # но в среде Space это менее критично, чем при локальном запуске.
200
+ # Для ясности в Space можно убрать блок if __name__ == "__main__":
201
+ # Я оставил его в коде выше, но знайте, что SDK вызовет interface.launch() независимо от него.