File size: 5,176 Bytes
4e31b1a
aaa6458
829dfd4
aaa6458
 
 
b33bab2
81914fc
aaa6458
 
 
1087492
b33bab2
 
 
 
 
 
 
 
 
 
 
06be9c8
b33bab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829dfd4
 
06be9c8
829dfd4
 
388cf5c
b33bab2
aaa6458
b33bab2
aaa6458
829dfd4
 
aaa6458
829dfd4
 
06be9c8
829dfd4
b33bab2
 
 
 
829dfd4
b33bab2
 
 
 
 
 
 
829dfd4
b33bab2
 
829dfd4
b33bab2
829dfd4
 
b33bab2
aaa6458
b33bab2
06be9c8
 
b33bab2
aaa6458
06be9c8
829dfd4
 
388cf5c
b33bab2
388cf5c
829dfd4
 
 
06be9c8
829dfd4
b33bab2
829dfd4
 
 
 
aaa6458
829dfd4
388cf5c
aaa6458
 
 
b33bab2
 
 
aaa6458
 
 
 
829dfd4
aaa6458
 
 
829dfd4
06be9c8
aaa6458
 
 
 
b33bab2
06be9c8
aaa6458
1087492
aaa6458
48056a7
aaa6458
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
import gradio as gr
import numpy as np
from PIL import Image
import tempfile
import trimesh

# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Import Point-E modules
try:
    print("Loading Point-E model...")
    from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
    from point_e.diffusion.sampler import PointCloudSampler
    from point_e.models.configs import MODEL_CONFIGS, model_from_config
    from point_e.models.download import load_checkpoint
    from point_e.util.plotting import plot_point_cloud
except ImportError:
    print("Point-E modules not available. Please make sure Point-E is installed.")
    raise

# Create base model for image encoder
base_name = 'base40M-textvec'
base_model = model_from_config(MODEL_CONFIGS[base_name], device)
base_model.eval()
base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])

# Create upsampler model
upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
upsampler_model.eval()
upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])

# Create image to point cloud model
img2pc_name = 'base300M'
img2pc_model = model_from_config(MODEL_CONFIGS[img2pc_name], device)
img2pc_model.eval()
img2pc_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[img2pc_name])

# Load checkpoints
print("Loading model checkpoints...")
base_model.load_state_dict(load_checkpoint(base_name, device))
upsampler_model.load_state_dict(load_checkpoint('upsample', device))
img2pc_model.load_state_dict(load_checkpoint(img2pc_name, device))

# Create samplers
sampler = PointCloudSampler(
    device=device,
    models=[base_model, upsampler_model],
    diffusions=[base_diffusion, upsampler_diffusion],
    num_points=[1024, 4096],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0, 0.0],
)

img2pc_sampler = PointCloudSampler(
    device=device,
    models=[img2pc_model],
    diffusions=[img2pc_diffusion],
    num_points=[1024],
    aux_channels=['R', 'G', 'B'],
    guidance_scale=[3.0],
)

def preprocess_image(image):
    # Resize to match expected input size
    image = image.resize((256, 256))
    return image

def image_to_3d(image, num_steps=64):
    """
    Convert a single image to a 3D model using Point-E
    """
    if image is None:
        return None, "No image provided"
    
    try:
        # Preprocess image
        processed_image = preprocess_image(image)
        
        # Generate samples
        samples = None
        for i, x in enumerate(img2pc_sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[processed_image]))):
            samples = x
        
        # Extract point cloud
        pc = samples[-1]['pred_pc']
        colors = samples[-1]['pred_pc_aux']['R', 'G', 'B']
        
        # Create colored point cloud
        points = pc.cpu().numpy()[0]
        colors_np = colors.cpu().numpy()[0]
        
        # Create a mesh from point cloud
        point_cloud = trimesh.PointCloud(vertices=points, colors=colors_np)
        
        # Save as OBJ
        with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
            obj_path = obj_file.name
            point_cloud.export(obj_path)
        
        # Save as PLY for better Unity compatibility
        with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file:
            ply_path = ply_file.name
            point_cloud.export(ply_path)
        
        return [obj_path, ply_path], "3D model generated successfully!"
    except Exception as e:
        return None, f"Error: {str(e)}"

def process_image(image, num_steps):
    try:
        if image is None:
            return None, None, "Please upload an image first."
        
        results, message = image_to_3d(
            image, 
            num_steps=num_steps
        )
        
        if results:
            return results[0], results[1], message
        else:
            return None, None, message
    except Exception as e:
        return None, None, f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Image to 3D Point Cloud Converter") as demo:
    gr.Markdown("# Image to 3D Point Cloud Converter")
    gr.Markdown("Upload an image to convert it to a 3D point cloud that you can use in Unity or other engines.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Input Image")
            num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps")
            submit_btn = gr.Button("Convert to 3D")
        
        with gr.Column(scale=1):
            obj_file = gr.File(label="OBJ File (for editing)")
            ply_file = gr.File(label="PLY File (for Unity)")
            output_message = gr.Textbox(label="Output Message")
    
    submit_btn.click(
        fn=process_image,
        inputs=[input_image, num_steps],
        outputs=[obj_file, ply_file, output_message]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)