Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
f5c7dc7
·
verified ·
1 Parent(s): e3da8fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -73
app.py CHANGED
@@ -32,107 +32,70 @@ from torchvision.transforms.functional import to_pil_image
32
 
33
  app = Flask(__name__)
34
 
 
35
  base_path = 'yisol/IDM-VTON'
36
- example_path = os.path.join(os.path.dirname(__file__), 'example')
37
 
 
38
  unet = UNet2DConditionModel.from_pretrained(
39
  base_path,
40
  subfolder="unet",
41
  torch_dtype=torch.float16,
42
  force_download=False
43
  )
44
- unet.requires_grad_(False)
45
  tokenizer_one = AutoTokenizer.from_pretrained(
46
  base_path,
47
  subfolder="tokenizer",
48
- revision=None,
49
  use_fast=False,
50
  force_download=False
51
  )
52
  tokenizer_two = AutoTokenizer.from_pretrained(
53
  base_path,
54
  subfolder="tokenizer_2",
55
- revision=None,
56
  use_fast=False,
57
  force_download=False
58
  )
59
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
60
-
61
- text_encoder_one = CLIPTextModel.from_pretrained(
62
- base_path,
63
- subfolder="text_encoder",
64
- torch_dtype=torch.float16,
65
- force_download=False
66
- )
67
- text_encoder_two = CLIPTextModelWithProjection.from_pretrained(
68
- base_path,
69
- subfolder="text_encoder_2",
70
- torch_dtype=torch.float16,
71
- force_download=False
72
- )
73
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
74
- base_path,
75
- subfolder="image_encoder",
76
- torch_dtype=torch.float16,
77
- force_download=False
78
- )
79
- vae = AutoencoderKL.from_pretrained(base_path,
80
- subfolder="vae",
81
- torch_dtype=torch.float16,
82
- force_download=False
83
- )
84
-
85
- UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(
86
- base_path,
87
- subfolder="unet_encoder",
88
- torch_dtype=torch.float16,
89
- force_download=False
90
- )
91
 
92
  parsing_model = Parsing(0)
93
  openpose_model = OpenPose(0)
94
 
95
- UNet_Encoder.requires_grad_(False)
96
- image_encoder.requires_grad_(False)
97
- vae.requires_grad_(False)
98
- unet.requires_grad_(False)
99
- text_encoder_one.requires_grad_(False)
100
- text_encoder_two.requires_grad_(False)
101
- tensor_transfrom = transforms.Compose(
102
- [
103
- transforms.ToTensor(),
104
- transforms.Normalize([0.5], [0.5]),
105
- ]
106
- )
107
-
108
  pipe = TryonPipeline.from_pretrained(
109
- base_path,
110
- unet=unet,
111
- vae=vae,
112
- feature_extractor= CLIPImageProcessor(),
113
- text_encoder = text_encoder_one,
114
- text_encoder_2 = text_encoder_two,
115
- tokenizer = tokenizer_one,
116
- tokenizer_2 = tokenizer_two,
117
- scheduler = noise_scheduler,
118
- image_encoder=image_encoder,
119
- torch_dtype=torch.float16,
120
- force_download=False
121
  )
122
  pipe.unet_encoder = UNet_Encoder
123
 
 
 
 
 
 
 
124
  def pil_to_binary_mask(pil_image, threshold=0):
125
  np_image = np.array(pil_image)
126
  grayscale_image = Image.fromarray(np_image).convert("L")
127
  binary_mask = np.array(grayscale_image) > threshold
128
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
129
- for i in range(binary_mask.shape[0]):
130
- for j in range(binary_mask.shape[1]):
131
- if binary_mask[i, j]:
132
- mask[i, j] = 1
133
- mask = (mask * 255).astype(np.uint8)
134
- output_mask = Image.fromarray(mask)
135
- return output_mask
136
 
137
  def get_image_from_url(url):
138
  try:
@@ -157,8 +120,7 @@ def encode_image_to_base64(img):
157
  try:
158
  buffered = BytesIO()
159
  img.save(buffered, format="PNG")
160
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
161
- return img_str
162
  except Exception as e:
163
  logging.error(f"Error encoding image: {e}")
164
  raise
@@ -283,7 +245,6 @@ def tryon_v2():
283
  human_image_data = data['human_image']
284
  garment_image_data = data['garment_image']
285
 
286
- # Process images (base64 ou URL)
287
  human_image = process_image(human_image_data)
288
  garment_image = process_image(garment_image_data)
289
 
@@ -294,18 +255,18 @@ def tryon_v2():
294
  seed = int(data.get('seed', random.randint(0, 9999999)))
295
  categorie = data.get('categorie', 'upper_body')
296
 
297
- # Vérifie si 'mask_image' est présent dans les données
298
  mask_image = None
299
  if 'mask_image' in data:
300
  mask_image_data = data['mask_image']
301
  mask_image = process_image(mask_image_data)
302
-
303
  human_dict = {
304
  'background': human_image,
305
  'layers': [mask_image] if not use_auto_mask else None,
306
  'composite': None
307
  }
308
- output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed , categorie)
 
309
  return jsonify({
310
  'image_id': save_image(output_image)
311
  })
 
32
 
33
  app = Flask(__name__)
34
 
35
+ # Chemins de base pour les modèles
36
  base_path = 'yisol/IDM-VTON'
 
37
 
38
+ # Chargement des modèles
39
  unet = UNet2DConditionModel.from_pretrained(
40
  base_path,
41
  subfolder="unet",
42
  torch_dtype=torch.float16,
43
  force_download=False
44
  )
 
45
  tokenizer_one = AutoTokenizer.from_pretrained(
46
  base_path,
47
  subfolder="tokenizer",
 
48
  use_fast=False,
49
  force_download=False
50
  )
51
  tokenizer_two = AutoTokenizer.from_pretrained(
52
  base_path,
53
  subfolder="tokenizer_2",
 
54
  use_fast=False,
55
  force_download=False
56
  )
57
  noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler")
58
+ text_encoder_one = CLIPTextModel.from_pretrained(base_path, subfolder="text_encoder", torch_dtype=torch.float16)
59
+ text_encoder_two = CLIPTextModelWithProjection.from_pretrained(base_path, subfolder="text_encoder_2", torch_dtype=torch.float16)
60
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_path, subfolder="image_encoder", torch_dtype=torch.float16)
61
+ vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
62
+ UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  parsing_model = Parsing(0)
65
  openpose_model = OpenPose(0)
66
 
67
+ # Préparation du pipeline Tryon
 
 
 
 
 
 
 
 
 
 
 
 
68
  pipe = TryonPipeline.from_pretrained(
69
+ base_path,
70
+ unet=unet,
71
+ vae=vae,
72
+ feature_extractor=CLIPImageProcessor(),
73
+ text_encoder=text_encoder_one,
74
+ text_encoder_2=text_encoder_two,
75
+ tokenizer=tokenizer_one,
76
+ tokenizer_2=tokenizer_two,
77
+ scheduler=noise_scheduler,
78
+ image_encoder=image_encoder,
79
+ torch_dtype=torch.float16,
80
+ force_download=False
81
  )
82
  pipe.unet_encoder = UNet_Encoder
83
 
84
+ # Utilisation des transformations d'images
85
+ tensor_transfrom = transforms.Compose([
86
+ transforms.ToTensor(),
87
+ transforms.Normalize([0.5], [0.5]),
88
+ ])
89
+
90
  def pil_to_binary_mask(pil_image, threshold=0):
91
  np_image = np.array(pil_image)
92
  grayscale_image = Image.fromarray(np_image).convert("L")
93
  binary_mask = np.array(grayscale_image) > threshold
94
  mask = np.zeros(binary_mask.shape, dtype=np.uint8)
95
+ mask[binary_mask] = 1
96
+ return Image.fromarray((mask * 255).astype(np.uint8))
97
+
98
+
 
 
 
99
 
100
  def get_image_from_url(url):
101
  try:
 
120
  try:
121
  buffered = BytesIO()
122
  img.save(buffered, format="PNG")
123
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
 
124
  except Exception as e:
125
  logging.error(f"Error encoding image: {e}")
126
  raise
 
245
  human_image_data = data['human_image']
246
  garment_image_data = data['garment_image']
247
 
 
248
  human_image = process_image(human_image_data)
249
  garment_image = process_image(garment_image_data)
250
 
 
255
  seed = int(data.get('seed', random.randint(0, 9999999)))
256
  categorie = data.get('categorie', 'upper_body')
257
 
 
258
  mask_image = None
259
  if 'mask_image' in data:
260
  mask_image_data = data['mask_image']
261
  mask_image = process_image(mask_image_data)
262
+
263
  human_dict = {
264
  'background': human_image,
265
  'layers': [mask_image] if not use_auto_mask else None,
266
  'composite': None
267
  }
268
+
269
+ output_image, mask_image = start_tryon(human_dict, garment_image, description, use_auto_mask, use_auto_crop, denoise_steps, seed, categorie)
270
  return jsonify({
271
  'image_id': save_image(output_image)
272
  })