bennyguo commited on
Commit
af15ec4
·
1 Parent(s): b98ab62

add wd14 tagging if prompt is not given

Browse files
Files changed (2) hide show
  1. app.py +222 -40
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,9 +1,23 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import os
3
  import sys
4
  import subprocess
5
- from huggingface_hub import snapshot_download, HfFolder
6
  import random # Import random for seed generation
 
 
 
 
 
7
 
8
  # --- Repo Setup ---
9
  DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path
@@ -152,66 +166,234 @@ MAX_SEED = np.iinfo(np.int32).max
152
  def get_random_seed():
153
  return random.randint(0, MAX_SEED)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Apply decorator conditionally
156
  @spaces.GPU() if ENABLE_ZEROGPU else lambda func: func
157
- def generate_3d(scribble_image_dict, prompt, scribble_confidence, prompt_confidence, seed): # Added text_confidence parameter
158
  print("Generating 3D model...")
159
- # Extract the composite image from the ImageEditor dictionary
160
  if scribble_image_dict is None or scribble_image_dict.get("composite") is None:
161
  print("No scribble image provided.")
162
- return None # Return None if no image is provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  # --- Seed Handling ---
165
  current_seed = int(seed)
166
  print(f"Using seed: {current_seed}")
167
  # --- End Seed Handling ---
168
 
169
- # Get the composite image which includes the drawing
 
170
  # The composite might be RGBA if a layer was involved, ensure RGB for processing
171
- image = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
172
-
173
  # Preprocess the image: invert colors (black on white -> white on black)
174
- image_np = np.array(image)
175
  processed_image_np = 255 - image_np
176
  processed_image = Image.fromarray(processed_image_np)
177
- print("Image preprocessed.")
 
178
 
179
- # Define fixed parameters
180
- # attn_scale_text = 1.0 # Replaced by text_confidence input
181
-
182
- # Set the generator with the provided seed
183
  generator = torch.Generator(device='cuda').manual_seed(current_seed)
 
184
 
185
- # Run the pipeline
186
  print("Running pipeline...")
187
- out = pipe(
188
- processed_image,
189
- prompt=prompt,
190
- num_tokens=512, # Default value from example
191
- guidance_scale=0, # Default value from example
192
- num_inference_steps=16, # Default value from example
193
- attention_kwargs={
194
- "cross_attention_scale": prompt_confidence, # Use input parameter
195
- "cross_attention_2_scale": scribble_confidence
196
- },
197
- generator=generator,
198
- use_flash_decoder=False,
199
- dense_octree_depth=8,
200
- hierarchical_octree_depth=8
201
- )
202
- print("Pipeline finished.")
 
 
 
 
 
203
 
204
- # Save the output mesh to a temporary file
205
  if out.meshes and len(out.meshes) > 0:
206
  # Create a temporary file with .glb extension
207
  with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile:
208
  output_path = tmpfile.name
209
  out.meshes[0].export(output_path)
210
  print(f"Mesh saved to temporary file: {output_path}")
211
- return output_path
212
  else:
213
  print("Pipeline did not generate any meshes.")
214
- return None
 
215
 
216
  # Create the Gradio interface
217
  with gr.Blocks() as demo:
@@ -242,21 +424,21 @@ with gr.Blocks() as demo:
242
 
243
  submit_button.click(
244
  fn=generate_3d,
245
- inputs=gen_inputs, # Include seed_input and text_confidence_input
246
- outputs=model_output
247
  )
248
 
249
  # Define inputs for the lucky button (same as main button for the final call)
250
  lucky_gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input
251
 
252
  lucky_button.click(
253
- fn=get_random_seed, # First, get a random seed
254
  inputs=[],
255
- outputs=[seed_input] # Update the seed input field
256
  ).then(
257
- fn=generate_3d, # Then, generate the model
258
- inputs=lucky_gen_inputs, # Use the updated seed from the input field
259
- outputs=model_output
260
  )
261
 
262
  # Launch with queue enabled if using ZeroGPU
 
1
+ # --- Environment Variables Used ---
2
+ # ENABLE_ZEROGPU: Set to 'true' or '1' to enable @spaces.GPU decorator (for Hugging Face Spaces).
3
+ # TRIPOSG_CODE_PATH: Absolute path to a local directory containing the checked-out TripoSG repository (scribble branch).
4
+ # GITHUB_TOKEN: A GitHub token used for cloning the TripoSG repo if TRIPOSG_CODE_PATH is not provided.
5
+ # WEIGHTS_PATH: Absolute path to a local directory containing the TripoSG-scribble model weights.
6
+ # HF_TOKEN: A Hugging Face Hub token used for downloading weights/models if local paths (WEIGHTS_PATH, WD14_CONVNEXT_PATH) are not provided.
7
+ # WD14_CONVNEXT_PATH: Absolute path to a local directory containing the WD14 ConvNeXT tagger model.onnx and selected_tags.csv.
8
+ # ----------------------------------
9
+
10
  import gradio as gr
11
  import os
12
  import sys
13
  import subprocess
14
+ from huggingface_hub import snapshot_download, HfFolder, hf_hub_download
15
  import random # Import random for seed generation
16
+ import re # For WD14 tag processing
17
+ import cv2 # For WD14 preprocessing
18
+ import pandas as pd # For WD14 tags
19
+ from onnxruntime import InferenceSession # For WD14 model
20
+ from typing import Mapping, Tuple, Dict # Type hints
21
 
22
  # --- Repo Setup ---
23
  DEFAULT_REPO_DIR = "./TripoSG-repo" # Directory to clone into if not using local path
 
166
  def get_random_seed():
167
  return random.randint(0, MAX_SEED)
168
 
169
+ # --- WD14 Helper Functions ---
170
+ def make_square(img, target_size):
171
+ old_size = img.shape[:2]
172
+ desired_size = max(old_size)
173
+ desired_size = max(desired_size, target_size)
174
+
175
+ delta_w = desired_size - old_size[1]
176
+ delta_h = desired_size - old_size[0]
177
+ top, bottom = delta_h // 2, delta_h - (delta_h // 2)
178
+ left, right = delta_w // 2, delta_w - (delta_w // 2)
179
+
180
+ color = [255, 255, 255] # White padding
181
+ return cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
182
+
183
+ def smart_resize(img, size):
184
+ if img.shape[0] > size:
185
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
186
+ elif img.shape[0] < size:
187
+ img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)
188
+ return img
189
+
190
+ RE_SPECIAL = re.compile(r'([\()])')
191
+
192
+ # --- WD14 Tagger Class ---
193
+ class WaifuDiffusionInterrogator:
194
+ def __init__(
195
+ self,
196
+ repo: str,
197
+ model_filename='model.onnx',
198
+ tags_filename='selected_tags.csv',
199
+ local_model_dir: str | None = None # Added local path option
200
+ ) -> None:
201
+ self.__repo = repo
202
+ self.__model_filename = model_filename
203
+ self.__tags_filename = tags_filename
204
+ self.__local_model_dir = local_model_dir
205
+ self.__initialized = False
206
+ self._model = None
207
+ self._tags = None
208
+
209
+ def _init(self) -> None:
210
+ if self.__initialized:
211
+ return
212
+
213
+ model_path = None
214
+ tags_path = None
215
+
216
+ if self.__local_model_dir:
217
+ print(f"WD14: Attempting to load from local directory: {self.__local_model_dir}")
218
+ potential_model_path = os.path.join(self.__local_model_dir, self.__model_filename)
219
+ potential_tags_path = os.path.join(self.__local_model_dir, self.__tags_filename)
220
+ if os.path.exists(potential_model_path) and os.path.exists(potential_tags_path):
221
+ model_path = potential_model_path
222
+ tags_path = potential_tags_path
223
+ print("WD14: Found local model and tags file.")
224
+ else:
225
+ print("WD14: Local files not found. Falling back to Hugging Face download.")
226
+
227
+ if model_path is None or tags_path is None:
228
+ print(f"WD14: Downloading from repo: {self.__repo}")
229
+ hf_token = os.environ.get("HF_TOKEN") # Reuse HF token if available
230
+ try:
231
+ model_path = hf_hub_download(self.__repo, filename=self.__model_filename, token=hf_token)
232
+ tags_path = hf_hub_download(self.__repo, filename=self.__tags_filename, token=hf_token)
233
+ print("WD14: Download complete.")
234
+ except Exception as e:
235
+ print(f"WD14: Error downloading from Hugging Face: {e}")
236
+ # Decide how to handle this - maybe raise error or disable tagging?
237
+ # For now, we'll let it fail later if model is None
238
+ return # Cannot initialize
239
+
240
+ try:
241
+ self._model = InferenceSession(str(model_path))
242
+ self._tags = pd.read_csv(tags_path)
243
+ self.__initialized = True
244
+ print("WD14: Tagger initialized successfully.")
245
+ except Exception as e:
246
+ print(f"WD14: Error initializing ONNX session or reading tags: {e}")
247
+
248
+ def _calculation(self, image: Image.Image) -> pd.DataFrame | None:
249
+ self._init()
250
+ if not self._model or self._tags is None:
251
+ print("WD14: Tagger not initialized.")
252
+ return None
253
+
254
+ _, height, _, _ = self._model.get_inputs()[0].shape
255
+
256
+ image = image.convert('RGBA')
257
+ new_image = Image.new('RGBA', image.size, 'WHITE')
258
+ new_image.paste(image, mask=image)
259
+ image = new_image.convert('RGB')
260
+ image.save("image_to_wd.png")
261
+ image = np.asarray(image)
262
+ image = image[:, :, ::-1]
263
+
264
+ image = make_square(image, height)
265
+ image = smart_resize(image, height)
266
+ image = image.astype(np.float32)
267
+ image = np.expand_dims(image, 0)
268
+
269
+ input_name = self._model.get_inputs()[0].name
270
+ label_name = self._model.get_outputs()[0].name
271
+ confidence = self._model.run([label_name], {input_name: image})[0]
272
+
273
+ full_tags = self._tags[['name', 'category']].copy()
274
+ full_tags['confidence'] = confidence[0]
275
+
276
+ return full_tags
277
+
278
+ def interrogate(self, image: Image.Image) -> Tuple[Dict[str, float], Dict[str, float]] | None:
279
+ full_tags = self._calculation(image)
280
+ if full_tags is None:
281
+ return None
282
+
283
+ ratings = dict(full_tags[full_tags['category'] == 9][['name', 'confidence']].values)
284
+ tags = dict(full_tags[full_tags['category'] != 9][['name', 'confidence']].values)
285
+
286
+ return ratings, tags
287
+
288
+ # --- Instantiate WD14 Tagger ---
289
+ WD14_CONVNEXT_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger'
290
+ wd14_local_path = os.environ.get("WD14_CONVNEXT_PATH")
291
+ wd14_tagger = WaifuDiffusionInterrogator(repo=WD14_CONVNEXT_REPO, local_model_dir=wd14_local_path)
292
+
293
+ # --- Helper to format tags ---
294
+ def format_wd14_tags(tags: Dict[str, float], threshold: float = 0.35) -> str:
295
+ filtered_tags = {
296
+ tag: score for tag, score in tags.items()
297
+ if score >= threshold and "background" not in tag and tag not in {"monochrome", "greyscale", "no_humans", "comic", "solo"}
298
+ }
299
+ print(filtered_tags)
300
+ # Sort by score descending, then alphabetically
301
+ tags_pairs = sorted(filtered_tags.items(), key=lambda x: (-x[1], x[0]))
302
+ text_items = [tag.replace('_', ' ') for tag, score in tags_pairs]
303
+ return ', '.join(text_items)
304
+
305
  # Apply decorator conditionally
306
  @spaces.GPU() if ENABLE_ZEROGPU else lambda func: func
307
+ def generate_3d(scribble_image_dict, prompt, scribble_confidence, text_confidence, seed):
308
  print("Generating 3D model...")
309
+ input_prompt = prompt # Keep track of original prompt for return on early exit
310
  if scribble_image_dict is None or scribble_image_dict.get("composite") is None:
311
  print("No scribble image provided.")
312
+ return None, input_prompt # Return None for model, original prompt
313
+
314
+ # --- Prompt Handling ---
315
+ input_prompt = prompt.strip()
316
+ if not input_prompt:
317
+ print("Prompt is empty, attempting WD14 tagging...")
318
+ try:
319
+ # Get the user drawing (black on white) for tagging
320
+ user_drawing_img = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
321
+ tag_results = wd14_tagger.interrogate(user_drawing_img)
322
+ if tag_results:
323
+ ratings, tags = tag_results
324
+ generated_prompt = format_wd14_tags(tags) # Use default threshold
325
+ if generated_prompt:
326
+ print(f"WD14 generated prompt: {generated_prompt}")
327
+ input_prompt = generated_prompt
328
+ else:
329
+ print("WD14 tagging did not produce tags above threshold.")
330
+ input_prompt = "object" # Fallback prompt
331
+ else:
332
+ print("WD14 tagging failed or tagger not initialized.")
333
+ input_prompt = "object" # Fallback prompt
334
+ except Exception as e:
335
+ print(f"Error during WD14 tagging: {e}")
336
+ input_prompt = "object" # Fallback prompt
337
+ else:
338
+ print(f"Using user provided prompt: {input_prompt}")
339
+ # --- End Prompt Handling ---
340
 
341
  # --- Seed Handling ---
342
  current_seed = int(seed)
343
  print(f"Using seed: {current_seed}")
344
  # --- End Seed Handling ---
345
 
346
+ # --- Image Preprocessing for TripoSG ---
347
+ # Get the composite image again (safer in case dict is modified)
348
  # The composite might be RGBA if a layer was involved, ensure RGB for processing
349
+ image_for_triposg = Image.fromarray(scribble_image_dict["composite"]).convert("RGB")
 
350
  # Preprocess the image: invert colors (black on white -> white on black)
351
+ image_np = np.array(image_for_triposg)
352
  processed_image_np = 255 - image_np
353
  processed_image = Image.fromarray(processed_image_np)
354
+ print("Image preprocessed for TripoSG.")
355
+ # --- End Image Preprocessing ---
356
 
357
+ # --- Generator Setup ---
 
 
 
358
  generator = torch.Generator(device='cuda').manual_seed(current_seed)
359
+ # --- End Generator Setup ---
360
 
361
+ # --- Run Pipeline ---
362
  print("Running pipeline...")
363
+ try:
364
+ out = pipe(
365
+ processed_image,
366
+ prompt=input_prompt, # Use the potentially generated prompt
367
+ num_tokens=512, # Default value from example
368
+ guidance_scale=0, # Default value from example
369
+ num_inference_steps=16, # Default value from example
370
+ attention_kwargs={
371
+ "cross_attention_scale": text_confidence,
372
+ "cross_attention_2_scale": scribble_confidence
373
+ },
374
+ generator=generator,
375
+ use_flash_decoder=False, # Default value from example
376
+ dense_octree_depth=8, # Default value from example
377
+ hierarchical_octree_depth=8 # Default value from example
378
+ )
379
+ print("Pipeline finished.")
380
+ except Exception as e:
381
+ print(f"Error during pipeline execution: {e}")
382
+ return None, input_prompt # Return None for model, the prompt used
383
+ # --- End Run Pipeline ---
384
 
385
+ # --- Save Output ---
386
  if out.meshes and len(out.meshes) > 0:
387
  # Create a temporary file with .glb extension
388
  with tempfile.NamedTemporaryFile(suffix=".glb", delete=False) as tmpfile:
389
  output_path = tmpfile.name
390
  out.meshes[0].export(output_path)
391
  print(f"Mesh saved to temporary file: {output_path}")
392
+ return output_path, input_prompt # Return model path and the prompt used
393
  else:
394
  print("Pipeline did not generate any meshes.")
395
+ return None, input_prompt # Return None for model, the prompt used
396
+ # --- End Save Output ---
397
 
398
  # Create the Gradio interface
399
  with gr.Blocks() as demo:
 
424
 
425
  submit_button.click(
426
  fn=generate_3d,
427
+ inputs=gen_inputs,
428
+ outputs=[model_output, prompt_input] # Add prompt_input to outputs
429
  )
430
 
431
  # Define inputs for the lucky button (same as main button for the final call)
432
  lucky_gen_inputs = [image_input, prompt_input, confidence_input, prompt_confidence_input, seed_input] # Added text_confidence_input
433
 
434
  lucky_button.click(
435
+ fn=get_random_seed,
436
  inputs=[],
437
+ outputs=[seed_input]
438
  ).then(
439
+ fn=generate_3d,
440
+ inputs=lucky_gen_inputs,
441
+ outputs=[model_output, prompt_input] # Add prompt_input to outputs
442
  )
443
 
444
  # Launch with queue enabled if using ZeroGPU
requirements.txt CHANGED
@@ -14,3 +14,4 @@ typeguard
14
  ninja
15
  gltflib
16
  https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true
 
 
14
  ninja
15
  gltflib
16
  https://huggingface.co/spaces/VAST-AI/TripoSG/resolve/main/diso-0.1.4-cp310-cp310-linux_x86_64.whl?download=true
17
+ onnxruntime