Lifeinhockey commited on
Commit
82259b8
·
verified ·
1 Parent(s): bdedce5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -3
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
5
  from peft import PeftModel, LoraConfig
6
  import os
 
7
 
8
  MAX_SEED = np.iinfo(np.int32).max
9
  MAX_IMAGE_SIZE = 1024
@@ -65,6 +66,18 @@ pipe_controlnet = StableDiffusionControlNetPipeline.from_pretrained(
65
  torch_dtype=torch_dtype
66
  ).to(device)
67
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  def infer(
69
  prompt,
70
  negative_prompt,
@@ -84,13 +97,16 @@ def infer(
84
  generator = torch.Generator(device).manual_seed(seed)
85
 
86
  if use_control_net and control_image is not None and source_image is not None:
 
 
 
 
87
  # Используем ControlNet
88
  image = pipe_controlnet(
89
  prompt=prompt,
90
  negative_prompt=negative_prompt,
91
- image=source_image, ####################
92
- control_image=control_image, ###############
93
- #image=control_image, # Используем загруженное изображение как карту позы
94
  width=width,
95
  height=height,
96
  num_inference_steps=num_inference_steps,
 
4
  from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
5
  from peft import PeftModel, LoraConfig
6
  import os
7
+ from PIL import Image
8
 
9
  MAX_SEED = np.iinfo(np.int32).max
10
  MAX_IMAGE_SIZE = 1024
 
66
  torch_dtype=torch_dtype
67
  ).to(device)
68
 
69
+ def preprocess_image(image, target_width, target_height):
70
+ """
71
+ Преобразует изображение в формат, подходящий для модели.
72
+ """
73
+ if isinstance(image, np.ndarray):
74
+ image = Image.fromarray(image)
75
+ image = image.resize((target_width, target_height), Image.LANCZOS)
76
+ image = np.array(image).astype(np.float32) / 255.0 # Нормализация [0, 1]
77
+ image = image[None].transpose(0, 3, 1, 2) # Преобразуем в (batch, channels, height, width)
78
+ image = torch.from_numpy(image).to(device)
79
+ return image
80
+
81
  def infer(
82
  prompt,
83
  negative_prompt,
 
97
  generator = torch.Generator(device).manual_seed(seed)
98
 
99
  if use_control_net and control_image is not None and source_image is not None:
100
+ # Преобразуем изображения
101
+ source_image = preprocess_image(source_image, width, height)
102
+ control_image = preprocess_image(control_image, width, height)
103
+
104
  # Используем ControlNet
105
  image = pipe_controlnet(
106
  prompt=prompt,
107
  negative_prompt=negative_prompt,
108
+ image=source_image,
109
+ control_image=control_image,
 
110
  width=width,
111
  height=height,
112
  num_inference_steps=num_inference_steps,