zaghamrasool's picture
Update app.py
af6493c verified
raw
history blame
7.35 kB
import os
import cv2
import gradio as gr
import numpy as np
import random
import base64
import requests
import json
import time
MAX_SEED = 999999
example_path = os.path.join(os.path.dirname(__file__), 'assets')
garm_list = os.listdir(os.path.join(example_path, "cloth"))
garm_list_path = [os.path.join(example_path, "cloth", garm) for garm in garm_list]
human_list = os.listdir(os.path.join(example_path, "human"))
human_list_path = [os.path.join(example_path, "human", human) for human in human_list]
# API details
base_url = "https://huggingface.co/spaces/zaghamrasool/Z-Virtual-Try-On"
upload_image_url = f"{base_url}/upload_image"
create_save_task_url = f"{base_url}/create_save_task"
execute_task_url = f"{base_url}/execute_task"
query_task_url = f"{base_url}/query_task"
def tryon(person_img, garment_img, seed, randomize_seed):
post_start_time = time.time()
if person_img is None or garment_img is None:
gr.Warning("Empty image")
return None, None, "Empty image"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Encode images
encoded_person_img = cv2.imencode('.jpg', cv2.cvtColor(person_img, cv2.COLOR_RGB2BGR))[1].tobytes()
encoded_person_img = base64.b64encode(encoded_person_img).decode('utf-8')
encoded_garment_img = cv2.imencode('.jpg', cv2.cvtColor(garment_img, cv2.COLOR_RGB2BGR))[1].tobytes()
encoded_garment_img = base64.b64encode(encoded_garment_img).decode('utf-8')
# Prepare data
data = {
"clothImage": encoded_garment_img,
"humanImage": encoded_person_img,
"seed": seed
}
uuid = None
try:
# First API call to create task
response = requests.post(create_save_task_url, data=json.dumps(data), timeout=50)
if response.status_code == 200:
result = response.json().get('result', {})
if result.get('status') == "success":
uuid = result.get('taskId') # Use taskId for querying
else:
raise Exception("Failed to create task, no task ID received.")
else:
raise Exception(f"Failed to create task. Status Code: {response.status_code}")
except Exception as err:
print(f"Post Exception Error: {err}")
raise gr.Error("Too many users, please try again later")
post_end_time = time.time()
print(f"post time used: {post_end_time - post_start_time}")
# Retry loop to query task status
get_start_time = time.time()
time.sleep(5)
Max_Retry = 20
result_img = None
info = ""
err_log = ""
if not uuid:
err_log = "No task ID received from backend."
info = "Failed to get task ID from backend"
else:
for i in range(Max_Retry):
try:
url = f"{query_task_url}?taskId={uuid}"
response = requests.get(url, timeout=20)
if response.status_code == 200:
result = response.json()['result']
status = result['status']
if status == "success":
result = base64.b64decode(result['result'])
result_np = np.frombuffer(result, np.uint8)
result_img = cv2.imdecode(result_np, cv2.IMREAD_UNCHANGED)
result_img = cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR)
info = "Success"
break
elif status == "error":
err_log = "Status is Error"
info = "Error"
break
else:
err_log = "URL error, please contact the admin"
info = "URL error, please contact the admin"
break
except requests.exceptions.ReadTimeout:
err_log = "Http Timeout"
info = "Http Timeout, please try again later"
except Exception as err:
err_log = f"Get Exception Error: {err}"
time.sleep(5)
get_end_time = time.time()
print(f"get time used: {get_end_time - get_start_time}")
print(f"all time used: {get_end_time - get_start_time + post_end_time - post_start_time}")
if info == "":
err_log = f"No image after {Max_Retry} retries"
info = "Too many users, please try again later"
if info != "Success":
print(f"Error Log: {err_log}")
gr.Warning(info)
return result_img, seed, info
def load_description(fp):
with open(fp, 'r', encoding='utf-8') as f:
return f.read()
css = """
#col-left { margin: 0 auto; max-width: 430px; }
#col-mid { margin: 0 auto; max-width: 430px; }
#col-right { margin: 0 auto; max-width: 430px; }
#col-showcase { margin: 0 auto; max-width: 1100px; }
#button { color: blue; }
"""
with gr.Blocks(css=css) as Tryon:
gr.HTML(load_description("assets/title.md"))
with gr.Row():
with gr.Column(elem_id="col-left"):
gr.HTML("<div style='text-align: center; font-size: 20px;'>Step 1. Upload a person image ⬇️</div>")
with gr.Column(elem_id="col-mid"):
gr.HTML("<div style='text-align: center; font-size: 20px;'>Step 2. Upload a garment image ⬇️</div>")
with gr.Column(elem_id="col-right"):
gr.HTML("<div style='text-align: center; font-size: 20px;'>Step 3. Press “Run” to get try-on results</div>")
with gr.Row():
with gr.Column(elem_id="col-left"):
imgs = gr.Image(label="Person image", sources='upload', type="numpy")
gr.Examples(inputs=imgs, examples_per_page=12, examples=human_list_path)
with gr.Column(elem_id="col-mid"):
garm_img = gr.Image(label="Garment image", sources='upload', type="numpy")
gr.Examples(inputs=garm_img, examples_per_page=12, examples=garm_list_path)
with gr.Column(elem_id="col-right"):
image_out = gr.Image(label="Result", show_share_button=False)
with gr.Row():
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Random seed", value=True)
with gr.Row():
seed_used = gr.Number(label="Seed used")
result_info = gr.Text(label="Response")
test_button = gr.Button(value="Run", elem_id="button")
test_button.click(fn=tryon, inputs=[imgs, garm_img, seed, randomize_seed], outputs=[image_out, seed_used, result_info], api_name=False, concurrency_limit=45)
with gr.Column(elem_id="col-showcase"):
gr.HTML("<div style='text-align: center; font-size: 20px;'>Virtual try-on examples in pairs of person and garment images</div>")
gr.Examples(
examples=[["assets/examples/model2.png", "assets/examples/garment2.png", "assets/examples/result2.png"],
["assets/examples/model3.png", "assets/examples/garment3.png", "assets/examples/result3.png"],
["assets/examples/model1.png", "assets/examples/garment1.png", "assets/examples/result1.png"]],
inputs=[imgs, garm_img, image_out]
)
Tryon.queue(api_open=False).launch(show_api=False)
Tryon.launch()
print("Gradio app is running...")
print("Please open the link in your browser to access the app.")