text_rewriter / app.py
ksvmuralidhar's picture
Update app.py
e71b6e8 verified
import numpy as np
import pandas as pd
import re
import os
import cloudpickle
from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM
import tensorflow as tf
import spacy
import streamlit as st
import logging
import traceback
from scraper import scrape_text
# Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases
os.environ['TF_USE_LEGACY_KERAS'] = "1"
CHECKPOINT = "facebook/bart-base"
INPUT_N_TOKENS = 70
TARGET_N_TOKENS = 70
@st.cache_resource
def load_models():
nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True)
logging.warning('Loaded models')
return nlp, tokenizer, model
nlp, tokenizer, model = load_models()
def inference_tokenize(input_: list, n_tokens: int):
tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf")
return tokenizer, tokenized_data
def process_result(result: str):
result = result.split('", "')
result = result if isinstance(result, str) else result[0]
result = re.sub('(<.*?>)', "", result).strip()
result = re.sub('^\"', "", result).strip()
return result
def inference(txt):
try:
inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS)
pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4)
result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred]
logging.warning(f'paraphrased_result: {result}')
return result
except:
logging.warning(traceback.format_exc())
raise
def inference_long_text(txt, n_sents):
paraphrased_txt = []
input_txt = []
doc = nlp(txt)
n = 0
for sent in doc.sents:
if n >= n_sents:
break
if len(sent.text.split()) >= 3:
input_txt.append(sent.text)
n += 1
with st.spinner('Rewriting...'):
paraphrased_txt = inference(input_txt)
return input_txt, paraphrased_txt
############## ENTRY POINT START #######################
def main():
st.markdown('''<h3>Text Rewriter</h3>''', unsafe_allow_html=True)
input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'],
horizontal=True)
n_sents = st.slider('Select the number of sentences to process', 5, 30, 10)
scrape_error = None
paraphrase_error = None
paraphrased_txt = None
input_txt = None
if input_type == 'Paste URL of the text':
input_url = st.text_input("Paste URL of the text", "")
if (st.button("Submit")) or (input_url):
with st.status("Processing...", expanded=True) as status:
status.empty()
# Scraping data Start
try:
st.info("Scraping data from the URL.", icon="ℹ️")
input_txt = scrape_text(input_url)
st.success("Successfully scraped the data.", icon="βœ…")
except Exception as e:
input_txt = None
scrape_error = str(e)
# Scraping data End
if input_txt is not None:
input_txt = re.sub(r'\n+',' ', input_txt)
# Paraphrasing start
try:
st.info("Rewriting the text. This takes time.", icon="ℹ️")
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
except Exception as e:
paraphrased_txt = None
paraphrase_error = str(e)
if paraphrased_txt is not None:
st.success("Successfully rewrote the text.", icon="βœ…")
else:
st.error("Encountered an error while rewriting the text.", icon="🚨")
# Paraphrasing end
else:
st.error("Encountered an error while scraping the data.", icon="🚨")
if (scrape_error is None) and (paraphrase_error is None):
status.update(label="Done", state="complete", expanded=False)
else:
status.update(label="Error", state="error", expanded=False)
if scrape_error is not None:
st.error(f"Scrape Error: \n{scrape_error}", icon="🚨")
else:
if paraphrase_error is not None:
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="🚨")
else:
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
result = "<br><br>".join(result)
result = result.replace("$", "&#36;")
st.markdown(f"{result}", unsafe_allow_html=True)
else:
input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150)
if (st.button("Submit")) or (input_txt):
with st.status("Processing...", expanded=True) as status:
input_txt = re.sub(r'\n+',' ', input_txt)
# Paraphrasing start
try:
st.info("Rewriting the text. This takes time.", icon="ℹ️")
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
except Exception as e:
paraphrased_txt = None
paraphrase_error = str(e)
if paraphrased_txt is not None:
st.success("Successfully rewrote the text.", icon="βœ…")
else:
st.error("Encountered an error while rewriting the text.", icon="🚨")
# Paraphrasing end
if paraphrase_error is None:
status.update(label="Done", state="complete", expanded=False)
else:
status.update(label="Error", state="error", expanded=False)
if paraphrase_error is not None:
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="🚨")
else:
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
result = "<br><br>".join(result)
result = result.replace("$", "&#36;")
st.markdown(f"{result}", unsafe_allow_html=True)
############## ENTRY POINT END #######################
if __name__ == "__main__":
main()