Model-Demo / app.py
Mariam-Elz's picture
Update app.py
200bf7b verified
raw
history blame
10.6 kB
# # Not ready to use yet
# import spaces
# import argparse
# import numpy as np
# import gradio as gr
# from omegaconf import OmegaConf
# import torch
# from PIL import Image
# import PIL
# from pipelines import TwoStagePipeline
# from huggingface_hub import hf_hub_download
# import os
# import rembg
# from typing import Any
# import json
# import os
# import json
# import argparse
# from model import CRM
# from inference import generate3d
# pipeline = None
# rembg_session = rembg.new_session()
# def expand_to_square(image, bg_color=(0, 0, 0, 0)):
# # expand image to 1:1
# width, height = image.size
# if width == height:
# return image
# new_size = (max(width, height), max(width, height))
# new_image = Image.new("RGBA", new_size, bg_color)
# paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
# new_image.paste(image, paste_position)
# return new_image
# def check_input_image(input_image):
# if input_image is None:
# raise gr.Error("No image uploaded!")
# def remove_background(
# image: PIL.Image.Image,
# rembg_session: Any = None,
# force: bool = False,
# **rembg_kwargs,
# ) -> PIL.Image.Image:
# do_remove = True
# if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
# # explain why current do not rm bg
# print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
# background = Image.new("RGBA", image.size, (0, 0, 0, 0))
# image = Image.alpha_composite(background, image)
# do_remove = False
# do_remove = do_remove or force
# if do_remove:
# image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
# return image
# def do_resize_content(original_image: Image, scale_rate):
# # resize image content wile retain the original image size
# if scale_rate != 1:
# # Calculate the new size after rescaling
# new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
# # Resize the image while maintaining the aspect ratio
# resized_image = original_image.resize(new_size)
# # Create a new image with the original size and black background
# padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
# paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
# padded_image.paste(resized_image, paste_position)
# return padded_image
# else:
# return original_image
# def add_background(image, bg_color=(255, 255, 255)):
# # given an RGBA image, alpha channel is used as mask to add background color
# background = Image.new("RGBA", image.size, bg_color)
# return Image.alpha_composite(background, image)
# def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
# """
# input image is a pil image in RGBA, return RGB image
# """
# print(background_choice)
# if background_choice == "Alpha as mask":
# background = Image.new("RGBA", image.size, (0, 0, 0, 0))
# image = Image.alpha_composite(background, image)
# else:
# image = remove_background(image, rembg_session, force=True)
# image = do_resize_content(image, foreground_ratio)
# image = expand_to_square(image)
# image = add_background(image, backgroud_color)
# return image.convert("RGB")
# @spaces.GPU
# def gen_image(input_image, seed, scale, step):
# global pipeline, model, args
# pipeline.set_seed(seed)
# rt_dict = pipeline(input_image, scale=scale, step=step)
# stage1_images = rt_dict["stage1_images"]
# stage2_images = rt_dict["stage2_images"]
# np_imgs = np.concatenate(stage1_images, 1)
# np_xyzs = np.concatenate(stage2_images, 1)
# glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
# return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path#, obj_path
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--stage1_config",
# type=str,
# default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
# help="config for stage1",
# )
# parser.add_argument(
# "--stage2_config",
# type=str,
# default="configs/stage2-v2-snr.yaml",
# help="config for stage2",
# )
# parser.add_argument("--device", type=str, default="cuda")
# args = parser.parse_args()
# crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
# specs = json.load(open("configs/specs_objaverse_total.json"))
# model = CRM(specs)
# model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
# model = model.to(args.device)
# stage1_config = OmegaConf.load(args.stage1_config).config
# stage2_config = OmegaConf.load(args.stage2_config).config
# stage2_sampler_config = stage2_config.sampler
# stage1_sampler_config = stage1_config.sampler
# stage1_model_config = stage1_config.models
# stage2_model_config = stage2_config.models
# xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
# pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
# stage1_model_config.resume = pixel_path
# stage2_model_config.resume = xyz_path
# pipeline = TwoStagePipeline(
# stage1_model_config,
# stage2_model_config,
# stage1_sampler_config,
# stage2_sampler_config,
# device=args.device,
# dtype=torch.float32
# )
# _DESCRIPTION = '''
# * Our [official implementation](https://github.com/thu-ml/CRM) uses UV texture instead of vertex color. It has better texture than this online demo.
# * Project page of CRM: https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/
# * If you find the output unsatisfying, try using different seeds:)
# '''
# with gr.Blocks() as demo:
# gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
# gr.Markdown(_DESCRIPTION)
# with gr.Row():
# with gr.Column():
# with gr.Row():
# image_input = gr.Image(
# label="Image input",
# image_mode="RGBA",
# sources="upload",
# type="pil",
# )
# processed_image = gr.Image(label="Processed Image", interactive=False, type="pil", image_mode="RGB")
# with gr.Row():
# with gr.Column():
# with gr.Row():
# background_choice = gr.Radio([
# "Alpha as mask",
# "Auto Remove background"
# ], value="Auto Remove background",
# label="backgroud choice")
# # do_remove_background = gr.Checkbox(label=, value=True)
# # force_remove = gr.Checkbox(label=, value=False)
# back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
# foreground_ratio = gr.Slider(
# label="Foreground Ratio",
# minimum=0.5,
# maximum=1.0,
# value=1.0,
# step=0.05,
# )
# with gr.Column():
# seed = gr.Number(value=1234, label="seed", precision=0)
# guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale")
# step = gr.Number(value=30, minimum=30, maximum=100, label="sample steps", precision=0)
# text_button = gr.Button("Generate 3D shape")
# gr.Examples(
# examples=[os.path.join("examples", i) for i in os.listdir("examples")],
# inputs=[image_input],
# examples_per_page = 20,
# )
# with gr.Column():
# image_output = gr.Image(interactive=False, label="Output RGB image")
# xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
# output_model = gr.Model3D(
# label="Output OBJ",
# interactive=False,
# )
# gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
# inputs = [
# processed_image,
# seed,
# guidance_scale,
# step,
# ]
# outputs = [
# image_output,
# xyz_ouput,
# output_model,
# # output_obj,
# ]
# text_button.click(fn=check_input_image, inputs=[image_input]).success(
# fn=preprocess_image,
# inputs=[image_input, background_choice, foreground_ratio, back_groud_color],
# outputs=[processed_image],
# ).success(
# fn=gen_image,
# inputs=inputs,
# outputs=outputs,
# )
# demo.queue().launch()
import torch
import gradio as gr
import requests
import os
# Download model weights from Hugging Face model repo (if not already present)
model_repo = "Mariam-Elz/CRM" # Your Hugging Face model repo
model_files = {
"ccm-diffusion.pth": "ccm-diffusion.pth",
"pixel-diffusion.pth": "pixel-diffusion.pth",
"CRM.pth": "CRM.pth",
}
os.makedirs("models", exist_ok=True)
for filename, output_path in model_files.items():
file_path = f"models/{output_path}"
if not os.path.exists(file_path):
url = f"https://huggingface.co/{model_repo}/resolve/main/{filename}"
print(f"Downloading {filename}...")
response = requests.get(url)
with open(file_path, "wb") as f:
f.write(response.content)
# Load model (This part depends on how the model is defined)
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
model_path = "models/CRM.pth"
model = torch.load(model_path, map_location=device)
model.eval()
return model
model = load_model()
# Define inference function
def infer(image):
"""Process input image and return a reconstructed image."""
with torch.no_grad():
# Assuming model expects a tensor input
image_tensor = torch.tensor(image).to(device)
output = model(image_tensor)
return output.cpu().numpy()
# Create Gradio UI
demo = gr.Interface(
fn=infer,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="numpy"),
title="Convolutional Reconstruction Model",
description="Upload an image to get the reconstructed output."
)
if __name__ == "__main__":
demo.launch()