Gbssreejith commited on
Commit
b3a3658
·
verified ·
1 Parent(s): 28631a4

Delete app.py.py

Browse files
Files changed (1) hide show
  1. app.py.py +0 -484
app.py.py DELETED
@@ -1,484 +0,0 @@
1
- import copy
2
- import torch
3
- import math
4
- import torch.nn as nn
5
- from torch.nn.parameter import Parameter
6
- import random
7
- import numpy as np
8
- from load_weights import load_weight
9
- from sklearn.model_selection import train_test_split
10
- from transformers import GPT2TokenizerFast
11
- import pandas as pd
12
- from torch.utils.data import Dataset, DataLoader
13
- from transformers import AdamW, get_linear_schedule_with_warmup
14
- torch.manual_seed(42)
15
- import nltk
16
- nltk.download('punkt')
17
-
18
- from transformers import GPT2Tokenizer
19
- from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
20
- import datetime
21
- import time
22
- import os
23
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
24
- from tqdm import trange
25
- import gradio as gr
26
- import re
27
-
28
-
29
-
30
-
31
- def gelu(x):
32
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
33
-
34
- class Conv1D(nn.Module):
35
- def __init__(self, nf, nx):
36
- super(Conv1D, self).__init__()
37
- self.nf = nf
38
- w = torch.empty(nx, nf)
39
- nn.init.normal_(w, std=0.02)
40
- self.weight = Parameter(w)
41
- self.bias = Parameter(torch.zeros(nf))
42
-
43
- def forward(self, x):
44
- size_out = x.size()[:-1] + (self.nf,)
45
- x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
46
- x = x.view(*size_out)
47
- return x
48
-
49
- class LayerNorm(nn.Module):
50
- def __init__(self, hidden_size, eps=1e-12):
51
- """Construct a layernorm module in the TF style (epsilon inside the square root).
52
- """
53
- super(LayerNorm, self).__init__()
54
- self.weight = nn.Parameter(torch.ones(hidden_size))
55
- self.bias = nn.Parameter(torch.zeros(hidden_size))
56
- self.variance_epsilon = eps
57
-
58
- def forward(self, x):
59
- u = x.mean(-1, keepdim=True)
60
- s = (x - u).pow(2).mean(-1, keepdim=True)
61
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
62
- return self.weight * x + self.bias
63
-
64
-
65
-
66
- class Attention(nn.Module):
67
- def __init__(self, nx, n_ctx, config, scale=False):
68
- super(Attention, self).__init__()
69
- n_state = nx # in Attention: n_state=768 (nx=n_embd)
70
- # [switch nx => n_state from Block to Attention to keep identical to TF implem]
71
- assert n_state % config.n_head == 0
72
- self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
73
- self.n_head = config.n_head
74
- self.split_size = n_state
75
- self.scale = scale
76
- self.c_attn = Conv1D(n_state * 3, nx)
77
- self.c_proj = Conv1D(n_state, nx)
78
-
79
- def _attn(self, q, k, v):
80
- w = torch.matmul(q, k)
81
- if self.scale:
82
- w = w / math.sqrt(v.size(-1))
83
- nd, ns = w.size(-2), w.size(-1)
84
- b = self.bias[:, :, ns-nd:ns, :ns]
85
- w = w * b - 1e10 * (1 - b)
86
- w = nn.Softmax(dim=-1)(w)
87
- return torch.matmul(w, v)
88
-
89
- def merge_heads(self, x):
90
- x = x.permute(0, 2, 1, 3).contiguous()
91
- new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
92
- return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
93
-
94
- def split_heads(self, x, k=False):
95
- new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
96
- x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
97
- if k:
98
- return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
99
- else:
100
- return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
101
-
102
- def forward(self, x, layer_past=None):
103
- x = self.c_attn(x)
104
- query, key, value = x.split(self.split_size, dim=2)
105
- query = self.split_heads(query)
106
- key = self.split_heads(key, k=True)
107
- value = self.split_heads(value)
108
- if layer_past is not None:
109
- past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
110
- key = torch.cat((past_key, key), dim=-1)
111
- value = torch.cat((past_value, value), dim=-2)
112
- present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
113
- a = self._attn(query, key, value)
114
- a = self.merge_heads(a)
115
- a = self.c_proj(a)
116
- return a, present
117
-
118
-
119
- class MLP(nn.Module):
120
- def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
121
- super(MLP, self).__init__()
122
- nx = config.n_embd
123
- self.c_fc = Conv1D(n_state, nx)
124
- self.c_proj = Conv1D(nx, n_state)
125
- self.act = gelu
126
-
127
- def forward(self, x):
128
- h = self.act(self.c_fc(x))
129
- h2 = self.c_proj(h)
130
- return h2
131
-
132
-
133
- class Block(nn.Module):
134
- def __init__(self, n_ctx, config, scale=False):
135
- super(Block, self).__init__()
136
- nx = config.n_embd
137
- self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
138
- self.attn = Attention(nx, n_ctx, config, scale)
139
- self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
140
- self.mlp = MLP(4 * nx, config)
141
-
142
- def forward(self, x, layer_past=None):
143
- a, present = self.attn(self.ln_1(x), layer_past=layer_past)
144
- x = x + a
145
- m = self.mlp(self.ln_2(x))
146
- x = x + m
147
- return x, present
148
-
149
-
150
-
151
- class GPT2Model(nn.Module):
152
- def __init__(self, config):
153
- super(GPT2Model, self).__init__()
154
- self.n_layer = config.n_layer
155
- self.n_embd = config.n_embd
156
- self.n_vocab = config.vocab_size
157
-
158
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
159
- self.wpe = nn.Embedding(config.n_positions, config.n_embd)
160
- block = Block(config.n_ctx, config, scale=True)
161
- self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
162
- self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
163
-
164
- def set_embeddings_weights(self, model_embeddings_weights):
165
- embed_shape = model_embeddings_weights.shape
166
- self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
167
- self.decoder.weight = model_embeddings_weights # Tied weights
168
-
169
-
170
-
171
- def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
172
-
173
- if (input_ids >= self.n_vocab).any():
174
- raise ValueError(f"Invalid token ID found in input_ids: {input_ids}")
175
-
176
- # print(f"input_ids: {input_ids}") # Debugging statement
177
- # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement
178
- # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement
179
-
180
- if past is None:
181
- past_length = 0
182
- past = [None] * len(self.h)
183
- else:
184
- past_length = past[0][0].size(-2)
185
- if position_ids is None:
186
- position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
187
- device=input_ids.device)
188
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
189
-
190
- input_shape = input_ids.size()
191
- input_ids = input_ids.view(-1, input_ids.size(-1))
192
- position_ids = position_ids.view(-1, position_ids.size(-1))
193
-
194
- inputs_embeds = self.wte(input_ids)
195
- position_embeds = self.wpe(position_ids)
196
-
197
- # print(f"inputs_embeds shape: {inputs_embeds.shape}")
198
- # print(f"position_embeds shape: {position_embeds.shape}")
199
-
200
-
201
- if token_type_ids is not None:
202
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
203
- token_type_embeds = self.wte(token_type_ids)
204
- else:
205
- token_type_embeds = 0
206
- hidden_states = inputs_embeds + position_embeds + token_type_embeds
207
- presents = []
208
- for block, layer_past in zip(self.h, past):
209
- hidden_states, present = block(hidden_states, layer_past)
210
- presents.append(present)
211
- hidden_states = self.ln_f(hidden_states)
212
- output_shape = input_shape + (hidden_states.size(-1),)
213
- return hidden_states.view(*output_shape), presents
214
-
215
- class GPT2LMHead(nn.Module):
216
- def __init__(self, model_embeddings_weights, config):
217
- super(GPT2LMHead, self).__init__()
218
- self.n_embd = config.n_embd
219
- self.set_embeddings_weights(model_embeddings_weights)
220
-
221
- def set_embeddings_weights(self, model_embeddings_weights):
222
- embed_shape = model_embeddings_weights.shape
223
- self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
224
- self.decoder.weight = model_embeddings_weights # Tied weights
225
-
226
- def forward(self, hidden_state):
227
- # Truncated Language modeling logits (we remove the last token)
228
- # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
229
- lm_logits = self.decoder(hidden_state)
230
- return lm_logits
231
-
232
- import torch.nn.functional as F
233
-
234
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
235
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
236
- Args:
237
- logits: logits distribution shape (batch size, vocabulary size)
238
- top_k > 0: keep only top k tokens with highest probability (top-k filtering).
239
- top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
240
- filter_value: value to replace filtered logits.
241
- """
242
- assert logits.dim() == 2 # batch size x vocabulary size
243
- top_k = min(top_k, logits.size(-1)) # Safety check
244
- if top_k > 0:
245
- # Remove all tokens with a probability less than the last token of the top-k
246
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
247
- logits[indices_to_remove] = filter_value
248
-
249
- if top_p > 0.0:
250
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
251
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
252
-
253
- # Remove tokens with cumulative probability above the threshold
254
- sorted_indices_to_remove = cumulative_probs > top_p
255
- # Shift the indices to the right to keep also the first token above the threshold
256
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
257
- sorted_indices_to_remove[..., 0] = 0
258
-
259
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
260
- logits[indices_to_remove] = filter_value
261
- return logits
262
-
263
-
264
- class GPT2LMHeadModel(nn.Module):
265
- def __init__(self, config):
266
- super(GPT2LMHeadModel, self).__init__()
267
- self.transformer = GPT2Model(config)
268
- self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
269
-
270
- def set_tied(self):
271
- """ Make sure we are sharing the embeddings
272
- """
273
- self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
274
-
275
- def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
276
- hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
277
- lm_logits = self.lm_head(hidden_states)
278
-
279
- outputs = (lm_logits,presents)
280
-
281
- if lm_labels is not None:
282
- shift_logits = lm_logits[..., :-1, :].contiguous()
283
- shift_labels = lm_labels[..., 1:].contiguous()
284
- loss_fct = nn.CrossEntropyLoss()
285
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
286
- outputs = (loss,) + outputs
287
- return outputs
288
-
289
- import torch.nn.functional as F
290
-
291
-
292
-
293
- def generate(
294
- self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda'
295
- ):
296
- self.eval()
297
- input_ids = input_ids.to(device)
298
- batch_size = input_ids.shape[0]
299
- past = None
300
-
301
- generated = input_ids
302
- with torch.no_grad():
303
- for _ in range(max_length):
304
- outputs = self(input_ids, past=past)
305
- next_token_logits = outputs[0][:, -1, :]
306
- past = outputs[1]
307
-
308
- for i in range(batch_size):
309
- for token_id in set(generated[i].tolist()):
310
- next_token_logits[i, token_id] /= repetition_penalty
311
-
312
- next_token_logits = next_token_logits / temperature
313
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
314
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
315
- generated = torch.cat((generated, next_token), dim=1)
316
-
317
- if (next_token == self.config.eos_token_id).all():
318
- break
319
-
320
- input_ids = next_token
321
-
322
- return generated
323
-
324
-
325
- class GPT2Config(object):
326
- def __init__(
327
- self,
328
- vocab_size_or_config_json_file=50257,
329
- n_positions=1024,
330
- n_ctx=1024,
331
- n_embd=768,
332
- n_layer=12,
333
- n_head=12,
334
- layer_norm_epsilon=1e-5,
335
- initializer_range=0.02,
336
- ):
337
- self.vocab_size = vocab_size_or_config_json_file
338
- self.n_ctx = n_ctx
339
- self.n_positions = n_positions
340
- self.n_embd = n_embd
341
- self.n_layer = n_layer
342
- self.n_head = n_head
343
- self.layer_norm_epsilon = layer_norm_epsilon
344
- self.initializer_range = initializer_range
345
-
346
-
347
-
348
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
349
- config = GPT2Config()
350
- model = GPT2LMHeadModel(config)
351
- state_dict = torch.load(r'C:\vision_model\gpt-2-Pytorch\test\gpt_today\weights\epoch_1.pth', map_location='cpu' if not torch.cuda.is_available() else None)
352
- model = load_weight(model, state_dict)
353
- model.to(device)
354
- print(model)
355
- model.eval()
356
-
357
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
358
- tokenizer.pad_token = tokenizer.eos_token
359
-
360
-
361
-
362
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
363
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
364
- Args:
365
- logits: logits distribution shape (batch size x vocabulary size)
366
- top_k > 0: keep only top k tokens with highest probability (top-k filtering).
367
- top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
368
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
369
- """
370
- assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)"
371
- top_k = min(top_k, logits.size(-1)) # Safety check
372
- if top_k > 0:
373
- # Remove all tokens with a probability less than the last token of the top-k
374
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
375
- logits[indices_to_remove] = filter_value
376
-
377
- if top_p > 0.0:
378
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
379
- cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
380
-
381
- # Remove tokens with cumulative probability above the threshold
382
- sorted_indices_to_remove = cumulative_probs > top_p
383
- # Shift the indices to the right to keep also the first token above the threshold
384
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
385
- sorted_indices_to_remove[..., 0] = 0
386
-
387
- # Ensure that the dimensions match
388
- if sorted_indices_to_remove.size() != sorted_indices.size():
389
- raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}")
390
-
391
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
392
-
393
- # Expand dimensions to match logits tensor and use scatter_
394
- for batch_idx in range(logits.size(0)):
395
- logits[batch_idx, indices_to_remove[batch_idx]] = filter_value
396
-
397
- return logits
398
-
399
- # prompt_text = "What is the classical conceptualisation of oxidation and reduction in redox reactions?"
400
- # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]"
401
- # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
402
-
403
-
404
- # max_length = 50
405
- # temperature = 0.7
406
- # top_k = 50
407
- # top_p = 0.95
408
- # repetition_penalty = 1.0
409
-
410
- # with torch.no_grad():
411
- # for _ in range(max_length):
412
- # outputs = model(input_ids)
413
- # logits = outputs[0]
414
- # next_token_logits = logits[:, -1, :] / temperature
415
-
416
- # # Apply repetition penalty
417
- # for i in range(input_ids.size(0)):
418
- # for token_id in set(input_ids[i].tolist()):
419
- # next_token_logits[0, token_id] /= repetition_penalty
420
-
421
- # # Filter logits using top-k and/or top-p filtering
422
- # filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
423
- # next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
424
- # input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
425
-
426
-
427
- # import re
428
- # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
429
- # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
430
- # print(input_ids[0])
431
-
432
- # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
433
- # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
434
- # print(wp_responses)
435
-
436
-
437
- # Define the generation function
438
- def generate_text(prompt_text, max_length=50, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.0):
439
- prompt = f"\n[WP] {prompt_text} \n[RESPONSE]"
440
- input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
441
-
442
- with torch.no_grad():
443
- for _ in range(max_length):
444
- outputs = model(input_ids)
445
- logits = outputs[0]
446
- next_token_logits = logits[:, -1, :] / temperature
447
-
448
- # Apply repetition penalty
449
- for i in range(input_ids.size(0)):
450
- for token_id in set(input_ids[i].tolist()):
451
- next_token_logits[0, token_id] /= repetition_penalty
452
-
453
- # Filter logits using top-k and/or top-p filtering
454
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
455
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
456
- input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
457
-
458
- generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
459
- wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
460
- return wp_responses[1]
461
-
462
- # Define the Gradio interface using Blocks
463
- with gr.Blocks() as demo:
464
- with gr.Row():
465
- gr.Markdown("<h1 style='text-align: center'>GPT-2 Text Generator</h1>")
466
- with gr.Row():
467
- with gr.Column():
468
- prompt = gr.Textbox(lines=2, placeholder="Enter prompt here...", label="Prompt")
469
- max_length = gr.Slider(minimum=10, maximum=100, step=1, value=50, label="Max Length")
470
- temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.7, label="Temperature")
471
- top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top K")
472
- top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top P")
473
- repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
474
- generate_button = gr.Button("Generate")
475
- with gr.Column():
476
- output_text = gr.Textbox(lines=20, label="Generated Text")
477
-
478
- generate_button.click(
479
- fn=generate_text,
480
- inputs=[prompt, max_length, temperature, top_k, top_p, repetition_penalty],
481
- outputs=output_text
482
- )
483
-
484
- demo.launch()