bennyguo commited on
Commit
2c25d73
·
1 Parent(s): c60b074

initial demo release

Browse files
Files changed (2) hide show
  1. app.py +260 -4
  2. requirements.txt +16 -0
app.py CHANGED
@@ -1,7 +1,263 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import sys
4
+ import subprocess
5
+ from huggingface_hub import snapshot_download, HfFolder
6
+ import random # Import random for seed generation
7
 
8
+ # --- Repo Setup ---
9
+ DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path
10
+ REPO_GIT_URL = "github.com/VAST-AI-Research/TripoSG.git" # Base URL without schema/token
11
+ BRANCH = "scribble"
12
 
13
+ code_source_path = None
14
+
15
+ # Option 1: Use local path if TRIPOSG_CODE_PATH env var is set
16
+ local_code_path = os.environ.get("TRIPOSG_CODE_PATH")
17
+ if local_code_path:
18
+ print(f"Attempting to use local code path specified by TRIPOSG_CODE_PATH: {local_code_path}")
19
+ # Basic check: does it exist and seem like a git repo (has .git)?
20
+ if os.path.isdir(local_code_path) and os.path.isdir(os.path.join(local_code_path, ".git")):
21
+ code_source_path = os.path.abspath(local_code_path)
22
+ print(f"Using local TripoSG code directory: {code_source_path}")
23
+ # You might want to add a check here to verify the branch is correct, e.g.:
24
+ # try:
25
+ # current_branch = subprocess.run(["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=code_source_path, check=True, capture_output=True, text=True).stdout.strip()
26
+ # if current_branch != BRANCH:
27
+ # print(f"Warning: Local repo is on branch '{current_branch}', expected '{BRANCH}'. Attempting checkout...")
28
+ # subprocess.run(["git", "checkout", BRANCH], cwd=code_source_path, check=True)
29
+ # except Exception as e:
30
+ # print(f"Warning: Could not verify or checkout branch '{BRANCH}' in {code_source_path}: {e}")
31
+ else:
32
+ print(f"Warning: TRIPOSG_CODE_PATH '{local_code_path}' not found or not a valid git repository directory. Falling back to cloning.")
33
+
34
+ # Option 2: Clone from GitHub (if local path not used or invalid)
35
+ if not code_source_path:
36
+ repo_url_to_clone = f"https://{REPO_GIT_URL}"
37
+ github_token = os.environ.get("GITHUB_TOKEN")
38
+ if github_token:
39
+ print("Using GITHUB_TOKEN for repository cloning.")
40
+ repo_url_to_clone = f"https://{github_token}@{REPO_GIT_URL}"
41
+ else:
42
+ print("No GITHUB_TOKEN found. Using public HTTPS for cloning.")
43
+
44
+ repo_target_dir = os.path.abspath(DEFAULT_REPO_DIR)
45
+ if not os.path.exists(repo_target_dir):
46
+ print(f"Cloning TripoSG repository ({BRANCH} branch) into {repo_target_dir}...")
47
+ try:
48
+ subprocess.run(["git", "clone", "--branch", BRANCH, "--depth", "1", repo_url_to_clone, repo_target_dir], check=True)
49
+ code_source_path = repo_target_dir
50
+ print("Repository cloned successfully.")
51
+ except subprocess.CalledProcessError as e:
52
+ print(f"Error cloning repository: {e}")
53
+ print("Please ensure the URL is correct, the branch '{BRANCH}' exists, and you have access rights (or provide a GITHUB_TOKEN).")
54
+ sys.exit(1)
55
+ except Exception as e:
56
+ print(f"An unexpected error occurred during cloning: {e}")
57
+ sys.exit(1)
58
+ else:
59
+ print(f"Directory {repo_target_dir} already exists. Assuming it contains the correct code/branch.")
60
+ # Optional: Add checks here like git pull or verifying the branch
61
+ code_source_path = repo_target_dir
62
+
63
+ if not code_source_path:
64
+ print("Error: Could not determine TripoSG code source path.")
65
+ sys.exit(1)
66
+
67
+ # Add repo to Python path
68
+ sys.path.insert(0, code_source_path) # Use the determined absolute path
69
+ print(f"Added {code_source_path} to sys.path")
70
+ # --- End Repo Setup ---
71
+
72
+ # --- ZeroGPU Setup ---
73
+ ENABLE_ZEROGPU = os.environ.get("ENABLE_ZEROGPU", "false").lower() in ("true", "1", "t")
74
+ print(f"ZeroGPU Enabled: {ENABLE_ZEROGPU}")
75
+ # --- End ZeroGPU Setup ---
76
+
77
+ if ENABLE_ZEROGPU:
78
+ import spaces # Import spaces for ZeroGPU
79
+ from PIL import Image
80
+ import numpy as np
81
+ import torch
82
+ from triposg.pipelines.pipeline_triposg_scribble import TripoSGScribblePipeline
83
+ import tempfile
84
+
85
+ # --- Weight Loading Logic ---
86
+ HF_TOKEN = os.environ.get("HF_TOKEN")
87
+ if HF_TOKEN:
88
+ HfFolder.save_token(HF_TOKEN)
89
+ HUGGING_FACE_REPO_ID = "VAST-AI/TripoSG-scribble"
90
+ DEFAULT_CACHE_PATH = "./pretrained_weights/TripoSG-scribble"
91
+
92
+ # Option 1: Use local path if WEIGHTS_PATH env var is set
93
+ local_weights_path = os.environ.get("WEIGHTS_PATH")
94
+ model_load_path = None
95
+
96
+ if local_weights_path:
97
+ print(f"Attempting to load weights from local path specified by WEIGHTS_PATH: {local_weights_path}")
98
+ if os.path.isdir(local_weights_path):
99
+ model_load_path = local_weights_path
100
+ print(f"Using local weights directory: {model_load_path}")
101
+ else:
102
+ print(f"Warning: WEIGHTS_PATH '{local_weights_path}' not found or not a directory. Falling back to Hugging Face download.")
103
+
104
+ # Option 2: Download from Hugging Face (if local path not used or invalid)
105
+ if not model_load_path:
106
+ hf_token = os.environ.get("HF_TOKEN")
107
+ print(f"Attempting to download weights from Hugging Face repo: {HUGGING_FACE_REPO_ID}")
108
+ if hf_token:
109
+ print("Using Hugging Face token for download.")
110
+ auth_token = hf_token
111
+ else:
112
+ print("No Hugging Face token found. Attempting public download.")
113
+ auth_token = None
114
+ try:
115
+ model_load_path = snapshot_download(
116
+ repo_id=HUGGING_FACE_REPO_ID,
117
+ local_dir=DEFAULT_CACHE_PATH,
118
+ local_dir_use_symlinks=False, # Recommended for Spaces
119
+ token=auth_token,
120
+ # revision="main" # Specify branch/commit if needed
121
+ )
122
+ print(f"Weights downloaded/cached to: {model_load_path}")
123
+ except Exception as e:
124
+ print(f"Error downloading weights from Hugging Face: {e}")
125
+ print("Please ensure the repository exists and is accessible, or provide a valid WEIGHTS_PATH.")
126
+ sys.exit(1) # Exit if weights cannot be loaded
127
+
128
+ # Load the pipeline using the determined path
129
+ print(f"Loading pipeline from: {model_load_path}")
130
+ pipe = TripoSGScribblePipeline.from_pretrained(model_load_path)
131
+ pipe.to(dtype=torch.float16, device="cuda")
132
+ print("Pipeline loaded.")
133
+ # --- End Weight Loading Logic ---
134
+
135
+ # Create a white background image and a transparent layer for drawing
136
+ canvas_width, canvas_height = 512, 512
137
+ initial_background = Image.new("RGB", (canvas_width, canvas_height), color="white")
138
+ initial_layer = Image.new("RGBA", (canvas_width, canvas_height), color=(0, 0, 0, 0)) # Transparent layer
139
+ # Prepare the initial value dictionary for ImageEditor
140
+ initial_value = {
141
+ "background": initial_background,
142
+ "layers": [initial_layer], # Add the transparent layer
143
+ "composite": None
144
+ }
145
+
146
+ # --- ZeroGPU Setup ---
147
+ # ... existing ZeroGPU setup ...
148
+
149
+ MAX_SEED = np.iinfo(np.int32).max
150
+
151
+ def get_random_seed():
152
+ return random.randint(0, MAX_SEED)
153
+
154
+ # Apply decorator conditionally
155
+ @spaces.GPU(duration=120) if ENABLE_ZEROGPU else lambda func: func
156
+ def generate_3d(scribble_image_dict, prompt, scribble_confidence, seed): # Added seed parameter back
157
+ print("Generating 3D model...")
158
+ # Extract the composite image from the ImageEditor dictionary
159
+ if scribble_image_dict is None or scribble_image_dict.get("composite") is None:
160
+ print("No scribble image provided.")
161
+ return None # Return None if no image is provided
162
+
163
+ # --- Seed Handling ---
164
+ current_seed = int(seed)
165
+ print(f"Using seed: {current_seed}")
166
+ # --- End Seed Handling ---
167
+
168
+ # Get the composite image which includes the drawing
169
+ # The composite might be RGBA if a layer was involved, ensure RGB for processing
170
+ image = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
171
+
172
+ # Preprocess the image: invert colors (black on white -> white on black)
173
+ image_np = np.array(image)
174
+ processed_image_np = 255 - image_np
175
+ processed_image = Image.fromarray(processed_image_np)
176
+ print("Image preprocessed.")
177
+
178
+ # Define fixed parameters
179
+ attn_scale_text = 1.0 # As per the example run.py
180
+
181
+ # Set the generator with the provided seed
182
+ generator = torch.Generator(device='cuda').manual_seed(current_seed)
183
+
184
+ # Run the pipeline
185
+ print("Running pipeline...")
186
+ out = pipe(
187
+ processed_image,
188
+ prompt=prompt,
189
+ num_tokens=512,
190
+ guidance_scale=0,
191
+ num_inference_steps=16,
192
+ attention_kwargs={
193
+ "cross_attention_scale": attn_scale_text,
194
+ "cross_attention_2_scale": scribble_confidence
195
+ },
196
+ generator=generator,
197
+ use_flash_decoder=False,
198
+ dense_octree_depth=8,
199
+ hierarchical_octree_depth=8
200
+ )
201
+ print("Pipeline finished.")
202
+
203
+ # Save the output mesh to a temporary file
204
+ if out.meshes and len(out.meshes) > 0:
205
+ # Create a temporary file with .glb extension
206
+ with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile:
207
+ output_path = tmpfile.name
208
+ out.meshes[0].export(output_path)
209
+ print(f"Mesh saved to temporary file: {output_path}")
210
+ return output_path
211
+ else:
212
+ print("Pipeline did not generate any meshes.")
213
+ return None
214
+
215
+ # Create the Gradio interface
216
+ with gr.Blocks() as demo:
217
+ gr.Markdown("# Scribble + Text to 3D Model Generator (TripoSG)")
218
+ gr.Markdown("Draw a scribble (black on white canvas), enter a text prompt, adjust confidence, set a seed, and generate a 3D model.") # Updated guidance
219
+ with gr.Row():
220
+ with gr.Column(scale=1):
221
+ image_input = gr.ImageEditor(
222
+ label="Scribble Input (Draw Black on White)",
223
+ value=initial_value,
224
+ image_mode="RGB",
225
+ brush=gr.Brush(default_color="#000000", color_mode="fixed", default_size=5), # Fixed small brush size
226
+ interactive=True,
227
+ eraser=gr.Brush(default_color="#FFFFFF", color_mode="fixed", default_size=20) # Fixed small eraser size
228
+ )
229
+ prompt_input = gr.Textbox(label="Prompt", placeholder="e.g., a cute cat wearing a hat")
230
+ confidence_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.05, label="Scribble Confidence (attn_scale_image)")
231
+ seed_input = gr.Number(label="Seed", value=0, precision=0) # Added Seed input back
232
+ with gr.Row():
233
+ submit_button = gr.Button("Generate 3D Model", variant="primary", scale=1)
234
+ lucky_button = gr.Button("I'm Feeling Lucky", scale=1)
235
+ with gr.Column(scale=1):
236
+ model_output = gr.Model3D(label="Generated 3D Model", interactive=False)
237
+
238
+ # Define the inputs for the main generation function
239
+ gen_inputs = [image_input, prompt_input, confidence_input, seed_input]
240
+
241
+ submit_button.click(
242
+ fn=generate_3d,
243
+ inputs=gen_inputs, # Include seed_input
244
+ outputs=model_output
245
+ )
246
+
247
+ # Define inputs for the lucky button (same as main button for the final call)
248
+ lucky_gen_inputs = [image_input, prompt_input, confidence_input, seed_input]
249
+
250
+ lucky_button.click(
251
+ fn=get_random_seed, # First, get a random seed
252
+ inputs=[],
253
+ outputs=[seed_input] # Update the seed input field
254
+ ).then(
255
+ fn=generate_3d, # Then, generate the model
256
+ inputs=lucky_gen_inputs, # Use the updated seed from the input field
257
+ outputs=model_output
258
+ )
259
+
260
+ # Launch with queue enabled if using ZeroGPU
261
+ print("Launching Gradio interface...")
262
+ demo.launch(share=False, server_name="0.0.0.0")
263
+ print("Gradio interface launched.")
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers==4.49.0
3
+ einops
4
+ huggingface_hub
5
+ opencv-python
6
+ trimesh==4.5.3
7
+ omegaconf
8
+ scikit-image
9
+ numpy
10
+ peft
11
+ scipy==1.11.4
12
+ jaxtyping
13
+ typeguard
14
+ ninja
15
+ gltflib
16
+ https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true