skylersterling commited on
Commit
57caaab
·
verified ·
1 Parent(s): ad4d3e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -18,7 +18,7 @@ def generate_text(prompt):
18
  input_tokens = tokenizer.encode(prompt, return_tensors='pt')
19
  input_tokens = input_tokens.to('cpu')
20
 
21
- generated_tokens = []
22
 
23
  for _ in range(80): # Adjust the range to control the number of tokens generated
24
  with torch.no_grad():
@@ -26,15 +26,11 @@ def generate_text(prompt):
26
  predictions = outputs.logits
27
  next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
28
 
29
- generated_tokens.append(next_token.item())
30
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
31
 
32
  decoded_token = tokenizer.decode(next_token.item())
33
- yield decoded_token # Yield each token as it is generated
34
-
35
- # Decode the entire generated text
36
- generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
37
- yield generated_text
38
 
39
  # Create a Gradio interface with a text input and a text output
40
  interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=True)
 
18
  input_tokens = tokenizer.encode(prompt, return_tensors='pt')
19
  input_tokens = input_tokens.to('cpu')
20
 
21
+ generated_text = prompt # Start with the initial prompt
22
 
23
  for _ in range(80): # Adjust the range to control the number of tokens generated
24
  with torch.no_grad():
 
26
  predictions = outputs.logits
27
  next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
28
 
 
29
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
30
 
31
  decoded_token = tokenizer.decode(next_token.item())
32
+ generated_text += decoded_token # Append the new token to the generated text
33
+ yield generated_text # Yield the entire generated text so far
 
 
 
34
 
35
  # Create a Gradio interface with a text input and a text output
36
  interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=True)