Spaces:
Runtime error
Runtime error
File size: 7,455 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from threading import Lock
from fastapi.responses import JSONResponse
from modules import errors, shared, scripts, ui
from modules.api import models, script, helpers
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
errors.install()
class APIGenerate():
def __init__(self, queue_lock: Lock):
self.queue_lock = queue_lock
self.default_script_arg_txt2img = []
self.default_script_arg_img2img = []
def sanitize_args(self, args: dict):
args = vars(args)
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model
args.pop('script_name', None)
args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them
args.pop('alwayson_scripts', None)
args.pop('face', None)
args.pop('face_id', None)
args.pop('ip_adapter', None)
args.pop('save_images', None)
return args
def sanitize_b64(self, request):
def sanitize_str(args: list):
for idx in range(0, len(args)):
if isinstance(args[idx], str) and len(args[idx]) >= 1000:
args[idx] = f"<str {len(args[idx])}>"
if hasattr(request, "alwayson_scripts") and request.alwayson_scripts:
for script_name in request.alwayson_scripts.keys():
script_obj = request.alwayson_scripts[script_name]
if script_obj and "args" in script_obj and script_obj["args"]:
sanitize_str(script_obj["args"])
if hasattr(request, "script_args") and request.script_args:
sanitize_str(request.script_args)
def prepare_face_module(self, request):
if hasattr(request, "face") and request.face and not request.script_name and (not request.alwayson_scripts or "face" not in request.alwayson_scripts.keys()):
request.script_name = "face"
request.script_args = [
request.face.mode,
request.face.source_images,
request.face.ip_model,
request.face.ip_override_sampler,
request.face.ip_cache_model,
request.face.ip_strength,
request.face.ip_structure,
request.face.id_strength,
request.face.id_conditioning,
request.face.id_cache,
request.face.pm_trigger,
request.face.pm_strength,
request.face.pm_start,
request.face.fs_cache
]
del request.face
def post_text2img(self, txt2imgreq: models.ReqTxt2Img):
self.prepare_face_module(txt2imgreq)
script_runner = scripts.scripts_txt2img
if not script_runner.scripts:
script_runner.initialize_scripts(False)
ui.create_ui(None)
if not self.default_script_arg_txt2img:
self.default_script_arg_txt2img = script.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = script.get_selectable_script(txt2imgreq.script_name, script_runner)
populate = txt2imgreq.copy(update={ # Override __init__ params
"sampler_name": helpers.validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
"do_not_save_samples": not txt2imgreq.save_images,
"do_not_save_grid": not txt2imgreq.save_images,
})
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = self.sanitize_args(populate)
send_images = args.pop('send_images', True)
with self.queue_lock:
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
p.scripts = script_runner
p.outpath_grids = shared.opts.outdir_grids or shared.opts.outdir_txt2img_grids
p.outpath_samples = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples
shared.state.begin('api-txt2img', api=True)
script_args = script.init_script_args(p, txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner)
if selectable_scripts is not None:
processed = scripts.scripts_txt2img.run(p, *script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end(api=False)
b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else []
self.sanitize_b64(txt2imgreq)
return models.ResTxt2Img(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def post_img2img(self, img2imgreq: models.ReqImg2Img):
self.prepare_face_module(img2imgreq)
init_images = img2imgreq.init_images
if init_images is None:
return JSONResponse(status_code=400, content={"error": "Init image is none"})
mask = img2imgreq.mask
if mask:
mask = helpers.decode_base64_to_image(mask)
script_runner = scripts.scripts_img2img
if not script_runner.scripts:
script_runner.initialize_scripts(True)
ui.create_ui(None)
if not self.default_script_arg_img2img:
self.default_script_arg_img2img = script.init_default_script_args(script_runner)
selectable_scripts, selectable_script_idx = script.get_selectable_script(img2imgreq.script_name, script_runner)
populate = img2imgreq.copy(update={ # Override __init__ params
"sampler_name": helpers.validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index),
"do_not_save_samples": not img2imgreq.save_images,
"do_not_save_grid": not img2imgreq.save_images,
"mask": mask,
})
if populate.sampler_name:
populate.sampler_index = None # prevent a warning later on
args = self.sanitize_args(populate)
send_images = args.pop('send_images', True)
with self.queue_lock:
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
p.init_images = [helpers.decode_base64_to_image(x) for x in init_images]
p.scripts = script_runner
p.outpath_grids = shared.opts.outdir_img2img_grids
p.outpath_samples = shared.opts.outdir_img2img_samples
shared.state.begin('api-img2img', api=True)
script_args = script.init_script_args(p, img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner)
if selectable_scripts is not None:
processed = scripts.scripts_img2img.run(p, *script_args) # Need to pass args as list here
else:
p.script_args = tuple(script_args) # Need to pass args as tuple here
processed = process_images(p)
shared.state.end(api=False)
b64images = list(map(helpers.encode_pil_to_base64, processed.images)) if send_images else []
if not img2imgreq.include_init_images:
img2imgreq.init_images = None
img2imgreq.mask = None
self.sanitize_b64(img2imgreq)
return models.ResImg2Img(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|