Gbssreejith commited on
Commit
0dfdc45
·
verified ·
1 Parent(s): bbd227d

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -451
app.py DELETED
@@ -1,451 +0,0 @@
1
- import copy
2
- import torch
3
- import math
4
- import torch.nn as nn
5
- from torch.nn.parameter import Parameter
6
- import random
7
- import numpy as np
8
- from load_weights import load_weight
9
- from sklearn.model_selection import train_test_split
10
- from transformers import GPT2TokenizerFast
11
- import pandas as pd
12
- from torch.utils.data import Dataset, DataLoader
13
- from transformers import AdamW, get_linear_schedule_with_warmup
14
- torch.manual_seed(42)
15
- import nltk
16
- nltk.download('punkt')
17
-
18
- from transformers import GPT2Tokenizer
19
- from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler
20
- import datetime
21
- import time
22
- import os
23
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
24
- from tqdm import trange
25
- import gradio as gr
26
-
27
-
28
-
29
- def gelu(x):
30
- return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
31
-
32
- class Conv1D(nn.Module):
33
- def __init__(self, nf, nx):
34
- super(Conv1D, self).__init__()
35
- self.nf = nf
36
- w = torch.empty(nx, nf)
37
- nn.init.normal_(w, std=0.02)
38
- self.weight = Parameter(w)
39
- self.bias = Parameter(torch.zeros(nf))
40
-
41
- def forward(self, x):
42
- size_out = x.size()[:-1] + (self.nf,)
43
- x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
44
- x = x.view(*size_out)
45
- return x
46
-
47
- class LayerNorm(nn.Module):
48
- def __init__(self, hidden_size, eps=1e-12):
49
- """Construct a layernorm module in the TF style (epsilon inside the square root).
50
- """
51
- super(LayerNorm, self).__init__()
52
- self.weight = nn.Parameter(torch.ones(hidden_size))
53
- self.bias = nn.Parameter(torch.zeros(hidden_size))
54
- self.variance_epsilon = eps
55
-
56
- def forward(self, x):
57
- u = x.mean(-1, keepdim=True)
58
- s = (x - u).pow(2).mean(-1, keepdim=True)
59
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
60
- return self.weight * x + self.bias
61
-
62
-
63
-
64
- class Attention(nn.Module):
65
- def __init__(self, nx, n_ctx, config, scale=False):
66
- super(Attention, self).__init__()
67
- n_state = nx # in Attention: n_state=768 (nx=n_embd)
68
- # [switch nx => n_state from Block to Attention to keep identical to TF implem]
69
- assert n_state % config.n_head == 0
70
- self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
71
- self.n_head = config.n_head
72
- self.split_size = n_state
73
- self.scale = scale
74
- self.c_attn = Conv1D(n_state * 3, nx)
75
- self.c_proj = Conv1D(n_state, nx)
76
-
77
- def _attn(self, q, k, v):
78
- w = torch.matmul(q, k)
79
- if self.scale:
80
- w = w / math.sqrt(v.size(-1))
81
- nd, ns = w.size(-2), w.size(-1)
82
- b = self.bias[:, :, ns-nd:ns, :ns]
83
- w = w * b - 1e10 * (1 - b)
84
- w = nn.Softmax(dim=-1)(w)
85
- return torch.matmul(w, v)
86
-
87
- def merge_heads(self, x):
88
- x = x.permute(0, 2, 1, 3).contiguous()
89
- new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
90
- return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
91
-
92
- def split_heads(self, x, k=False):
93
- new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
94
- x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
95
- if k:
96
- return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
97
- else:
98
- return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
99
-
100
- def forward(self, x, layer_past=None):
101
- x = self.c_attn(x)
102
- query, key, value = x.split(self.split_size, dim=2)
103
- query = self.split_heads(query)
104
- key = self.split_heads(key, k=True)
105
- value = self.split_heads(value)
106
- if layer_past is not None:
107
- past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
108
- key = torch.cat((past_key, key), dim=-1)
109
- value = torch.cat((past_value, value), dim=-2)
110
- present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
111
- a = self._attn(query, key, value)
112
- a = self.merge_heads(a)
113
- a = self.c_proj(a)
114
- return a, present
115
-
116
-
117
- class MLP(nn.Module):
118
- def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
119
- super(MLP, self).__init__()
120
- nx = config.n_embd
121
- self.c_fc = Conv1D(n_state, nx)
122
- self.c_proj = Conv1D(nx, n_state)
123
- self.act = gelu
124
-
125
- def forward(self, x):
126
- h = self.act(self.c_fc(x))
127
- h2 = self.c_proj(h)
128
- return h2
129
-
130
-
131
- class Block(nn.Module):
132
- def __init__(self, n_ctx, config, scale=False):
133
- super(Block, self).__init__()
134
- nx = config.n_embd
135
- self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
136
- self.attn = Attention(nx, n_ctx, config, scale)
137
- self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
138
- self.mlp = MLP(4 * nx, config)
139
-
140
- def forward(self, x, layer_past=None):
141
- a, present = self.attn(self.ln_1(x), layer_past=layer_past)
142
- x = x + a
143
- m = self.mlp(self.ln_2(x))
144
- x = x + m
145
- return x, present
146
-
147
-
148
-
149
- class GPT2Model(nn.Module):
150
- def __init__(self, config):
151
- super(GPT2Model, self).__init__()
152
- self.n_layer = config.n_layer
153
- self.n_embd = config.n_embd
154
- self.n_vocab = config.vocab_size
155
-
156
- self.wte = nn.Embedding(config.vocab_size, config.n_embd)
157
- self.wpe = nn.Embedding(config.n_positions, config.n_embd)
158
- block = Block(config.n_ctx, config, scale=True)
159
- self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
160
- self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
161
-
162
- def set_embeddings_weights(self, model_embeddings_weights):
163
- embed_shape = model_embeddings_weights.shape
164
- self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
165
- self.decoder.weight = model_embeddings_weights # Tied weights
166
-
167
-
168
-
169
- def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
170
-
171
- if (input_ids >= self.n_vocab).any():
172
- raise ValueError(f"Invalid token ID found in input_ids: {input_ids}")
173
-
174
- # print(f"input_ids: {input_ids}") # Debugging statement
175
- # print(f"Max input_id: {input_ids.max().item()}") # Debugging statement
176
- # print(f"Min input_id: {input_ids.min().item()}") # Debugging statement
177
-
178
- if past is None:
179
- past_length = 0
180
- past = [None] * len(self.h)
181
- else:
182
- past_length = past[0][0].size(-2)
183
- if position_ids is None:
184
- position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
185
- device=input_ids.device)
186
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
187
-
188
- input_shape = input_ids.size()
189
- input_ids = input_ids.view(-1, input_ids.size(-1))
190
- position_ids = position_ids.view(-1, position_ids.size(-1))
191
-
192
- inputs_embeds = self.wte(input_ids)
193
- position_embeds = self.wpe(position_ids)
194
-
195
- # print(f"inputs_embeds shape: {inputs_embeds.shape}")
196
- # print(f"position_embeds shape: {position_embeds.shape}")
197
-
198
-
199
- if token_type_ids is not None:
200
- token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
201
- token_type_embeds = self.wte(token_type_ids)
202
- else:
203
- token_type_embeds = 0
204
- hidden_states = inputs_embeds + position_embeds + token_type_embeds
205
- presents = []
206
- for block, layer_past in zip(self.h, past):
207
- hidden_states, present = block(hidden_states, layer_past)
208
- presents.append(present)
209
- hidden_states = self.ln_f(hidden_states)
210
- output_shape = input_shape + (hidden_states.size(-1),)
211
- return hidden_states.view(*output_shape), presents
212
-
213
- class GPT2LMHead(nn.Module):
214
- def __init__(self, model_embeddings_weights, config):
215
- super(GPT2LMHead, self).__init__()
216
- self.n_embd = config.n_embd
217
- self.set_embeddings_weights(model_embeddings_weights)
218
-
219
- def set_embeddings_weights(self, model_embeddings_weights):
220
- embed_shape = model_embeddings_weights.shape
221
- self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
222
- self.decoder.weight = model_embeddings_weights # Tied weights
223
-
224
- def forward(self, hidden_state):
225
- # Truncated Language modeling logits (we remove the last token)
226
- # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
227
- lm_logits = self.decoder(hidden_state)
228
- return lm_logits
229
-
230
- import torch.nn.functional as F
231
-
232
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
233
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
234
- Args:
235
- logits: logits distribution shape (batch size, vocabulary size)
236
- top_k > 0: keep only top k tokens with highest probability (top-k filtering).
237
- top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
238
- filter_value: value to replace filtered logits.
239
- """
240
- assert logits.dim() == 2 # batch size x vocabulary size
241
- top_k = min(top_k, logits.size(-1)) # Safety check
242
- if top_k > 0:
243
- # Remove all tokens with a probability less than the last token of the top-k
244
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
245
- logits[indices_to_remove] = filter_value
246
-
247
- if top_p > 0.0:
248
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
249
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
250
-
251
- # Remove tokens with cumulative probability above the threshold
252
- sorted_indices_to_remove = cumulative_probs > top_p
253
- # Shift the indices to the right to keep also the first token above the threshold
254
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
255
- sorted_indices_to_remove[..., 0] = 0
256
-
257
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
258
- logits[indices_to_remove] = filter_value
259
- return logits
260
-
261
-
262
- class GPT2LMHeadModel(nn.Module):
263
- def __init__(self, config):
264
- super(GPT2LMHeadModel, self).__init__()
265
- self.transformer = GPT2Model(config)
266
- self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
267
-
268
- def set_tied(self):
269
- """ Make sure we are sharing the embeddings
270
- """
271
- self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
272
-
273
- def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
274
- hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
275
- lm_logits = self.lm_head(hidden_states)
276
-
277
- outputs = (lm_logits,presents)
278
-
279
- if lm_labels is not None:
280
- shift_logits = lm_logits[..., :-1, :].contiguous()
281
- shift_labels = lm_labels[..., 1:].contiguous()
282
- loss_fct = nn.CrossEntropyLoss()
283
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
284
- outputs = (loss,) + outputs
285
- return outputs
286
-
287
- import torch.nn.functional as F
288
-
289
-
290
-
291
- def generate(
292
- self, input_ids, max_length, temperature=1.0, top_k=0, top_p=0.9, repetition_penalty=1.0, device='cuda'
293
- ):
294
- self.eval()
295
- input_ids = input_ids.to(device)
296
- batch_size = input_ids.shape[0]
297
- past = None
298
-
299
- generated = input_ids
300
- with torch.no_grad():
301
- for _ in range(max_length):
302
- outputs = self(input_ids, past=past)
303
- next_token_logits = outputs[0][:, -1, :]
304
- past = outputs[1]
305
-
306
- for i in range(batch_size):
307
- for token_id in set(generated[i].tolist()):
308
- next_token_logits[i, token_id] /= repetition_penalty
309
-
310
- next_token_logits = next_token_logits / temperature
311
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
312
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
313
- generated = torch.cat((generated, next_token), dim=1)
314
-
315
- if (next_token == self.config.eos_token_id).all():
316
- break
317
-
318
- input_ids = next_token
319
-
320
- return generated
321
-
322
-
323
- class GPT2Config(object):
324
- def __init__(
325
- self,
326
- vocab_size_or_config_json_file=50257,
327
- n_positions=1024,
328
- n_ctx=1024,
329
- n_embd=768,
330
- n_layer=12,
331
- n_head=12,
332
- layer_norm_epsilon=1e-5,
333
- initializer_range=0.02,
334
- ):
335
- self.vocab_size = vocab_size_or_config_json_file
336
- self.n_ctx = n_ctx
337
- self.n_positions = n_positions
338
- self.n_embd = n_embd
339
- self.n_layer = n_layer
340
- self.n_head = n_head
341
- self.layer_norm_epsilon = layer_norm_epsilon
342
- self.initializer_range = initializer_range
343
-
344
-
345
-
346
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
347
- config = GPT2Config()
348
- model = GPT2LMHeadModel(config)
349
- state_dict = torch.load(r'epoch_4.pth', map_location='cpu' if not torch.cuda.is_available() else None)
350
- model = load_weight(model, state_dict)
351
- model.to(device)
352
- print(model)
353
- model.eval()
354
-
355
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
356
- tokenizer.pad_token = tokenizer.eos_token
357
-
358
-
359
-
360
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
361
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
362
- Args:
363
- logits: logits distribution shape (batch size x vocabulary size)
364
- top_k > 0: keep only top k tokens with highest probability (top-k filtering).
365
- top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
366
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
367
- """
368
- assert logits.dim() == 2, "Expected logits dimension to be 2 (batch size x vocabulary size)"
369
- top_k = min(top_k, logits.size(-1)) # Safety check
370
- if top_k > 0:
371
- # Remove all tokens with a probability less than the last token of the top-k
372
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
373
- logits[indices_to_remove] = filter_value
374
-
375
- if top_p > 0.0:
376
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
377
- cumulative_probs = torch.cumsum(nn.Softmax(dim=-1)(sorted_logits), dim=-1)
378
-
379
- # Remove tokens with cumulative probability above the threshold
380
- sorted_indices_to_remove = cumulative_probs > top_p
381
- # Shift the indices to the right to keep also the first token above the threshold
382
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
383
- sorted_indices_to_remove[..., 0] = 0
384
-
385
- # Ensure that the dimensions match
386
- if sorted_indices_to_remove.size() != sorted_indices.size():
387
- raise ValueError(f"Size mismatch: {sorted_indices_to_remove.size()} vs {sorted_indices.size()}")
388
-
389
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
390
-
391
- # Expand dimensions to match logits tensor and use scatter_
392
- for batch_idx in range(logits.size(0)):
393
- logits[batch_idx, indices_to_remove[batch_idx]] = filter_value
394
-
395
- return logits
396
-
397
- # prompt_text = "What is a nucleophile in organic chemistry?"
398
- # prompt = f"\n<|startoftext|>[WP] {prompt_text} \n[RESPONSE]"
399
- # input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
400
-
401
-
402
- max_length = 100
403
- temperature = 0.7
404
- top_k = 1
405
- top_p = 0.95
406
- repetition_penalty = 1.0
407
-
408
- with torch.no_grad():
409
- for _ in range(max_length):
410
- outputs = model(input_ids)
411
- logits = outputs[0]
412
- next_token_logits = logits[:, -1, :] / temperature
413
-
414
- # Apply repetition penalty
415
- for i in range(input_ids.size(0)):
416
- for token_id in set(input_ids[i].tolist()):
417
- next_token_logits[0, token_id] /= repetition_penalty
418
-
419
- # Filter logits using top-k and/or top-p filtering
420
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
421
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
422
- input_ids = torch.cat([input_ids, next_token], dim=-1).to(device)
423
-
424
-
425
- # import re
426
- # # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
427
- # # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
428
- # print(input_ids[0])
429
-
430
- # generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
431
- # wp_responses = re.split(r"\[WP\].*?\n|\[RESPONSE\]", generated_text)[1:]
432
- # print(wp_responses)
433
-
434
- # Create a Gradio interface
435
- iface = gr.Interface(
436
- fn=generate_response,
437
- inputs="text",
438
- outputs="text",
439
- title="Custom GPT-2 Model",
440
- description="Enter a prompt to get a generated response from the custom-trained GPT-2 model.",
441
- examples=[
442
- ["What is a nucleophile in organic chemistry?"],
443
- ["Explain the concept of quantum entanglement."],
444
- ["How does photosynthesis work?"]
445
- ]
446
- )
447
-
448
- # Launch the Gradio interface
449
- iface.launch(share=True, debug=True)
450
-
451
-