import numpy as np import pandas as pd import re import os import cloudpickle from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM import tensorflow as tf import spacy import streamlit as st import logging import traceback from scraper import scrape_text # Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases os.environ['TF_USE_LEGACY_KERAS'] = "1" CHECKPOINT = "facebook/bart-base" INPUT_N_TOKENS = 70 TARGET_N_TOKENS = 70 @st.cache_resource def load_models(): nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0')) tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT) model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True) logging.warning('Loaded models') return nlp, tokenizer, model nlp, tokenizer, model = load_models() def inference_tokenize(input_: list, n_tokens: int): tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf") return tokenizer, tokenized_data def process_result(result: str): result = result.split('", "') result = result if isinstance(result, str) else result[0] result = re.sub('(<.*?>)', "", result).strip() result = re.sub('^\"', "", result).strip() return result def inference(txt): try: inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS) pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4) result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred] logging.warning(f'paraphrased_result: {result}') return result except: logging.warning(traceback.format_exc()) raise def inference_long_text(txt, n_sents): paraphrased_txt = [] input_txt = [] doc = nlp(txt) n = 0 for sent in doc.sents: if n >= n_sents: break if len(sent.text.split()) >= 3: input_txt.append(sent.text) n += 1 with st.spinner('Rewriting...'): paraphrased_txt = inference(input_txt) return input_txt, paraphrased_txt ############## ENTRY POINT START ####################### def main(): st.markdown('''

Text Rewriter

''', unsafe_allow_html=True) input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'], horizontal=True) n_sents = st.slider('Select the number of sentences to process', 5, 30, 10) scrape_error = None paraphrase_error = None paraphrased_txt = None input_txt = None if input_type == 'Paste URL of the text': input_url = st.text_input("Paste URL of the text", "") if (st.button("Submit")) or (input_url): with st.status("Processing...", expanded=True) as status: status.empty() # Scraping data Start try: st.info("Scraping data from the URL.", icon="ℹ️") input_txt = scrape_text(input_url) st.success("Successfully scraped the data.", icon="✅") except Exception as e: input_txt = None scrape_error = str(e) # Scraping data End if input_txt is not None: input_txt = re.sub(r'\n+',' ', input_txt) # Paraphrasing start try: st.info("Rewriting the text. This takes time.", icon="ℹ️") input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) except Exception as e: paraphrased_txt = None paraphrase_error = str(e) if paraphrased_txt is not None: st.success("Successfully rewrote the text.", icon="✅") else: st.error("Encountered an error while rewriting the text.", icon="🚨") # Paraphrasing end else: st.error("Encountered an error while scraping the data.", icon="🚨") if (scrape_error is None) and (paraphrase_error is None): status.update(label="Done", state="complete", expanded=False) else: status.update(label="Error", state="error", expanded=False) if scrape_error is not None: st.error(f"Scrape Error: \n{scrape_error}", icon="🚨") else: if paraphrase_error is not None: st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="🚨") else: result = [f"Scraped Sentence: {scraped}
Rewritten Sentence: {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] result = "

".join(result) result = result.replace("$", "$") st.markdown(f"{result}", unsafe_allow_html=True) else: input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150) if (st.button("Submit")) or (input_txt): with st.status("Processing...", expanded=True) as status: input_txt = re.sub(r'\n+',' ', input_txt) # Paraphrasing start try: st.info("Rewriting the text. This takes time.", icon="ℹ️") input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) except Exception as e: paraphrased_txt = None paraphrase_error = str(e) if paraphrased_txt is not None: st.success("Successfully rewrote the text.", icon="✅") else: st.error("Encountered an error while rewriting the text.", icon="🚨") # Paraphrasing end if paraphrase_error is None: status.update(label="Done", state="complete", expanded=False) else: status.update(label="Error", state="error", expanded=False) if paraphrase_error is not None: st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="🚨") else: result = [f"Scraped Sentence: {scraped}
Rewritten Sentence: {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] result = "

".join(result) result = result.replace("$", "$") st.markdown(f"{result}", unsafe_allow_html=True) ############## ENTRY POINT END ####################### if __name__ == "__main__": main()