|
import streamlit as st |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import io |
|
import gc |
|
|
|
|
|
st.title("Генератор изображений с LCM Dreamshaper") |
|
st.write("Используйте эту модель для быстрой генерации изображений на CPU") |
|
|
|
|
|
with st.sidebar: |
|
st.header("Настройки") |
|
prompt = st.text_area("Введите ваш запрос:", "hoaxx kitty", height=100) |
|
|
|
num_inference_steps = st.slider( |
|
"Количество шагов инференса:", |
|
min_value=1, |
|
max_value=50, |
|
value=5, |
|
help="Больше шагов = выше качество, но медленнее" |
|
) |
|
|
|
guidance_scale = st.slider( |
|
"Guidance Scale:", |
|
min_value=1.0, |
|
max_value=15.0, |
|
value=8.0, |
|
step=0.5, |
|
help="Насколько строго модель следует промпту" |
|
) |
|
|
|
lcm_origin_steps = st.slider( |
|
"LCM Origin Steps:", |
|
min_value=1, |
|
max_value=50, |
|
value=35 |
|
) |
|
|
|
generate_button = st.button("Сгенерировать изображение") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
pipe = DiffusionPipeline.from_pretrained( |
|
"SimianLuo/LCM_Dreamshaper_v7", |
|
torch_dtype=torch.float32 |
|
) |
|
pipe.to("cpu") |
|
pipe.enable_attention_slicing() |
|
pipe.safety_checker = None |
|
return pipe |
|
|
|
|
|
def generate_image(pipe, prompt, steps, guidance, lcm_steps): |
|
try: |
|
with torch.inference_mode(): |
|
images = pipe( |
|
prompt=prompt, |
|
num_inference_steps=steps, |
|
guidance_scale=guidance, |
|
lcm_origin_steps=lcm_steps, |
|
output_type="pil" |
|
).images |
|
return images[0] |
|
except Exception as e: |
|
st.error(f"Error generating image: {e}") |
|
return None |
|
|
|
|
|
pipe = load_model() |
|
|
|
|
|
if generate_button: |
|
with st.spinner("Генерация изображения..."): |
|
|
|
result_container = st.empty() |
|
|
|
|
|
image = generate_image( |
|
pipe, |
|
prompt, |
|
num_inference_steps, |
|
guidance_scale, |
|
lcm_origin_steps |
|
) |
|
|
|
|
|
if image: |
|
result_container.image(image, caption=f"Результат для: {prompt}", use_container_width=True) |
|
|
|
|
|
buf = io.BytesIO() |
|
image.save(buf, format="PNG") |
|
byte_im = buf.getvalue() |
|
|
|
st.download_button( |
|
label="Скачать изображение", |
|
data=byte_im, |
|
file_name="generated_image.png", |
|
mime="image/png" |
|
) |
|
|
|
gc.collect() |
|
|
|
|
|
if not generate_button: |
|
st.write("👈 Настройте параметры в боковой панели и нажмите 'Сгенерировать изображение'") |