File size: 7,115 Bytes
ae0deeb
 
 
 
 
 
 
 
 
fde44b6
74611a4
ae0deeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde44b6
ae0deeb
 
 
 
 
 
 
 
0e7d7b7
 
 
 
 
 
 
 
 
 
ae0deeb
0e7d7b7
 
 
e71b6e8
ae0deeb
 
0e7d7b7
 
ae0deeb
 
 
 
 
 
 
 
 
07c92c4
 
 
0e7d7b7
 
 
ae0deeb
 
 
0e7d7b7
ae0deeb
 
 
 
 
 
0e7d7b7
ae0deeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e7d7b7
ae0deeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf8ef43
 
39cc3c6
bf8ef43
ae0deeb
 
 
 
 
 
 
 
 
 
 
 
 
e71b6e8
ae0deeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf8ef43
 
39cc3c6
bf8ef43
ae0deeb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import numpy as np
import pandas as pd
import re
import os
import cloudpickle
from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM
import tensorflow as tf
import spacy
import streamlit as st
import logging
import traceback
from scraper import scrape_text

# Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases

os.environ['TF_USE_LEGACY_KERAS'] = "1"

CHECKPOINT = "facebook/bart-base"
INPUT_N_TOKENS = 70
TARGET_N_TOKENS = 70


@st.cache_resource
def load_models():
    nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
    tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
    model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
    model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True)
    logging.warning('Loaded models')
    return nlp, tokenizer, model

nlp, tokenizer, model = load_models()

def inference_tokenize(input_: list, n_tokens: int):
    tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf")
    return tokenizer, tokenized_data    


def process_result(result: str):
    result = result.split('", "')
    result = result if isinstance(result, str) else result[0]
    result = re.sub('(<.*?>)', "", result).strip()
    result = re.sub('^\"', "", result).strip()
    return result


def inference(txt):
    try:
        inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS)
        pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4)
        result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred]
        logging.warning(f'paraphrased_result: {result}')
        return result
    except:
        logging.warning(traceback.format_exc())
        raise

def inference_long_text(txt, n_sents):
    paraphrased_txt = []
    input_txt = []
    doc = nlp(txt)
    n = 0
    for sent in doc.sents:
        if n >= n_sents:
            break
        if len(sent.text.split()) >= 3:
            input_txt.append(sent.text)
            n += 1

    with st.spinner('Rewriting...'):
        paraphrased_txt = inference(input_txt)
    return input_txt, paraphrased_txt

    
    
############## ENTRY POINT START #######################
def main():
    st.markdown('''<h3>Text Rewriter</h3>''', unsafe_allow_html=True)
    input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'], 
                      horizontal=True)

    n_sents = st.slider('Select the number of sentences to process', 5, 30, 10)

    scrape_error = None
    paraphrase_error = None
    paraphrased_txt = None
    input_txt = None
    
    
    if input_type == 'Paste URL of the text':
        input_url = st.text_input("Paste URL of the text", "")
        
        if (st.button("Submit")) or (input_url):
            with st.status("Processing...", expanded=True) as status:
                status.empty()
                # Scraping data Start
                try:
                    st.info("Scraping data from the URL.", icon="ℹ️")
                    input_txt = scrape_text(input_url)
                    st.success("Successfully scraped the data.", icon="βœ…")
                except Exception as e:
                    input_txt = None
                    scrape_error = str(e) 

                # Scraping data End
    
                if input_txt is not None:
                    input_txt = re.sub(r'\n+',' ', input_txt)

                    # Paraphrasing start
                    
                    try:
                        st.info("Rewriting the text. This takes time.", icon="ℹ️")
                        input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
                        
                    except Exception as e:
                        paraphrased_txt = None
                        paraphrase_error = str(e)
                    if paraphrased_txt is not None:
                        st.success("Successfully rewrote the text.", icon="βœ…")
                    else:
                        st.error("Encountered an error while rewriting the text.", icon="🚨")

                    # Paraphrasing end

                else:
                    st.error("Encountered an error while scraping the data.", icon="🚨")

                if (scrape_error is None) and (paraphrase_error is None):
                    status.update(label="Done", state="complete", expanded=False)
                else:
                    status.update(label="Error", state="error", expanded=False)

            if scrape_error is not None:
                st.error(f"Scrape Error:  \n{scrape_error}", icon="🚨")
            else:
                if paraphrase_error is not None:
                    st.error(f"Paraphrasing Error:  \n{paraphrase_error}", icon="🚨")
                else:
                    result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
                    result = "<br><br>".join(result)
                    result = result.replace("$", "&#36;")
                    st.markdown(f"{result}", unsafe_allow_html=True)

    else:
        input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150)

        if (st.button("Submit")) or (input_txt):
            with st.status("Processing...", expanded=True) as status:
                input_txt = re.sub(r'\n+',' ', input_txt)

                # Paraphrasing start
                
                try:
                    st.info("Rewriting the text. This takes time.", icon="ℹ️")
                    input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
                    
                except Exception as e:
                    paraphrased_txt = None
                    paraphrase_error = str(e)
                if paraphrased_txt is not None:
                    st.success("Successfully rewrote the text.", icon="βœ…")
                else:
                    st.error("Encountered an error while rewriting the text.", icon="🚨")

                # Paraphrasing end

                if paraphrase_error is None:
                    status.update(label="Done", state="complete", expanded=False)
                else:
                    status.update(label="Error", state="error", expanded=False)

            if paraphrase_error is not None:
                st.error(f"Paraphrasing Error:  \n{paraphrase_error}", icon="🚨")
            else:
                result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
                result = "<br><br>".join(result)
                result = result.replace("$", "&#36;")
                st.markdown(f"{result}", unsafe_allow_html=True)
        
        

############## ENTRY POINT END #######################

if __name__ == "__main__":
    main()