Chaitanya Sagar Gurujula
commited on
Commit
·
0e19c73
1
Parent(s):
8275a34
fixed generate method
Browse files- src/model.py +4 -1
src/model.py
CHANGED
@@ -197,7 +197,10 @@ class GPT(nn.Module):
|
|
197 |
def generate(self, input_ids, max_length=50,eos_token_id=None):
|
198 |
generated_tokens = []
|
199 |
current_ids = input_ids
|
200 |
-
|
|
|
|
|
|
|
201 |
for _ in range(max_length):
|
202 |
# Forward pass to get logits
|
203 |
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
|
|
|
197 |
def generate(self, input_ids, max_length=50,eos_token_id=None):
|
198 |
generated_tokens = []
|
199 |
current_ids = input_ids
|
200 |
+
|
201 |
+
# 🔥 Infer device from input_ids
|
202 |
+
device = input_ids.device
|
203 |
+
|
204 |
for _ in range(max_length):
|
205 |
# Forward pass to get logits
|
206 |
logits = self.forward(current_ids) # Shape: (batch_size, seq_len, vocab_size)
|