mkthoma commited on
Commit
f442c21
·
1 Parent(s): 484eb5f

app update

Browse files
Files changed (1) hide show
  1. app.py +35 -11
app.py CHANGED
@@ -185,27 +185,51 @@ class BigramLanguageModel(nn.Module):
185
  return idx
186
 
187
 
188
- # Load the model
189
- loaded_model = BigramLanguageModel().to(device) # Initialize an instance of your model
190
- loaded_model.load_state_dict(torch.load('bigram_language_model.pth', map_location=torch.device('cpu')))
191
- loaded_model.eval() # Set the model to evaluation mode
192
 
193
- def generate_gpt_outputs(prompt=None, max_new_tokens=2000):
 
 
 
 
 
 
194
  if prompt:
195
  context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
196
  else:
197
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
198
- text_output = decode(loaded_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
199
  return text_output
200
 
 
 
 
 
 
 
 
 
 
201
  import gradio as gr
202
 
203
- title = "Mini GPT on Shakespeare text"
204
- description = "Generate an image with a prompt and apply vibrance loss if you wish to"
 
 
 
 
 
205
 
206
- demo = gr.Interface(generate_gpt_outputs,
207
  inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="Once upon a time,"),
208
  gr.Slider(minimum=100, maximum=5000, step=100, value=2000, label="Max new tokens")],
209
- outputs=gr.Textbox(label="Output generated", type="text"),
210
- title=title, description=description)
 
 
 
 
211
  demo.launch()
 
185
  return idx
186
 
187
 
188
+ # Load the shakespeaere model
189
+ shakespeare_model = BigramLanguageModel().to(device) # Initialize an instance of your model
190
+ shakespeare_model.load_state_dict(torch.load('shakespeaere_language_model.pth', map_location=torch.device('cpu')))
191
+ shakespeare_model.eval() # Set the model to evaluation mode
192
 
193
+ # Load the wikipedia model
194
+ wikipedia_model = BigramLanguageModel().to(device) # Initialize an instance of your model
195
+ wikipedia_model.load_state_dict(torch.load('wikipedia_language_model.pth', map_location=torch.device('cpu')))
196
+ wikipedia_model.eval() # Set the model to evaluation mode
197
+
198
+
199
+ def generate_shakespeare_outputs(prompt=None, max_new_tokens=2000):
200
  if prompt:
201
  context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
202
  else:
203
  context = torch.zeros((1, 1), dtype=torch.long, device=device)
204
+ text_output = decode(shakespeare_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
205
  return text_output
206
 
207
+
208
+ def generate_wikipedia_outputs(prompt=None, max_new_tokens=2000):
209
+ if prompt:
210
+ context = torch.tensor(encode(prompt), dtype=torch.long, device=device).view(1, -1)
211
+ else:
212
+ context = torch.zeros((1, 1), dtype=torch.long, device=device)
213
+ text_output = decode(wikipedia_model.generate(context, max_new_tokens=max_new_tokens)[0].tolist())
214
+ return text_output
215
+
216
  import gradio as gr
217
 
218
+ title = "Nano GPT"
219
+ description = "Nano GPT trained on Shakespeare and Wikipedia datasets. It is trained on a very small amount of data to understand how GPT's are trained and built. <a href='https://github.com/karpathy/nanoGPT'>The implementation can be found here </a>""
220
+
221
+ shakespeare_interface = gr.Interface(generate_shakespeare_outputs,
222
+ inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="Once upon a time,"),
223
+ gr.Slider(minimum=100, maximum=5000, step=100, value=2000, label="Max new tokens")],
224
+ outputs=gr.Textbox(label="Output generated", type="text"))
225
 
226
+ wiki_interface = gr.Interface(generate_wikipedia_outputs,
227
  inputs=[gr.Textbox(label="Enter any prompt ", type="text", value="Once upon a time,"),
228
  gr.Slider(minimum=100, maximum=5000, step=100, value=2000, label="Max new tokens")],
229
+ outputs=gr.Textbox(label="Output generated", type="text"))
230
+
231
+ demo = gr.TabbedInterface([shakespeare_interface, wiki_interface], tab_names=["Shakespeare Data", "Wikipedia Data"],
232
+ title=title, description=description)
233
+
234
+
235
  demo.launch()