Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import re | |
import os | |
import cloudpickle | |
from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM | |
import tensorflow as tf | |
import spacy | |
import streamlit as st | |
import logging | |
import traceback | |
from scraper import scrape_text | |
# Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases | |
os.environ['TF_USE_LEGACY_KERAS'] = "1" | |
CHECKPOINT = "facebook/bart-base" | |
INPUT_N_TOKENS = 70 | |
TARGET_N_TOKENS = 70 | |
def load_models(): | |
nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0')) | |
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT) | |
model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT) | |
model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True) | |
logging.warning('Loaded models') | |
return nlp, tokenizer, model | |
nlp, tokenizer, model = load_models() | |
def inference_tokenize(input_: list, n_tokens: int): | |
tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf") | |
return tokenizer, tokenized_data | |
def process_result(result: str): | |
result = result.split('", "') | |
result = result if isinstance(result, str) else result[0] | |
result = re.sub('(<.*?>)', "", result).strip() | |
result = re.sub('^\"', "", result).strip() | |
return result | |
def inference(txt): | |
try: | |
inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS) | |
pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4) | |
result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred] | |
logging.warning(f'paraphrased_result: {result}') | |
return result | |
except: | |
logging.warning(traceback.format_exc()) | |
raise | |
def inference_long_text(txt, n_sents): | |
paraphrased_txt = [] | |
input_txt = [] | |
doc = nlp(txt) | |
n = 0 | |
for sent in doc.sents: | |
if n >= n_sents: | |
break | |
if len(sent.text.split()) >= 3: | |
input_txt.append(sent.text) | |
n += 1 | |
with st.spinner('Rewriting...'): | |
paraphrased_txt = inference(input_txt) | |
return input_txt, paraphrased_txt | |
############## ENTRY POINT START ####################### | |
def main(): | |
st.markdown('''<h3>Text Rewriter</h3>''', unsafe_allow_html=True) | |
input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'], | |
horizontal=True) | |
n_sents = st.slider('Select the number of sentences to process', 5, 30, 10) | |
scrape_error = None | |
paraphrase_error = None | |
paraphrased_txt = None | |
input_txt = None | |
if input_type == 'Paste URL of the text': | |
input_url = st.text_input("Paste URL of the text", "") | |
if (st.button("Submit")) or (input_url): | |
with st.status("Processing...", expanded=True) as status: | |
status.empty() | |
# Scraping data Start | |
try: | |
st.info("Scraping data from the URL.", icon="βΉοΈ") | |
input_txt = scrape_text(input_url) | |
st.success("Successfully scraped the data.", icon="β ") | |
except Exception as e: | |
input_txt = None | |
scrape_error = str(e) | |
# Scraping data End | |
if input_txt is not None: | |
input_txt = re.sub(r'\n+',' ', input_txt) | |
# Paraphrasing start | |
try: | |
st.info("Rewriting the text. This takes time.", icon="βΉοΈ") | |
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) | |
except Exception as e: | |
paraphrased_txt = None | |
paraphrase_error = str(e) | |
if paraphrased_txt is not None: | |
st.success("Successfully rewrote the text.", icon="β ") | |
else: | |
st.error("Encountered an error while rewriting the text.", icon="π¨") | |
# Paraphrasing end | |
else: | |
st.error("Encountered an error while scraping the data.", icon="π¨") | |
if (scrape_error is None) and (paraphrase_error is None): | |
status.update(label="Done", state="complete", expanded=False) | |
else: | |
status.update(label="Error", state="error", expanded=False) | |
if scrape_error is not None: | |
st.error(f"Scrape Error: \n{scrape_error}", icon="π¨") | |
else: | |
if paraphrase_error is not None: | |
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨") | |
else: | |
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] | |
result = "<br><br>".join(result) | |
result = result.replace("$", "$") | |
st.markdown(f"{result}", unsafe_allow_html=True) | |
else: | |
input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150) | |
if (st.button("Submit")) or (input_txt): | |
with st.status("Processing...", expanded=True) as status: | |
input_txt = re.sub(r'\n+',' ', input_txt) | |
# Paraphrasing start | |
try: | |
st.info("Rewriting the text. This takes time.", icon="βΉοΈ") | |
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents) | |
except Exception as e: | |
paraphrased_txt = None | |
paraphrase_error = str(e) | |
if paraphrased_txt is not None: | |
st.success("Successfully rewrote the text.", icon="β ") | |
else: | |
st.error("Encountered an error while rewriting the text.", icon="π¨") | |
# Paraphrasing end | |
if paraphrase_error is None: | |
status.update(label="Done", state="complete", expanded=False) | |
else: | |
status.update(label="Error", state="error", expanded=False) | |
if paraphrase_error is not None: | |
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨") | |
else: | |
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)] | |
result = "<br><br>".join(result) | |
result = result.replace("$", "$") | |
st.markdown(f"{result}", unsafe_allow_html=True) | |
############## ENTRY POINT END ####################### | |
if __name__ == "__main__": | |
main() |