Spaces:
Running
on
Zero
Running
on
Zero
bennyguo
commited on
Commit
·
af15ec4
1
Parent(s):
b98ab62
add wd14 tagging if prompt is not given
Browse files- app.py +222 -40
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,9 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import sys
|
4 |
import subprocess
|
5 |
-
from huggingface_hub import snapshot_download, HfFolder
|
6 |
import random # Import random for seed generation
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# --- Repo Setup ---
|
9 |
DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path
|
@@ -152,66 +166,234 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
152 |
def get_random_seed():
|
153 |
return random.randint(0, MAX_SEED)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
# Apply decorator conditionally
|
156 |
@spaces.GPU() if ENABLE_ZEROGPU else lambda func: func
|
157 |
-
def generate_3d(scribble_image_dict, prompt, scribble_confidence,
|
158 |
print("Generating 3D model...")
|
159 |
-
#
|
160 |
if scribble_image_dict is None or scribble_image_dict.get("composite") is None:
|
161 |
print("No scribble image provided.")
|
162 |
-
return None # Return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
# --- Seed Handling ---
|
165 |
current_seed = int(seed)
|
166 |
print(f"Using seed: {current_seed}")
|
167 |
# --- End Seed Handling ---
|
168 |
|
169 |
-
#
|
|
|
170 |
# The composite might be RGBA if a layer was involved, ensure RGB for processing
|
171 |
-
|
172 |
-
|
173 |
# Preprocess the image: invert colors (black on white -> white on black)
|
174 |
-
image_np = np.array(
|
175 |
processed_image_np = 255 - image_np
|
176 |
processed_image = Image.fromarray(processed_image_np)
|
177 |
-
print("Image preprocessed.")
|
|
|
178 |
|
179 |
-
#
|
180 |
-
# attn_scale_text = 1.0 # Replaced by text_confidence input
|
181 |
-
|
182 |
-
# Set the generator with the provided seed
|
183 |
generator = torch.Generator(device='cuda').manual_seed(current_seed)
|
|
|
184 |
|
185 |
-
# Run
|
186 |
print("Running pipeline...")
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
# Save
|
205 |
if out.meshes and len(out.meshes) > 0:
|
206 |
# Create a temporary file with .glb extension
|
207 |
with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile:
|
208 |
output_path = tmpfile.name
|
209 |
out.meshes[0].export(output_path)
|
210 |
print(f"Mesh saved to temporary file: {output_path}")
|
211 |
-
return output_path
|
212 |
else:
|
213 |
print("Pipeline did not generate any meshes.")
|
214 |
-
return None
|
|
|
215 |
|
216 |
# Create the Gradio interface
|
217 |
with gr.Blocks() as demo:
|
@@ -242,21 +424,21 @@ with gr.Blocks() as demo:
|
|
242 |
|
243 |
submit_button.click(
|
244 |
fn=generate_3d,
|
245 |
-
inputs=gen_inputs,
|
246 |
-
outputs=model_output
|
247 |
)
|
248 |
|
249 |
# Define inputs for the lucky button (same as main button for the final call)
|
250 |
lucky_gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input
|
251 |
|
252 |
lucky_button.click(
|
253 |
-
fn=get_random_seed,
|
254 |
inputs=[],
|
255 |
-
outputs=[seed_input]
|
256 |
).then(
|
257 |
-
fn=generate_3d,
|
258 |
-
inputs=lucky_gen_inputs,
|
259 |
-
outputs=model_output
|
260 |
)
|
261 |
|
262 |
# Launch with queue enabled if using ZeroGPU
|
|
|
1 |
+
# --- Environment Variables Used ---
|
2 |
+
# ENABLE_ZEROGPU: Set to 'true' or '1' to enable @spaces.GPU decorator (for Hugging Face Spaces).
|
3 |
+
# TRIPOSG_CODE_PATH: Absolute path to a local directory containing the checked-out TripoSG repository (scribble branch).
|
4 |
+
# GITHUB_TOKEN: A GitHub token used for cloning the TripoSG repo if TRIPOSG_CODE_PATH is not provided.
|
5 |
+
# WEIGHTS_PATH: Absolute path to a local directory containing the TripoSG-scribble model weights.
|
6 |
+
# HF_TOKEN: A Hugging Face Hub token used for downloading weights/models if local paths (WEIGHTS_PATH, WD14_CONVNEXT_PATH) are not provided.
|
7 |
+
# WD14_CONVNEXT_PATH: Absolute path to a local directory containing the WD14 ConvNeXT tagger model.onnx and selected_tags.csv.
|
8 |
+
# ----------------------------------
|
9 |
+
|
10 |
import gradio as gr
|
11 |
import os
|
12 |
import sys
|
13 |
import subprocess
|
14 |
+
from huggingface_hub import snapshot_download, HfFolder, hf_hub_download
|
15 |
import random # Import random for seed generation
|
16 |
+
import re # For WD14 tag processing
|
17 |
+
import cv2 # For WD14 preprocessing
|
18 |
+
import pandas as pd # For WD14 tags
|
19 |
+
from onnxruntime import InferenceSession # For WD14 model
|
20 |
+
from typing import Mapping, Tuple, Dict # Type hints
|
21 |
|
22 |
# --- Repo Setup ---
|
23 |
DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path
|
|
|
166 |
def get_random_seed():
|
167 |
return random.randint(0, MAX_SEED)
|
168 |
|
169 |
+
# --- WD14 Helper Functions ---
|
170 |
+
def make_square(img, target_size):
|
171 |
+
old_size = img.shape[:2]
|
172 |
+
desired_size = max(old_size)
|
173 |
+
desired_size = max(desired_size, target_size)
|
174 |
+
|
175 |
+
delta_w = desired_size - old_size[1]
|
176 |
+
delta_h = desired_size - old_size[0]
|
177 |
+
top, bottom = delta_h // 2, delta_h - (delta_h // 2)
|
178 |
+
left, right = delta_w // 2, delta_w - (delta_w // 2)
|
179 |
+
|
180 |
+
color = [255, 255, 255] # White padding
|
181 |
+
return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
182 |
+
|
183 |
+
def smart_resize(img, size):
|
184 |
+
if img.shape[0] > size:
|
185 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
|
186 |
+
elif img.shape[0] < size:
|
187 |
+
img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
|
188 |
+
return img
|
189 |
+
|
190 |
+
RE_SPECIAL = re.compile(r'([\()])')
|
191 |
+
|
192 |
+
# --- WD14 Tagger Class ---
|
193 |
+
class WaifuDiffusionInterrogator:
|
194 |
+
def __init__(
|
195 |
+
self,
|
196 |
+
repo: str,
|
197 |
+
model_filename='model.onnx',
|
198 |
+
tags_filename='selected_tags.csv',
|
199 |
+
local_model_dir: str | None = None # Added local path option
|
200 |
+
) -> None:
|
201 |
+
self.__repo = repo
|
202 |
+
self.__model_filename = model_filename
|
203 |
+
self.__tags_filename = tags_filename
|
204 |
+
self.__local_model_dir = local_model_dir
|
205 |
+
self.__initialized = False
|
206 |
+
self._model = None
|
207 |
+
self._tags = None
|
208 |
+
|
209 |
+
def _init(self) -> None:
|
210 |
+
if self.__initialized:
|
211 |
+
return
|
212 |
+
|
213 |
+
model_path = None
|
214 |
+
tags_path = None
|
215 |
+
|
216 |
+
if self.__local_model_dir:
|
217 |
+
print(f"WD14: Attempting to load from local directory: {self.__local_model_dir}")
|
218 |
+
potential_model_path = os.path.join(self.__local_model_dir, self.__model_filename)
|
219 |
+
potential_tags_path = os.path.join(self.__local_model_dir, self.__tags_filename)
|
220 |
+
if os.path.exists(potential_model_path) and os.path.exists(potential_tags_path):
|
221 |
+
model_path = potential_model_path
|
222 |
+
tags_path = potential_tags_path
|
223 |
+
print("WD14: Found local model and tags file.")
|
224 |
+
else:
|
225 |
+
print("WD14: Local files not found. Falling back to Hugging Face download.")
|
226 |
+
|
227 |
+
if model_path is None or tags_path is None:
|
228 |
+
print(f"WD14: Downloading from repo: {self.__repo}")
|
229 |
+
hf_token = os.environ.get("HF_TOKEN") # Reuse HF token if available
|
230 |
+
try:
|
231 |
+
model_path = hf_hub_download(self.__repo, filename=self.__model_filename, token=hf_token)
|
232 |
+
tags_path = hf_hub_download(self.__repo, filename=self.__tags_filename, token=hf_token)
|
233 |
+
print("WD14: Download complete.")
|
234 |
+
except Exception as e:
|
235 |
+
print(f"WD14: Error downloading from Hugging Face: {e}")
|
236 |
+
# Decide how to handle this - maybe raise error or disable tagging?
|
237 |
+
# For now, we'll let it fail later if model is None
|
238 |
+
return # Cannot initialize
|
239 |
+
|
240 |
+
try:
|
241 |
+
self._model = InferenceSession(str(model_path))
|
242 |
+
self._tags = pd.read_csv(tags_path)
|
243 |
+
self.__initialized = True
|
244 |
+
print("WD14: Tagger initialized successfully.")
|
245 |
+
except Exception as e:
|
246 |
+
print(f"WD14: Error initializing ONNX session or reading tags: {e}")
|
247 |
+
|
248 |
+
def _calculation(self, image: Image.Image) -> pd.DataFrame | None:
|
249 |
+
self._init()
|
250 |
+
if not self._model or self._tags is None:
|
251 |
+
print("WD14: Tagger not initialized.")
|
252 |
+
return None
|
253 |
+
|
254 |
+
_, height, _, _ = self._model.get_inputs()[0].shape
|
255 |
+
|
256 |
+
image = image.convert('RGBA')
|
257 |
+
new_image = Image.new('RGBA', image.size, 'WHITE')
|
258 |
+
new_image.paste(image, mask=image)
|
259 |
+
image = new_image.convert('RGB')
|
260 |
+
image.save("image_to_wd.png")
|
261 |
+
image = np.asarray(image)
|
262 |
+
image = image[:, :, ::-1]
|
263 |
+
|
264 |
+
image = make_square(image, height)
|
265 |
+
image = smart_resize(image, height)
|
266 |
+
image = image.astype(np.float32)
|
267 |
+
image = np.expand_dims(image, 0)
|
268 |
+
|
269 |
+
input_name = self._model.get_inputs()[0].name
|
270 |
+
label_name = self._model.get_outputs()[0].name
|
271 |
+
confidence = self._model.run([label_name], {input_name: image})[0]
|
272 |
+
|
273 |
+
full_tags = self._tags[['name', 'category']].copy()
|
274 |
+
full_tags['confidence'] = confidence[0]
|
275 |
+
|
276 |
+
return full_tags
|
277 |
+
|
278 |
+
def interrogate(self, image: Image.Image) -> Tuple[Dict[str, float], Dict[str, float]] | None:
|
279 |
+
full_tags = self._calculation(image)
|
280 |
+
if full_tags is None:
|
281 |
+
return None
|
282 |
+
|
283 |
+
ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values)
|
284 |
+
tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values)
|
285 |
+
|
286 |
+
return ratings, tags
|
287 |
+
|
288 |
+
# --- Instantiate WD14 Tagger ---
|
289 |
+
WD14_CONVNEXT_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger'
|
290 |
+
wd14_local_path = os.environ.get("WD14_CONVNEXT_PATH")
|
291 |
+
wd14_tagger = WaifuDiffusionInterrogator(repo=WD14_CONVNEXT_REPO, local_model_dir=wd14_local_path)
|
292 |
+
|
293 |
+
# --- Helper to format tags ---
|
294 |
+
def format_wd14_tags(tags: Dict[str, float], threshold: float = 0.35) -> str:
|
295 |
+
filtered_tags = {
|
296 |
+
tag: score for tag, score in tags.items()
|
297 |
+
if score >= threshold and "background" not in tag and tag not in {"monochrome", "greyscale", "no_humans", "comic", "solo"}
|
298 |
+
}
|
299 |
+
print(filtered_tags)
|
300 |
+
# Sort by score descending, then alphabetically
|
301 |
+
tags_pairs = sorted(filtered_tags.items(), key=lambda x: (-x[1], x[0]))
|
302 |
+
text_items = [tag.replace('_', ' ') for tag, score in tags_pairs]
|
303 |
+
return ', '.join(text_items)
|
304 |
+
|
305 |
# Apply decorator conditionally
|
306 |
@spaces.GPU() if ENABLE_ZEROGPU else lambda func: func
|
307 |
+
def generate_3d(scribble_image_dict, prompt, scribble_confidence, text_confidence, seed):
|
308 |
print("Generating 3D model...")
|
309 |
+
input_prompt = prompt # Keep track of original prompt for return on early exit
|
310 |
if scribble_image_dict is None or scribble_image_dict.get("composite") is None:
|
311 |
print("No scribble image provided.")
|
312 |
+
return None, input_prompt # Return None for model, original prompt
|
313 |
+
|
314 |
+
# --- Prompt Handling ---
|
315 |
+
input_prompt = prompt.strip()
|
316 |
+
if not input_prompt:
|
317 |
+
print("Prompt is empty, attempting WD14 tagging...")
|
318 |
+
try:
|
319 |
+
# Get the user drawing (black on white) for tagging
|
320 |
+
user_drawing_img = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
|
321 |
+
tag_results = wd14_tagger.interrogate(user_drawing_img)
|
322 |
+
if tag_results:
|
323 |
+
ratings, tags = tag_results
|
324 |
+
generated_prompt = format_wd14_tags(tags) # Use default threshold
|
325 |
+
if generated_prompt:
|
326 |
+
print(f"WD14 generated prompt: {generated_prompt}")
|
327 |
+
input_prompt = generated_prompt
|
328 |
+
else:
|
329 |
+
print("WD14 tagging did not produce tags above threshold.")
|
330 |
+
input_prompt = "object" # Fallback prompt
|
331 |
+
else:
|
332 |
+
print("WD14 tagging failed or tagger not initialized.")
|
333 |
+
input_prompt = "object" # Fallback prompt
|
334 |
+
except Exception as e:
|
335 |
+
print(f"Error during WD14 tagging: {e}")
|
336 |
+
input_prompt = "object" # Fallback prompt
|
337 |
+
else:
|
338 |
+
print(f"Using user provided prompt: {input_prompt}")
|
339 |
+
# --- End Prompt Handling ---
|
340 |
|
341 |
# --- Seed Handling ---
|
342 |
current_seed = int(seed)
|
343 |
print(f"Using seed: {current_seed}")
|
344 |
# --- End Seed Handling ---
|
345 |
|
346 |
+
# --- Image Preprocessing for TripoSG ---
|
347 |
+
# Get the composite image again (safer in case dict is modified)
|
348 |
# The composite might be RGBA if a layer was involved, ensure RGB for processing
|
349 |
+
image_for_triposg = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
|
|
|
350 |
# Preprocess the image: invert colors (black on white -> white on black)
|
351 |
+
image_np = np.array(image_for_triposg)
|
352 |
processed_image_np = 255 - image_np
|
353 |
processed_image = Image.fromarray(processed_image_np)
|
354 |
+
print("Image preprocessed for TripoSG.")
|
355 |
+
# --- End Image Preprocessing ---
|
356 |
|
357 |
+
# --- Generator Setup ---
|
|
|
|
|
|
|
358 |
generator = torch.Generator(device='cuda').manual_seed(current_seed)
|
359 |
+
# --- End Generator Setup ---
|
360 |
|
361 |
+
# --- Run Pipeline ---
|
362 |
print("Running pipeline...")
|
363 |
+
try:
|
364 |
+
out = pipe(
|
365 |
+
processed_image,
|
366 |
+
prompt=input_prompt, # Use the potentially generated prompt
|
367 |
+
num_tokens=512, # Default value from example
|
368 |
+
guidance_scale=0, # Default value from example
|
369 |
+
num_inference_steps=16, # Default value from example
|
370 |
+
attention_kwargs={
|
371 |
+
"cross_attention_scale": text_confidence,
|
372 |
+
"cross_attention_2_scale": scribble_confidence
|
373 |
+
},
|
374 |
+
generator=generator,
|
375 |
+
use_flash_decoder=False, # Default value from example
|
376 |
+
dense_octree_depth=8, # Default value from example
|
377 |
+
hierarchical_octree_depth=8 # Default value from example
|
378 |
+
)
|
379 |
+
print("Pipeline finished.")
|
380 |
+
except Exception as e:
|
381 |
+
print(f"Error during pipeline execution: {e}")
|
382 |
+
return None, input_prompt # Return None for model, the prompt used
|
383 |
+
# --- End Run Pipeline ---
|
384 |
|
385 |
+
# --- Save Output ---
|
386 |
if out.meshes and len(out.meshes) > 0:
|
387 |
# Create a temporary file with .glb extension
|
388 |
with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile:
|
389 |
output_path = tmpfile.name
|
390 |
out.meshes[0].export(output_path)
|
391 |
print(f"Mesh saved to temporary file: {output_path}")
|
392 |
+
return output_path, input_prompt # Return model path and the prompt used
|
393 |
else:
|
394 |
print("Pipeline did not generate any meshes.")
|
395 |
+
return None, input_prompt # Return None for model, the prompt used
|
396 |
+
# --- End Save Output ---
|
397 |
|
398 |
# Create the Gradio interface
|
399 |
with gr.Blocks() as demo:
|
|
|
424 |
|
425 |
submit_button.click(
|
426 |
fn=generate_3d,
|
427 |
+
inputs=gen_inputs,
|
428 |
+
outputs=[model_output, prompt_input] # Add prompt_input to outputs
|
429 |
)
|
430 |
|
431 |
# Define inputs for the lucky button (same as main button for the final call)
|
432 |
lucky_gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input
|
433 |
|
434 |
lucky_button.click(
|
435 |
+
fn=get_random_seed,
|
436 |
inputs=[],
|
437 |
+
outputs=[seed_input]
|
438 |
).then(
|
439 |
+
fn=generate_3d,
|
440 |
+
inputs=lucky_gen_inputs,
|
441 |
+
outputs=[model_output, prompt_input] # Add prompt_input to outputs
|
442 |
)
|
443 |
|
444 |
# Launch with queue enabled if using ZeroGPU
|
requirements.txt
CHANGED
@@ -14,3 +14,4 @@ typeguard
|
|
14 |
ninja
|
15 |
gltflib
|
16 |
https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true
|
|
|
|
14 |
ninja
|
15 |
gltflib
|
16 |
https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true
|
17 |
+
onnxruntime
|