mike23415 commited on
Commit
8505235
·
verified ·
1 Parent(s): 9cd02e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -84
app.py CHANGED
@@ -4,121 +4,104 @@ import gradio as gr
4
  import numpy as np
5
  from PIL import Image
6
  import tempfile
 
7
  import trimesh
 
 
8
 
9
  # Check if CUDA is available, otherwise use CPU
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  print(f"Using device: {device}")
12
 
13
- # Import Point-E modules
14
- try:
15
- print("Loading Point-E model...")
16
- from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
17
- from point_e.diffusion.sampler import PointCloudSampler
18
- from point_e.models.configs import MODEL_CONFIGS, model_from_config
19
- from point_e.models.download import load_checkpoint
20
- from point_e.util.plotting import plot_point_cloud
21
- except ImportError:
22
- print("Point-E modules not available. Please make sure Point-E is installed.")
23
- raise
24
-
25
- # Create base model for image encoder
26
- base_name = 'base40M-textvec'
27
- base_model = model_from_config(MODEL_CONFIGS[base_name], device)
28
- base_model.eval()
29
- base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name])
30
-
31
- # Create upsampler model
32
- upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device)
33
- upsampler_model.eval()
34
- upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample'])
35
-
36
- # Create image to point cloud model
37
- img2pc_name = 'base300M'
38
- img2pc_model = model_from_config(MODEL_CONFIGS[img2pc_name], device)
39
- img2pc_model.eval()
40
- img2pc_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[img2pc_name])
41
-
42
- # Load checkpoints
43
- print("Loading model checkpoints...")
44
- base_model.load_state_dict(load_checkpoint(base_name, device))
45
- upsampler_model.load_state_dict(load_checkpoint('upsample', device))
46
- img2pc_model.load_state_dict(load_checkpoint(img2pc_name, device))
47
-
48
- # Create samplers
49
- sampler = PointCloudSampler(
50
- device=device,
51
- models=[base_model, upsampler_model],
52
- diffusions=[base_diffusion, upsampler_diffusion],
53
- num_points=[1024, 4096],
54
- aux_channels=['R', 'G', 'B'],
55
- guidance_scale=[3.0, 0.0],
56
- )
57
 
58
- img2pc_sampler = PointCloudSampler(
59
- device=device,
60
- models=[img2pc_model],
61
- diffusions=[img2pc_diffusion],
62
- num_points=[1024],
63
- aux_channels=['R', 'G', 'B'],
64
- guidance_scale=[3.0],
65
- )
66
 
67
- def preprocess_image(image):
68
- # Resize to match expected input size
69
- image = image.resize((256, 256))
70
- return image
 
71
 
72
- def image_to_3d(image, num_steps=64):
73
  """
74
- Convert a single image to a 3D model using Point-E
75
  """
76
  if image is None:
77
  return None, "No image provided"
78
 
79
  try:
80
  # Preprocess image
81
- processed_image = preprocess_image(image)
 
 
 
 
 
 
 
 
82
 
83
- # Generate samples
84
- samples = None
85
- for i, x in enumerate(img2pc_sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(images=[processed_image]))):
86
- samples = x
87
 
88
- # Extract point cloud
89
- pc = samples[-1]['pred_pc']
90
- colors = samples[-1]['pred_pc_aux']['R', 'G', 'B']
91
 
92
- # Create colored point cloud
93
- points = pc.cpu().numpy()[0]
94
- colors_np = colors.cpu().numpy()[0]
95
 
96
- # Create a mesh from point cloud
97
- point_cloud = trimesh.PointCloud(vertices=points, colors=colors_np)
 
98
 
99
  # Save as OBJ
100
  with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
101
  obj_path = obj_file.name
102
- point_cloud.export(obj_path)
103
 
104
- # Save as PLY for better Unity compatibility
105
  with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file:
106
  ply_path = ply_file.name
107
- point_cloud.export(ply_path)
108
 
109
  return [obj_path, ply_path], "3D model generated successfully!"
110
  except Exception as e:
111
  return None, f"Error: {str(e)}"
112
 
113
- def process_image(image, num_steps):
114
  try:
115
  if image is None:
116
  return None, None, "Please upload an image first."
117
 
118
- results, message = image_to_3d(
119
- image,
120
- num_steps=num_steps
121
- )
122
 
123
  if results:
124
  return results[0], results[1], message
@@ -128,14 +111,13 @@ def process_image(image, num_steps):
128
  return None, None, f"Error: {str(e)}"
129
 
130
  # Create Gradio interface
131
- with gr.Blocks(title="Image to 3D Point Cloud Converter") as demo:
132
- gr.Markdown("# Image to 3D Point Cloud Converter")
133
- gr.Markdown("Upload an image to convert it to a 3D point cloud that you can use in Unity or other engines.")
134
 
135
  with gr.Row():
136
  with gr.Column(scale=1):
137
  input_image = gr.Image(type="pil", label="Input Image")
138
- num_steps = gr.Slider(minimum=16, maximum=128, value=64, step=8, label="Number of Inference Steps")
139
  submit_btn = gr.Button("Convert to 3D")
140
 
141
  with gr.Column(scale=1):
@@ -145,7 +127,7 @@ with gr.Blocks(title="Image to 3D Point Cloud Converter") as demo:
145
 
146
  submit_btn.click(
147
  fn=process_image,
148
- inputs=[input_image, num_steps],
149
  outputs=[obj_file, ply_file, output_message]
150
  )
151
 
 
4
  import numpy as np
5
  from PIL import Image
6
  import tempfile
7
+ from skimage import measure
8
  import trimesh
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as transforms
11
 
12
  # Check if CUDA is available, otherwise use CPU
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print(f"Using device: {device}")
15
 
16
+ # Define a simple neural network to extract depth from images
17
+ class SimpleDepthNet(torch.nn.Module):
18
+ def __init__(self):
19
+ super(SimpleDepthNet, self).__init__()
20
+ self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
21
+ self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
22
+ self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
23
+ self.conv4 = torch.nn.Conv2d(128, 1, kernel_size=3, padding=1)
24
+ self.pool = torch.nn.MaxPool2d(2, 2)
25
+ self.upsample = torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
26
+
27
+ def forward(self, x):
28
+ # Encoder
29
+ x = F.relu(self.conv1(x))
30
+ x = self.pool(x)
31
+ x = F.relu(self.conv2(x))
32
+ x = self.pool(x)
33
+
34
+ # Decoder
35
+ x = self.upsample(x)
36
+ x = F.relu(self.conv3(x))
37
+ x = self.upsample(x)
38
+ x = torch.sigmoid(self.conv4(x))
39
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Initialize the model
42
+ model = SimpleDepthNet().to(device)
 
 
 
 
 
 
43
 
44
+ # Define transformation for input images
45
+ transform = transforms.Compose([
46
+ transforms.Resize((256, 256)),
47
+ transforms.ToTensor(),
48
+ ])
49
 
50
+ def image_to_3d(image):
51
  """
52
+ Convert a single image to a 3D model using a simple depth extraction approach
53
  """
54
  if image is None:
55
  return None, "No image provided"
56
 
57
  try:
58
  # Preprocess image
59
+ img_tensor = transform(image).unsqueeze(0).to(device)
60
+
61
+ # Generate depth map
62
+ with torch.no_grad():
63
+ depth = model(img_tensor)[0, 0].cpu().numpy()
64
+
65
+ # Convert depth map to 3D points
66
+ h, w = depth.shape
67
+ y, x = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
68
 
69
+ # Normalize coordinates
70
+ x = (x - w/2) / max(w, h)
71
+ y = (y - h/2) / max(w, h)
72
+ z = depth - 0.5 # Center around zero
73
 
74
+ # Create point cloud
75
+ points = np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1)
 
76
 
77
+ # Get colors from original image
78
+ img_np = np.array(image.resize((w, h))) / 255.0
79
+ colors = img_np.reshape(-1, 3)
80
 
81
+ # Create a mesh from the point cloud (using marching cubes on the depth map)
82
+ verts, faces, _, _ = measure.marching_cubes(depth, 0.5)
83
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
84
 
85
  # Save as OBJ
86
  with tempfile.NamedTemporaryFile(suffix='.obj', delete=False) as obj_file:
87
  obj_path = obj_file.name
88
+ mesh.export(obj_path)
89
 
90
+ # Also save as PLY for better compatibility with Unity
91
  with tempfile.NamedTemporaryFile(suffix='.ply', delete=False) as ply_file:
92
  ply_path = ply_file.name
93
+ mesh.export(ply_path)
94
 
95
  return [obj_path, ply_path], "3D model generated successfully!"
96
  except Exception as e:
97
  return None, f"Error: {str(e)}"
98
 
99
+ def process_image(image):
100
  try:
101
  if image is None:
102
  return None, None, "Please upload an image first."
103
 
104
+ results, message = image_to_3d(image)
 
 
 
105
 
106
  if results:
107
  return results[0], results[1], message
 
111
  return None, None, f"Error: {str(e)}"
112
 
113
  # Create Gradio interface
114
+ with gr.Blocks(title="Simple Image to 3D Converter") as demo:
115
+ gr.Markdown("# Simple Image to 3D Converter")
116
+ gr.Markdown("Upload an image to convert it to a simple 3D model that you can use in Unity or other engines.")
117
 
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
  input_image = gr.Image(type="pil", label="Input Image")
 
121
  submit_btn = gr.Button("Convert to 3D")
122
 
123
  with gr.Column(scale=1):
 
127
 
128
  submit_btn.click(
129
  fn=process_image,
130
+ inputs=[input_image],
131
  outputs=[obj_file, ply_file, output_message]
132
  )
133