IvaElen commited on
Commit
c870658
·
1 Parent(s): 3f3d64c

Delete GPT.py

Browse files
Files changed (1) hide show
  1. GPT.py +0 -60
GPT.py DELETED
@@ -1,60 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import numpy as np
4
- import transformers
5
- import random
6
-
7
- def load_model():
8
- model_finetuned = transformers.AutoModelWithLMHead.from_pretrained(
9
- 'tinkoff-ai/ruDialoGPT-small',
10
- output_attentions = False,
11
- output_hidden_states = False
12
- )
13
- model_finetuned.load_state_dict(torch.load('GPT_sonnik_only.pt', map_location=torch.device('cpu')))
14
- tokenizer = transformers.AutoTokenizer.from_pretrained('tinkoff-ai/ruDialoGPT-small')
15
- return model_finetuned, tokenizer
16
-
17
- def preprocess_text(text_input, tokenizer):
18
- prompt = tokenizer.encode(text_input, return_tensors='pt').to(device)
19
- return prompt
20
-
21
- def predict_sentiment(model, prompt, temp, num_generate):
22
- result = model.generate(
23
- input_ids=prompt,
24
- max_length=150,
25
- num_beams=5,
26
- do_sample=True,
27
- temperature=float(temp),
28
- top_k=50,
29
- top_p=0.6,
30
- no_repeat_ngram_size=3,
31
- num_return_sequences=num_generate,
32
- ).cpu().numpy()
33
- return result
34
-
35
- st.title('Text generation with dreambook')
36
-
37
- model, tokenizer = load_model()
38
-
39
- text_input = st.text_input("Enter some text about movie")
40
- max_len = st.slider('Length of sequence', 0, 500, 250)
41
- temp = st.slider('Temperature', 0, 30, 0)
42
- if st.button('Generate a random number of sequences'):
43
- num_generate = random.randint(1,5)
44
- st.write(f'Number of sequences: {num_generate}')
45
- else:
46
- num_generate = st.text_input("Enter number of sequences")
47
-
48
- if text_input:
49
- prompt = preprocess_text(text_input, tokenizer)
50
- result = predict_sentiment(model, prompt, max_len, temp, num_generate)
51
- for i in result:
52
- st.write(textwrap.fill(tokenizer.decode(i), max_len))
53
- if st.button('Next'):
54
- try:
55
- continue
56
- except:
57
- st.write('All sequences are return. Generate new')
58
- else:
59
- None
60
-