Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -309,7 +309,7 @@ def generate_dream_response(
|
|
309 |
if num_samples > 0:
|
310 |
transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
|
311 |
if alg_temp_val is None or alg_temp_val <= 0: # Top-k
|
312 |
-
sort_metric = confidence
|
313 |
k_topk = min(num_samples, sort_metric.numel())
|
314 |
if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
|
315 |
else: # Sample based on temp
|
@@ -330,12 +330,12 @@ def generate_dream_response(
|
|
330 |
try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
|
331 |
except RuntimeError as e:
|
332 |
print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
|
333 |
-
sort_metric = confidence
|
334 |
k_multinomial_fallback = min(num_samples, sort_metric.numel())
|
335 |
if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
|
336 |
else: # Fallback if probs invalid for multinomial
|
337 |
# print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
|
338 |
-
sort_metric = confidence
|
339 |
k_multinomial_fallback = min(num_samples, sort_metric.numel())
|
340 |
if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
|
341 |
|
|
|
309 |
if num_samples > 0:
|
310 |
transfer_indices_relative = torch.tensor([], dtype=torch.long, device=device) # Init empty
|
311 |
if alg_temp_val is None or alg_temp_val <= 0: # Top-k
|
312 |
+
sort_metric = confidence
|
313 |
k_topk = min(num_samples, sort_metric.numel())
|
314 |
if k_topk > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_topk)
|
315 |
else: # Sample based on temp
|
|
|
330 |
try: transfer_indices_relative = torch.multinomial(conf_probs, num_samples=num_samples, replacement=False)
|
331 |
except RuntimeError as e:
|
332 |
print(f"Warning step {i}: Multinomial sampling failed ('{e}'). Falling back to top-k.")
|
333 |
+
sort_metric = confidence
|
334 |
k_multinomial_fallback = min(num_samples, sort_metric.numel())
|
335 |
if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
|
336 |
else: # Fallback if probs invalid for multinomial
|
337 |
# print(f"Warning step {i}: Invalid probabilities for multinomial sampling (sum={final_prob_sum_check:.4f}). Falling back to top-k.")
|
338 |
+
sort_metric = confidence
|
339 |
k_multinomial_fallback = min(num_samples, sort_metric.numel())
|
340 |
if k_multinomial_fallback > 0: _, transfer_indices_relative = torch.topk(sort_metric, k=k_multinomial_fallback)
|
341 |
|