Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -13,14 +13,14 @@ model_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
|
13 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
14 |
|
15 |
def get_lora_sd_pipeline(
|
16 |
-
|
17 |
base_model_name_or_path=None,
|
18 |
dtype=torch.float16,
|
19 |
adapter_name="default"
|
20 |
):
|
21 |
|
22 |
-
unet_sub_dir = os.path.join(
|
23 |
-
text_encoder_sub_dir = os.path.join(
|
24 |
|
25 |
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
26 |
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
@@ -56,7 +56,7 @@ def align_embeddings(prompt_embeds, negative_prompt_embeds):
|
|
56 |
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
|
57 |
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
|
58 |
|
59 |
-
pipe_default = get_lora_sd_pipeline(
|
60 |
|
61 |
def infer(
|
62 |
prompt,
|
@@ -67,7 +67,7 @@ def infer(
|
|
67 |
model='stable-diffusion-v1-5/stable-diffusion-v1-5',
|
68 |
seed=4,
|
69 |
guidance_scale=7.5,
|
70 |
-
lora_scale=0.
|
71 |
progress=gr.Progress(track_tqdm=True)
|
72 |
):
|
73 |
|
|
|
13 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
14 |
|
15 |
def get_lora_sd_pipeline(
|
16 |
+
lora_dir='./lora_man_animestyle',
|
17 |
base_model_name_or_path=None,
|
18 |
dtype=torch.float16,
|
19 |
adapter_name="default"
|
20 |
):
|
21 |
|
22 |
+
unet_sub_dir = os.path.join(lora_dir, "unet")
|
23 |
+
text_encoder_sub_dir = os.path.join(lora_dir, "text_encoder")
|
24 |
|
25 |
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
26 |
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
|
|
56 |
return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
|
57 |
torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
|
58 |
|
59 |
+
pipe_default = get_lora_sd_pipeline(lora_dir='./lora_man_animestyle', base_model_name_or_path=model_default, dtype=torch_dtype).to(device)
|
60 |
|
61 |
def infer(
|
62 |
prompt,
|
|
|
67 |
model='stable-diffusion-v1-5/stable-diffusion-v1-5',
|
68 |
seed=4,
|
69 |
guidance_scale=7.5,
|
70 |
+
lora_scale=0.5,
|
71 |
progress=gr.Progress(track_tqdm=True)
|
72 |
):
|
73 |
|