|
from huggingface_hub import snapshot_download |
|
snapshot_download(repo_id="franciszzj/Leffa", local_dir="./ckpts") |
|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
|
|
|
|
class LeffaPredictor: |
|
def __init__(self): |
|
|
|
|
|
self.model = torch.load("virtual_tryon_dc.pth", map_location=torch.device('cpu')) |
|
|
|
def leffa_predict_vt(self, human_img_path, garment_img_path, garment_type): |
|
|
|
human_img = Image.open(human_img_path).convert("RGB") |
|
garment_img = Image.open(garment_img_path).convert("RGB") |
|
blended_img = Image.blend(human_img, garment_img.resize(human_img.size), alpha=0.5) |
|
return blended_img |
|
|
|
|
|
leffa_predictor = LeffaPredictor() |
|
|
|
def api_tryon(human_file, garment_file, garment_type): |
|
"""Elabora l'immagine caricata con il modello Leffa.""" |
|
human_path = "human_temp.jpg" |
|
garment_path = "garment_temp.jpg" |
|
|
|
|
|
human_file.save(human_path) |
|
garment_file.save(garment_path) |
|
|
|
|
|
output_img = leffa_predictor.leffa_predict_vt(human_path, garment_path, garment_type) |
|
|
|
|
|
return output_img |
|
|
|
|
|
interface = gr.Interface( |
|
fn=api_tryon, |
|
inputs=[ |
|
gr.Image(type="file", label="Immagine Persona"), |
|
gr.Image(type="file", label="Immagine Prodotto"), |
|
gr.Radio(choices=["upper", "bottom", "dressed"], label="Categoria Abito", value="upper") |
|
], |
|
outputs=gr.Image(type="pil", label="Risultato Try-On"), |
|
title="API Virtual Try-On" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch(server_name="0.0.0.0", server_port=7860) |
|
|