marufc36 commited on
Commit
201a582
·
1 Parent(s): 908ff59
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import T5TokenizerFast
3
+ from model import PoemSummaryModel
4
+
5
+
6
+ tokenizer = T5TokenizerFast.from_pretrained("t5-base")
7
+ best_model = PoemSummaryModel.load_from_checkpoint("checkpoints/best-checkpoint.ckpt")
8
+ best_model.freeze()
9
+
10
+
11
+ def encode_text(text):
12
+ encoding = tokenizer.encode_plus(
13
+ text,
14
+ max_length=512,
15
+ padding="max_length",
16
+ truncation=True,
17
+ return_attention_mask=True,
18
+ return_tensors='pt'
19
+ )
20
+ return encoding["input_ids"], encoding["attention_mask"]
21
+
22
+ def generate_summary(input_ids, attention_mask, model):
23
+ model = model.to(input_ids.device)
24
+ generated_ids = model.model.generate(
25
+ input_ids=input_ids,
26
+ attention_mask=attention_mask,
27
+ max_length=150,
28
+ num_beams=2,
29
+ repetition_penalty=2.5,
30
+ length_penalty=1.0,
31
+ early_stopping=True
32
+ )
33
+ return generated_ids
34
+
35
+ def decode_summary(generated_ids):
36
+ summary = [tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
37
+ for gen_id in generated_ids]
38
+ return "".join(summary)
39
+
40
+ def summarize(text):
41
+ input_ids, attention_mask = encode_text(text)
42
+ generated_ids = generate_summary(input_ids, attention_mask, best_model)
43
+ summary = decode_summary(generated_ids)
44
+ return summary
45
+
46
+ # Create Gradio interface
47
+ input_text = gr.Textbox(lines=10, label="Input Text")
48
+ output_text = gr.Textbox(label="Summary")
49
+
50
+ gr.Interface(
51
+ fn=summarize,
52
+ inputs=input_text,
53
+ outputs=output_text,
54
+ title="Poem Pulse",
55
+ description="Enter a Poem and get its Jist."
56
+ ).launch()