Lifeinhockey commited on
Commit
a854895
·
verified ·
1 Parent(s): 6425629

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -13,14 +13,14 @@ model_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
13
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
  def get_lora_sd_pipeline(
16
- ckpt_dir='./lora_man_animestyle',
17
  base_model_name_or_path=None,
18
  dtype=torch.float16,
19
  adapter_name="default"
20
  ):
21
 
22
- unet_sub_dir = os.path.join(ckpt_dir, "unet")
23
- text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
24
 
25
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
26
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
@@ -56,7 +56,7 @@ def align_embeddings(prompt_embeds, negative_prompt_embeds):
56
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
57
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
58
 
59
- pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_man_animestyle', base_model_name_or_path=model_default, dtype=torch_dtype).to(device)
60
 
61
  def infer(
62
  prompt,
@@ -67,7 +67,7 @@ def infer(
67
  model='stable-diffusion-v1-5/stable-diffusion-v1-5',
68
  seed=4,
69
  guidance_scale=7.5,
70
- lora_scale=0.05,
71
  progress=gr.Progress(track_tqdm=True)
72
  ):
73
 
 
13
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
 
15
  def get_lora_sd_pipeline(
16
+ lora_dir='./lora_man_animestyle',
17
  base_model_name_or_path=None,
18
  dtype=torch.float16,
19
  adapter_name="default"
20
  ):
21
 
22
+ unet_sub_dir = os.path.join(lora_dir, "unet")
23
+ text_encoder_sub_dir = os.path.join(lora_dir, "text_encoder")
24
 
25
  if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
26
  config = LoraConfig.from_pretrained(text_encoder_sub_dir)
 
56
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
57
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
58
 
59
+ pipe_default = get_lora_sd_pipeline(lora_dir='./lora_man_animestyle', base_model_name_or_path=model_default, dtype=torch_dtype).to(device)
60
 
61
  def infer(
62
  prompt,
 
67
  model='stable-diffusion-v1-5/stable-diffusion-v1-5',
68
  seed=4,
69
  guidance_scale=7.5,
70
+ lora_scale=0.5,
71
  progress=gr.Progress(track_tqdm=True)
72
  ):
73