multimodalart HF Staff commited on
Commit
7279ae9
·
verified ·
1 Parent(s): 418ce6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -309,7 +309,7 @@ def generate_dream_response(
309
  if num_samples > 0:
310
  transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
311
  if alg_temp_val is None or alg_temp_val <= 0: # Top-k
312
- sort_metric = confidence if alg != 'entropy' else -confidence
313
  k_topk = min(num_samples, sort_metric.numel())
314
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
315
  else: # Sample based on temp
@@ -330,12 +330,12 @@ def generate_dream_response(
330
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
331
  except RuntimeError as e:
332
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
333
- sort_metric = confidence if alg != 'entropy' else -confidence
334
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
335
  if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
336
  else: # Fallback if probs invalid for multinomial
337
  # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
338
- sort_metric = confidence if alg != 'entropy' else -confidence
339
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
340
  if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
341
 
 
309
  if num_samples > 0:
310
  transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
311
  if alg_temp_val is None or alg_temp_val <= 0: # Top-k
312
+ sort_metric = confidence
313
  k_topk = min(num_samples, sort_metric.numel())
314
  if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
315
  else: # Sample based on temp
 
330
  try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
331
  except RuntimeError as e:
332
  print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
333
+ sort_metric = confidence
334
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
335
  if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
336
  else: # Fallback if probs invalid for multinomial
337
  # print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
338
+ sort_metric = confidence
339
  k_multinomial_fallback = min(num_samples, sort_metric.numel())
340
  if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
341