Sartc commited on
Commit
83df70c
·
verified ·
1 Parent(s): 2b7ede3

gradio application

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import torch
3
+ from safetensors import safe_open
4
+ from huggingface_hub import hf_hub_download
5
+ from transformers import GPT2TokenizerFast
6
+ from model import Config, GPT
7
+ import torch.nn as nn
8
+ import gradio as gr
9
+
10
+ config = Config()
11
+
12
+ def load_safetensors(path):
13
+ state_dict = {}
14
+ with safe_open(path, framework="pt") as f:
15
+ for key in f.keys():
16
+ state_dict[key] = f.get_tensor(key)
17
+ return state_dict
18
+
19
+ def load_local(path):
20
+ return load_safetensors(path)
21
+
22
+ def load_from_hf(repo_id):
23
+ file_path = hf_hub_download(
24
+ repo_id=repo_id,
25
+ filename="storyGPT.safetensors"
26
+ )
27
+ return load_safetensors(file_path)
28
+
29
+ def load_model(repo_id, local_file):
30
+ if repo_id:
31
+ state_dict = load_from_hf(repo_id)
32
+ elif local_file:
33
+ state_dict = load_local(local_file)
34
+ else:
35
+ raise ValueError("Must provide either repo_id or local_file")
36
+
37
+ model = GPT(config)
38
+ model.load_state_dict(state_dict)
39
+ model.eval()
40
+ return model
41
+
42
+ def generate(model, prompt, max_tokens, temperature=0.7):
43
+ for _ in range(max_tokens):
44
+ prompt = prompt[:, :config.context_len]
45
+ logits = model(prompt)
46
+ logits = logits[:, -1, :] / temperature
47
+ logit_probs = nn.functional.softmax(logits, dim=-1)
48
+ next_prompt = torch.multinomial(logit_probs, num_samples=1)
49
+ prompt = torch.cat((prompt, next_prompt), dim=1)
50
+ return prompt
51
+
52
+ def run(prompt):
53
+ if prompt.lower() == "bye":
54
+ print("Bye!")
55
+ return
56
+
57
+ tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
58
+ inputs = tokenizer.encode(prompt, return_tensors='pt')
59
+
60
+ with torch.no_grad(): # Disable gradient calculation
61
+ generated = generate(gpt, inputs,
62
+ max_tokens=config.context_len,
63
+ temperature=0.7)
64
+
65
+ # print(tokenizer.decode(generated[0].cpu().numpy()))
66
+ # new_prompt = input("Your prompt: ")
67
+ # run(new_prompt)
68
+ return tokenizer.decode(generated[0].cpu().numpy())
69
+
70
+ def create_interface():
71
+ iface = gr.Interface(
72
+ fn=run,
73
+ inputs=gr.Textbox(label="Enter your prompt"),
74
+ outputs=gr.Textbox(label="Generated Text"),
75
+ title="GPT Text Generator",
76
+ description="Generate text using the trained GPT model"
77
+ )
78
+ return iface
79
+
80
+ if __name__ == "__main__":
81
+
82
+ file_path="storyGPT.safetensors"
83
+
84
+ if os.path.exists(file_path):
85
+ gpt = load_model(False, file_path)
86
+ else:
87
+ gpt = load_model("sartc/storyGPT", False)
88
+
89
+ # prompt = input("Your prompt: ")
90
+ # run(prompt)
91
+
92
+ interface = create_interface()
93
+ interface.launch()
94
+