Chaitanya Sagar Gurujula commited on
Commit
0e19c73
·
1 Parent(s): 8275a34

fixed generate method

Browse files
Files changed (1) hide show
  1. src/model.py +4 -1
src/model.py CHANGED
@@ -197,7 +197,10 @@ class GPT(nn.Module):
197
  def generate(self, input_ids, max_length=50,eos_token_id=None):
198
  generated_tokens = []
199
  current_ids = input_ids
200
-
 
 
 
201
  for _ in range(max_length):
202
  # Forward pass to get logits
203
  logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
 
197
  def generate(self, input_ids, max_length=50,eos_token_id=None):
198
  generated_tokens = []
199
  current_ids = input_ids
200
+
201
+ # 🔥 Infer device from input_ids
202
+ device = input_ids.device
203
+
204
  for _ in range(max_length):
205
  # Forward pass to get logits
206
  logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)