sairamn commited on
Commit
34f0c1d
·
1 Parent(s): c18799d

Updated version 1.0

Browse files
Files changed (1) hide show
  1. app.py +42 -5
app.py CHANGED
@@ -1,21 +1,58 @@
 
1
  import streamlit as st
2
  from transformers import BartTokenizer, TFBartForConditionalGeneration
3
 
4
- model_name = 'facebook/bart-large-cnn'
 
 
 
 
5
  tokenizer = BartTokenizer.from_pretrained(model_name)
6
  model = TFBartForConditionalGeneration.from_pretrained(model_name)
7
 
8
- def summarize(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  inputs = tokenizer.encode(text, return_tensors='tf', max_length=1024, truncation=True)
10
- summary_ids = model.generate(inputs, max_length=150, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
11
- summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
12
  return summary
13
 
 
14
  st.title('Text Summarizer')
15
  user_input = st.text_area("Enter text to summarize:", "")
 
 
 
 
 
 
 
16
  if st.button('Summarize'):
17
  if user_input:
18
- summary = summarize(user_input)
 
19
  st.write(summary)
20
  else:
21
  st.write("Please enter some text to summarize.")
 
1
+ import os
2
  import streamlit as st
3
  from transformers import BartTokenizer, TFBartForConditionalGeneration
4
 
5
+ # Suppress TensorFlow logging for errors only
6
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
7
+
8
+ # Load the model and tokenizer
9
+ model_name = 'facebook-bart-large-cnn'
10
  tokenizer = BartTokenizer.from_pretrained(model_name)
11
  model = TFBartForConditionalGeneration.from_pretrained(model_name)
12
 
13
+ def summarize(text, style):
14
+ input_length = len(tokenizer.encode(text, return_tensors='tf', max_length=1024, truncation=True)[0])
15
+
16
+ # Calculate max_length based on the chosen style
17
+ if style == 'Accurate':
18
+ max_length = int(input_length * 0.3) # Less than one-third
19
+ min_length = int(input_length * 0.2)
20
+ length_penalty = 1.0
21
+ elif style == 'Precise':
22
+ max_length = int(input_length * 0.33) # One-third
23
+ min_length = int(input_length * 0.25)
24
+ length_penalty = 1.2
25
+ else: # Normal
26
+ max_length = int(input_length * 0.5) # Half the length
27
+ min_length = int(input_length * 0.4)
28
+ length_penalty = 1.5
29
+
30
  inputs = tokenizer.encode(text, return_tensors='tf', max_length=1024, truncation=True)
31
+ summary_ids = model.generate(
32
+ inputs,
33
+ max_length=max_length,
34
+ min_length=min_length,
35
+ length_penalty=length_penalty,
36
+ num_beams=4,
37
+ early_stopping=True
38
+ )
39
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
40
  return summary
41
 
42
+ # Streamlit app
43
  st.title('Text Summarizer')
44
  user_input = st.text_area("Enter text to summarize:", "")
45
+
46
+ # Dropdown menu for summarization style
47
+ summary_style = st.selectbox(
48
+ 'Choose summarization style:',
49
+ ('Accurate', 'Precise', 'Normal')
50
+ )
51
+
52
  if st.button('Summarize'):
53
  if user_input:
54
+ summary = summarize(user_input, summary_style)
55
+ st.write("Summary:")
56
  st.write(summary)
57
  else:
58
  st.write("Please enter some text to summarize.")