pbansal commited on
Commit
96de902
·
verified ·
1 Parent(s): 27d5f63

Update generation_utils.py

Browse files
Files changed (1) hide show
  1. generation_utils.py +2 -2
generation_utils.py CHANGED
@@ -425,7 +425,7 @@ class DreamGenerationMixin:
425
  t = timesteps[i]
426
  s = timesteps[i + 1]
427
 
428
- logits[:,pad_token_id] += eos_penalty * torch.log(1-t+eps)
429
 
430
  num_mask_token = mask_index.sum() / mask_index.shape[0]
431
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
@@ -454,7 +454,7 @@ class DreamGenerationMixin:
454
  hidden_states_, logits_ = self.forward_pass_drafter(x[active_indices], hidden_states[active_indices], num_transfered_)
455
  hidden_states[active_indices] = hidden_states_
456
  logits[active_indices] = logits_
457
- logits[:,pad_token_id] += eos_penalty * torch.log(1-t+eps)
458
 
459
  # this allows user-defined token control of the intermediate steps
460
  x = generation_tokens_hook_func(i, x, logits)
 
425
  t = timesteps[i]
426
  s = timesteps[i + 1]
427
 
428
+ logits[:,:,pad_token_id] += eos_penalty * torch.log(1-t+eps)
429
 
430
  num_mask_token = mask_index.sum() / mask_index.shape[0]
431
  number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
 
454
  hidden_states_, logits_ = self.forward_pass_drafter(x[active_indices], hidden_states[active_indices], num_transfered_)
455
  hidden_states[active_indices] = hidden_states_
456
  logits[active_indices] = logits_
457
+ logits[:,:,pad_token_id] += eos_penalty * torch.log(1-t+eps)
458
 
459
  # this allows user-defined token control of the intermediate steps
460
  x = generation_tokens_hook_func(i, x, logits)