multimodalart HF Staff commited on
Commit
d07e660
·
verified ·
1 Parent(s): badae07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -399,14 +399,26 @@ def generate_dream_response(
399
 
400
  # Normalize probabilities if they don't sum to 1
401
  prob_sum = conf_probs.sum()
402
- if not torch.isclose(prob_sum, torch.tensor(1.0, device=device), atol=1e-4) and prob_sum > 0:
 
 
 
 
403
  # print(f"Warning step {i}: Confidence probabilities sum {prob_sum:.4f} != 1. Re-normalizing.")
404
- conf_probs = conf_probs / prob_sum
405
-
406
- if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(conf_probs.sum(), torch.tensor(1.0, device=device)):
 
 
 
 
 
 
 
407
  try:
408
  transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
409
  except RuntimeError as e:
 
410
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
411
  sort_metric = confidence if alg != 'entropy' else -confidence
412
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
@@ -414,8 +426,9 @@ def generate_dream_response(
414
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
415
  else:
416
  transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
417
- else: # Handle cases where multinomial is not possible
418
- # print(f"Warning step {i}: Invalid probabilities for multinomial sampling. Falling back to top-k.")
 
419
  sort_metric = confidence if alg != 'entropy' else -confidence
420
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
421
  if k_multinomial_fallback > 0:
 
399
 
400
  # Normalize probabilities if they don't sum to 1
401
  prob_sum = conf_probs.sum()
402
+ # --- START FIX ---
403
+ # Ensure the comparison tensor has the same dtype as prob_sum
404
+ target_sum_tensor = torch.tensor(1.0, device=device, dtype=prob_sum.dtype)
405
+ if not torch.isclose(prob_sum, target_sum_tensor, atol=1e-4) and prob_sum > 0:
406
+ # --- END FIX ---
407
  # print(f"Warning step {i}: Confidence probabilities sum {prob_sum:.4f} != 1. Re-normalizing.")
408
+ # Avoid division by zero if prob_sum is extremely small or zero
409
+ safe_prob_sum = torch.max(prob_sum, torch.tensor(1e-12, device=device, dtype=prob_sum.dtype))
410
+ conf_probs = conf_probs / safe_prob_sum # Use safe_prob_sum
411
+
412
+ # Ensure num_samples is valid and probabilities are okay for multinomial
413
+ # --- START FIX ---
414
+ # Check sum again after potential normalization
415
+ final_prob_sum_check = conf_probs.sum()
416
+ if conf_probs.numel() > 0 and num_samples > 0 and torch.all(conf_probs >= 0) and torch.isclose(final_prob_sum_check, target_sum_tensor, atol=1e-4):
417
+ # --- END FIX ---
418
  try:
419
  transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
420
  except RuntimeError as e:
421
+ # [Fallback logic remains the same]
422
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
423
  sort_metric = confidence if alg != 'entropy' else -confidence
424
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
 
426
  _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
427
  else:
428
  transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device)
429
+ else: # Handle cases where multinomial is not possible (e.g., bad probabilities)
430
+ # [Fallback logic remains the same]
431
+ # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
432
  sort_metric = confidence if alg != 'entropy' else -confidence
433
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
434
  if k_multinomial_fallback > 0: