yusuf commited on
Commit
d08d609
·
1 Parent(s): 5e53a08

aeayüz düzenleme 2

Browse files
Files changed (1) hide show
  1. app.py +136 -7
app.py CHANGED
@@ -9,9 +9,144 @@ from leffa_utils.densepose_predictor import DensePosePredictor
9
  from leffa_utils.utils import resize_and_center, list_dir, get_agnostic_mask_hd, get_agnostic_mask_dc
10
  from preprocess.humanparsing.run_parsing import Parsing
11
  from preprocess.openpose.run_openpose import OpenPose
12
-
13
  import gradio as gr
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  if __name__ == "__main__":
16
  leffa_predictor = LeffaPredictor()
17
  example_dir = "./ckpts/examples"
@@ -47,13 +182,10 @@ if __name__ == "__main__":
47
  """
48
 
49
  with gr.Blocks(theme=theme, title="Dehasoft AI Studio") as demo:
50
- # Başlık ve Açıklama
51
  gr.Markdown(title, elem_classes=["title"])
52
  gr.Markdown(description, elem_classes=["description"])
53
 
54
- # Sekmeler
55
  with gr.Tabs(elem_classes=["tabs"]):
56
- # Virtual Try-on Sekmesi
57
  with gr.TabItem("Virtual Try-On", elem_id="vt_tab"):
58
  with gr.Row(equal_height=True):
59
  with gr.Column(scale=1):
@@ -163,7 +295,6 @@ if __name__ == "__main__":
163
  _js="() => { document.querySelector('.generate-btn').classList.add('loading'); setTimeout(() => document.querySelector('.generate-btn').classList.remove('loading'), 5000); }"
164
  )
165
 
166
- # Pose Transfer Sekmesi
167
  with gr.TabItem("Pose Transfer", elem_id="pt_tab"):
168
  with gr.Row(equal_height=True):
169
  with gr.Column(scale=1):
@@ -256,10 +387,8 @@ if __name__ == "__main__":
256
  _js="() => { document.querySelector('.generate-btn').classList.add('loading'); setTimeout(() => document.querySelector('.generate-btn').classList.remove('loading'), 5000); }"
257
  )
258
 
259
- # Altbilgi
260
  gr.Markdown(footer_note, elem_classes=["footer"])
261
 
262
- # Özel CSS
263
  demo.css = """
264
  .title { text-align: center; font-size: 2.5em; margin-bottom: 10px; color: #4f46e5; }
265
  .description { text-align: center; font-size: 1.2em; margin-bottom: 20px; color: #374151; }
 
9
  from leffa_utils.utils import resize_and_center, list_dir, get_agnostic_mask_hd, get_agnostic_mask_dc
10
  from preprocess.humanparsing.run_parsing import Parsing
11
  from preprocess.openpose.run_openpose import OpenPose
 
12
  import gradio as gr
13
 
14
+ # Download checkpoints
15
+ snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts")
16
+
17
+ class LeffaPredictor(object):
18
+ def __init__(self):
19
+ self.mask_predictor = AutoMasker(
20
+ densepose_path="./ckpts/densepose",
21
+ schp_path="./ckpts/schp",
22
+ )
23
+
24
+ self.densepose_predictor = DensePosePredictor(
25
+ config_path="./ckpts/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
26
+ weights_path="./ckpts/densepose/model_final_162be9.pkl",
27
+ )
28
+
29
+ self.parsing = Parsing(
30
+ atr_path="./ckpts/humanparsing/parsing_atr.onnx",
31
+ lip_path="./ckpts/humanparsing/parsing_lip.onnx",
32
+ )
33
+
34
+ self.openpose = OpenPose(
35
+ body_model_path="./ckpts/openpose/body_pose_model.pth",
36
+ )
37
+
38
+ vt_model_hd = LeffaModel(
39
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
40
+ pretrained_model="./ckpts/virtual_tryon.pth",
41
+ dtype="float16",
42
+ )
43
+ self.vt_inference_hd = LeffaInference(model=vt_model_hd)
44
+
45
+ vt_model_dc = LeffaModel(
46
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-inpainting",
47
+ pretrained_model="./ckpts/virtual_tryon_dc.pth",
48
+ dtype="float16",
49
+ )
50
+ self.vt_inference_dc = LeffaInference(model=vt_model_dc)
51
+
52
+ pt_model = LeffaModel(
53
+ pretrained_model_name_or_path="./ckpts/stable-diffusion-xl-1.0-inpainting-0.1",
54
+ pretrained_model="./ckpts/pose_transfer.pth",
55
+ dtype="float16",
56
+ )
57
+ self.pt_inference = LeffaInference(model=pt_model)
58
+
59
+ def leffa_predict(
60
+ self,
61
+ src_image_path,
62
+ ref_image_path,
63
+ control_type,
64
+ ref_acceleration=False,
65
+ step=50,
66
+ scale=2.5,
67
+ seed=42,
68
+ vt_model_type="viton_hd",
69
+ vt_garment_type="upper_body",
70
+ vt_repaint=False
71
+ ):
72
+ assert control_type in [
73
+ "virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)
74
+ src_image = Image.open(src_image_path)
75
+ ref_image = Image.open(ref_image_path)
76
+ src_image = resize_and_center(src_image, 768, 1024)
77
+ ref_image = resize_and_center(ref_image, 768, 1024)
78
+
79
+ src_image_array = np.array(src_image)
80
+
81
+ # Mask
82
+ if control_type == "virtual_tryon":
83
+ src_image = src_image.convert("RGB")
84
+ model_parse, _ = self.parsing(src_image.resize((384, 512)))
85
+ keypoints = self.openpose(src_image.resize((384, 512)))
86
+ if vt_model_type == "viton_hd":
87
+ mask = get_agnostic_mask_hd(
88
+ model_parse, keypoints, vt_garment_type)
89
+ elif vt_model_type == "dress_code":
90
+ mask = get_agnostic_mask_dc(
91
+ model_parse, keypoints, vt_garment_type)
92
+ mask = mask.resize((768, 1024))
93
+ elif control_type == "pose_transfer":
94
+ mask = Image.fromarray(np.ones_like(src_image_array) * 255)
95
+
96
+ # DensePose
97
+ if control_type == "virtual_tryon":
98
+ if vt_model_type == "viton_hd":
99
+ src_image_seg_array = self.densepose_predictor.predict_seg(
100
+ src_image_array)[:, :, ::-1]
101
+ src_image_seg = Image.fromarray(src_image_seg_array)
102
+ densepose = src_image_seg
103
+ elif vt_model_type == "dress_code":
104
+ src_image_iuv_array = self.densepose_predictor.predict_iuv(
105
+ src_image_array)
106
+ src_image_seg_array = src_image_iuv_array[:, :, 0:1]
107
+ src_image_seg_array = np.concatenate(
108
+ [src_image_seg_array] * 3, axis=-1)
109
+ src_image_seg = Image.fromarray(src_image_seg_array)
110
+ densepose = src_image_seg
111
+ elif control_type == "pose_transfer":
112
+ src_image_iuv_array = self.densepose_predictor.predict_iuv(
113
+ src_image_array)[:, :, ::-1]
114
+ src_image_iuv = Image.fromarray(src_image_iuv_array)
115
+ densepose = src_image_iuv
116
+
117
+ # Leffa
118
+ transform = LeffaTransform()
119
+
120
+ data = {
121
+ "src_image": [src_image],
122
+ "ref_image": [ref_image],
123
+ "mask": [mask],
124
+ "densepose": [densepose],
125
+ }
126
+ data = transform(data)
127
+ if control_type == "virtual_tryon":
128
+ if vt_model_type == "viton_hd":
129
+ inference = self.vt_inference_hd
130
+ elif vt_model_type == "dress_code":
131
+ inference = self.vt_inference_dc
132
+ elif control_type == "pose_transfer":
133
+ inference = self.pt_inference
134
+ output = inference(
135
+ data,
136
+ ref_acceleration=ref_acceleration,
137
+ num_inference_steps=step,
138
+ guidance_scale=scale,
139
+ seed=seed,
140
+ repaint=vt_repaint,)
141
+ gen_image = output["generated_image"][0]
142
+ return np.array(gen_image), np.array(mask), np.array(densepose)
143
+
144
+ def dehasoft(self, src_image_path, ref_image_path, ref_acceleration, step, scale, seed, vt_model_type, vt_garment_type, vt_repaint):
145
+ return self.leffa_predict(src_image_path, ref_image_path, "virtual_tryon", ref_acceleration, step, scale, seed, vt_model_type, vt_garment_type, vt_repaint)
146
+
147
+ def leffa_predict_pt(self, src_image_path, ref_image_path, ref_acceleration, step, scale, seed):
148
+ return self.leffa_predict(src_image_path, ref_image_path, "pose_transfer", ref_acceleration, step, scale, seed)
149
+
150
  if __name__ == "__main__":
151
  leffa_predictor = LeffaPredictor()
152
  example_dir = "./ckpts/examples"
 
182
  """
183
 
184
  with gr.Blocks(theme=theme, title="Dehasoft AI Studio") as demo:
 
185
  gr.Markdown(title, elem_classes=["title"])
186
  gr.Markdown(description, elem_classes=["description"])
187
 
 
188
  with gr.Tabs(elem_classes=["tabs"]):
 
189
  with gr.TabItem("Virtual Try-On", elem_id="vt_tab"):
190
  with gr.Row(equal_height=True):
191
  with gr.Column(scale=1):
 
295
  _js="() => { document.querySelector('.generate-btn').classList.add('loading'); setTimeout(() => document.querySelector('.generate-btn').classList.remove('loading'), 5000); }"
296
  )
297
 
 
298
  with gr.TabItem("Pose Transfer", elem_id="pt_tab"):
299
  with gr.Row(equal_height=True):
300
  with gr.Column(scale=1):
 
387
  _js="() => { document.querySelector('.generate-btn').classList.add('loading'); setTimeout(() => document.querySelector('.generate-btn').classList.remove('loading'), 5000); }"
388
  )
389
 
 
390
  gr.Markdown(footer_note, elem_classes=["footer"])
391
 
 
392
  demo.css = """
393
  .title { text-align: center; font-size: 2.5em; margin-bottom: 10px; color: #4f46e5; }
394
  .description { text-align: center; font-size: 1.2em; margin-bottom: 20px; color: #374151; }