Spaces:
Runtime error
Runtime error
File size: 6,573 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import Optional
from fastapi.exceptions import HTTPException
from modules import shared
from modules.api import models, helpers
def get_samplers():
from modules import sd_samplers
return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers]
def get_sd_vaes():
from modules.sd_vae import vae_dict
return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()]
def get_upscalers():
return [{"name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, "model_url": None, "scale": upscaler.scale} for upscaler in shared.sd_upscalers]
def get_sd_models():
from modules import sd_models, sd_models_config
return [{"title": x.title, "model_name": x.name, "filename": x.filename, "type": x.type, "hash": x.shorthash, "sha256": x.sha256, "config": sd_models_config.find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]
def get_hypernetworks():
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
def get_face_restorers():
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
def get_prompt_styles():
return [{ 'name': v.name, 'prompt': v.prompt, 'negative_prompt': v.negative_prompt, 'extra': v.extra, 'filename': v.filename, 'preview': v.preview} for v in shared.prompt_styles.styles.values()]
def get_embeddings():
from modules import sd_hijack
db = sd_hijack.model_hijack.embedding_db
def convert_embedding(embedding):
return {"step": embedding.step, "sd_checkpoint": embedding.sd_checkpoint, "sd_checkpoint_name": embedding.sd_checkpoint_name, "shape": embedding.shape, "vectors": embedding.vectors}
def convert_embeddings(embeddings):
return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}
return {"loaded": convert_embeddings(db.word_embeddings), "skipped": convert_embeddings(db.skipped_embeddings)}
def get_extra_networks(page: Optional[str] = None, name: Optional[str] = None, filename: Optional[str] = None, title: Optional[str] = None, fullname: Optional[str] = None, hash: Optional[str] = None): # pylint: disable=redefined-builtin
res = []
for pg in shared.extra_networks:
if page is not None and pg.name != page.lower():
continue
for item in pg.items:
if name is not None and item.get('name', '') != name:
continue
if title is not None and item.get('title', '') != title:
continue
if filename is not None and item.get('filename', '') != filename:
continue
if fullname is not None and item.get('fullname', '') != fullname:
continue
if hash is not None and (item.get('shorthash', None) or item.get('hash')) != hash:
continue
res.append({
'name': item.get('name', ''),
'type': pg.name,
'title': item.get('title', None),
'fullname': item.get('fullname', None),
'filename': item.get('filename', None),
'hash': item.get('shorthash', None) or item.get('hash'),
"preview": item.get('preview', None),
})
return res
def get_interrogate():
from modules.interrogate import get_clip_models
return ['clip', 'deepdanbooru'] + get_clip_models()
def post_interrogate(req: models.ReqInterrogate):
if req.image is None or len(req.image) < 64:
raise HTTPException(status_code=404, detail="Image not found")
image = helpers.decode_base64_to_image(req.image)
image = image.convert('RGB')
if req.model == "clip":
try:
caption = shared.interrogator.interrogate(image)
except Exception as e:
caption = str(e)
return models.ResInterrogate(caption=caption)
elif req.model == "deepdanbooru" or req.model == 'deepbooru':
from modules import deepbooru
caption = deepbooru.model.tag(image)
return models.ResInterrogate(caption=caption)
else:
from modules.interrogate import interrogate_image, analyze_image, get_clip_models
if req.model not in get_clip_models():
raise HTTPException(status_code=404, detail="Model not found")
try:
caption = interrogate_image(image, model=req.model, mode=req.mode)
except Exception as e:
caption = str(e)
if not req.analyze:
return models.ResInterrogate(caption=caption)
else:
medium, artist, movement, trending, flavor = analyze_image(image, model=req.model)
return models.ResInterrogate(caption=caption, medium=medium, artist=artist, movement=movement, trending=trending, flavor=flavor)
def post_unload_checkpoint():
from modules import sd_models
sd_models.unload_model_weights(op='model')
sd_models.unload_model_weights(op='refiner')
return {}
def post_reload_checkpoint():
from modules import sd_models
sd_models.reload_model_weights()
return {}
def post_refresh_checkpoints():
return shared.refresh_checkpoints()
def post_refresh_vae():
return shared.refresh_vaes()
def get_extensions_list():
from modules import extensions
extensions.list_extensions()
ext_list = []
for ext in extensions.extensions:
ext: extensions.Extension
ext.read_info()
if ext.remote is not None:
ext_list.append({
"name": ext.name,
"remote": ext.remote,
"branch": ext.branch,
"commit_hash":ext.commit_hash,
"commit_date":ext.commit_date,
"version":ext.version,
"enabled":ext.enabled
})
return ext_list
def post_pnginfo(req: models.ReqImageInfo):
from modules import images, script_callbacks, generation_parameters_copypaste
if not req.image.strip():
return models.ResImageInfo(info="")
image = helpers.decode_base64_to_image(req.image.strip())
if image is None:
return models.ResImageInfo(info="")
geninfo, items = images.read_info_from_image(image)
if geninfo is None:
geninfo = ""
params = generation_parameters_copypaste.parse_generation_parameters(geninfo)
script_callbacks.infotext_pasted_callback(geninfo, params)
return models.ResImageInfo(info=geninfo, items=items, parameters=params)
|