masteroko's picture
Update app.py
2dddeae verified
raw
history blame
11.1 kB
import gradio as gr
import requests
import time
import json
import base64
import os
from io import BytesIO
import html
import re
from PIL import Image # Убедитесь, что PIL импортируется для работы с изображениями
class Prodia:
def __init__(self, api_key, base=None):
self.base = base or "https://api.prodia.com/v1"
self.headers = {
"X-Prodia-Key": api_key
}
def generate(self, params):
return self._post(f"{self.base}/sd/generate", params).json()
def transform(self, params):
return self._post(f"{self.base}/sd/transform", params).json()
def controlnet(self, params):
return self._post(f"{self.base}/sd/controlnet", params).json()
def get_job(self, job_id):
return self._get(f"{self.base}/job/{job_id}").json()
def wait(self, job):
job_result = job
while job_result['status'] not in ['succeeded', 'failed']:
time.sleep(0.25)
job_result = self.get_job(job['job'])
return job_result
def list_models(self):
return self._get(f"{self.base}/sd/models").json()
def list_samplers(self):
return self._get(f"{self.base}/sd/samplers").json()
# Изменение: добавлены повторные попытки при сбоях сети
def _post(self, url, params):
headers = {**self.headers, "Content-Type": "application/json"}
for _ in range(3):
try:
response = requests.post(url, headers=headers, data=json.dumps(params))
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}, retrying...")
time.sleep(1)
raise Exception("Failed after 3 attempts")
# Изменение: добавлены повторные попытки при сбоях сети
def _get(self, url):
for _ in range(3):
try:
response = requests.get(url, headers=self.headers)
response.raise_for_status()
return response
except requests.exceptions.RequestException as e:
print(f"Request failed: {e}, retrying...")
time.sleep(1)
raise Exception("Failed after 3 attempts")
def image_to_base64(image, format="PNG"):
buffered = BytesIO()
image.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return img_str
def remove_id_and_ext(text):
text = re.sub(r'\[.*\]$', '', text)
extension = text[-12:].strip()
if extension == "safetensors":
text = text[:-13]
elif extension == "ckpt":
text = text[:-4]
return text
# Изменение: оптимизация функции get_data
def get_data(text):
patterns = {
'prompt': r'(.*)',
'negative_prompt': r'Negative prompt: (.*)',
'steps': r'Steps: (\d+),',
'seed': r'Seed: (\d+),',
'sampler': r'Sampler:\s*([^\s,]+(?:\s+[^\s,]+)*)',
'model': r'Model:\s*([^\s,]+)',
'cfg_scale': r'CFG scale:\s*([\d\.]+)',
'size': r'Size:\s*([0-9]+x[0-9]+)'
}
results = {key: re.search(pattern, text).group(1) if re.search(pattern, text) else None for key, pattern in patterns.items()}
if results['size']:
results['w'], results['h'] = map(int, results['size'].split("x"))
else:
results['w'], results['h'] = None, None
return results
# Изменение: оптимизация функции send_to_txt2img
def send_to_txt2img(image):
result = {tabs: gr.update(selected="t2i")}
try:
text = image.info['parameters']
data = get_data(text)
fields = ['prompt', 'negative_prompt', 'steps', 'seed', 'cfg_scale', 'w', 'h', 'sampler', 'model']
for field in fields:
result[field] = gr.update(value=data[field]) if data[field] is not None else gr.update()
return result
except Exception as e:
print(e)
return result
prodia_client = Prodia(api_key=os.getenv("PRODIA_API_KEY"))
model_list = prodia_client.list_models()
model_names = {remove_id_and_ext(model_name): model_name for model_name in model_list}
def txt2img(prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
result = prodia_client.generate({
"prompt": prompt,
"negative_prompt": negative_prompt,
"model": model,
"steps": steps,
"sampler": sampler,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"seed": seed
})
job = prodia_client.wait(result)
return job["imageUrl"]
def img2img(input_image, denoising, prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed):
result = prodia_client.transform({
"imageData": image_to_base64(input_image),
"denoising_strength": denoising,
"prompt": prompt,
"negative_prompt": negative_prompt,
"model": model,
"steps": steps,
"sampler": sampler,
"cfg_scale": cfg_scale,
"width": width,
"height": height,
"seed": seed
})
job = prodia_client.wait(result)
return job["imageUrl"]
css = """
#generate {
height: 100%;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column(scale=6):
model = gr.Dropdown(interactive=True, value="childrensStories_v1ToonAnime.safetensors [2ec7b88b]", show_label=True, label="Stable Diffusion Checkpoint", choices=prodia_client.list_models())
with gr.Column(scale=1):
gr.Markdown(elem_id="powered-by-prodia", value="AUTOMATIC1111 Stable Diffusion Web UI переделано masteroko.<br>Powered by [Prodia](https://prodia.com).<br>For more features and faster generation times check out our [API Docs](https://docs.prodia.com/reference/getting-started-guide).")
with gr.Tabs() as tabs:
with gr.Tab("txt2img", id='t2i'):
with gr.Row():
with gr.Column(scale=6, min_width=600):
prompt = gr.Textbox("nsfw", placeholder="Prompt", show_label=False, lines=3)
negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3, value="3d, cartoon, anime, (deformed eyes, nose, ears, nose), bad anatomy, ugly")
with gr.Column():
text_button = gr.Button("Сгенерировать", variant='primary', elem_id="generate")
with gr.Row():
with gr.Column(scale=3):
with gr.Tab("Генерация"):
with gr.Row():
with gr.Column(scale=1):
sampler = gr.Dropdown(value="DPM++ 2M SDE Exponential", show_label=True, label="Sampling Method", choices=prodia_client.list_samplers())
with gr.Column(scale=1):
steps = gr.Slider(label="количество обработок", minimum=1, maximum=100, value=20, step=1)
with gr.Row():
with gr.Column(scale=1):
width = gr.Slider(label="Ширина", maximum=1024, value=512, step=8)
height = gr.Slider(label="Высота", maximum=1024, value=512, step=8)
with gr.Column(scale=1):
batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
cfg_scale = gr.Slider(label="CFG Scale(степень фантазии ии)", minimum=1, maximum=20, value=7, step=1)
seed = gr.Number(label="Семя рандома", value=-1)
with gr.Column(scale=2):
image_output = gr.Image(value="https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png")
# Изменение: добавлен limit на одновременные запросы
text_button.click(txt2img, inputs=[prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output, concurrency_limit=1024)
with gr.Tab("img2img", id='i2i'):
with gr.Row():
with gr.Column(scale=6, min_width=600):
image = gr.Image(show_label=False)
prompt = gr.Textbox(placeholder="Prompt", show_label=False, lines=3)
negative_prompt = gr.Textbox(placeholder="Negative Prompt", show_label=False, lines=3)
with gr.Column():
text_button = gr.Button("Трансформировать", variant='primary', elem_id="generate")
with gr.Row():
with gr.Column(scale=3):
with gr.Tab("Генерация"):
with gr.Row():
with gr.Column(scale=1):
sampler = gr.Dropdown(value="DPM++ 2M SDE Exponential", show_label=True, label="Sampling Method", choices=prodia_client.list_samplers())
with gr.Column(scale=1):
steps = gr.Slider(label="количество обработок", minimum=1, maximum=100, value=20, step=1)
with gr.Row():
with gr.Column(scale=1):
width = gr.Slider(label="Ширина", maximum=1024, value=512, step=8)
height = gr.Slider(label="Высота", maximum=1024, value=512, step=8)
with gr.Column(scale=1):
batch_size = gr.Slider(label="Batch Size", maximum=1, value=1)
batch_count = gr.Slider(label="Batch Count", maximum=1, value=1)
cfg_scale = gr.Slider(label="CFG Scale(степень фантазии ии)", minimum=1, maximum=20, value=7, step=1)
seed = gr.Number(label="Семя рандома", value=-1)
denoising = gr.Slider(label="Denoising", value=0.5)
with gr.Column(scale=2):
image_output = gr.Image(value="https://images.prodia.xyz/8ede1a7c-c0ee-4ded-987d-6ffed35fc477.png")
# Изменение: добавлен limit на одновременные запросы
text_button.click(img2img, inputs=[image, denoising, prompt, negative_prompt, model, steps, sampler, cfg_scale, width, height, seed], outputs=image_output, concurrency_limit=1024)
# Запуск Gradio интерфейса
demo.launch()