Spaces:
Sleeping
Sleeping
File size: 7,115 Bytes
ae0deeb fde44b6 74611a4 ae0deeb fde44b6 ae0deeb 0e7d7b7 ae0deeb 0e7d7b7 e71b6e8 ae0deeb 0e7d7b7 ae0deeb 07c92c4 0e7d7b7 ae0deeb 0e7d7b7 ae0deeb 0e7d7b7 ae0deeb 0e7d7b7 ae0deeb bf8ef43 39cc3c6 bf8ef43 ae0deeb e71b6e8 ae0deeb bf8ef43 39cc3c6 bf8ef43 ae0deeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import numpy as np
import pandas as pd
import re
import os
import cloudpickle
from transformers import BartTokenizerFast, TFAutoModelForSeq2SeqLM
import tensorflow as tf
import spacy
import streamlit as st
import logging
import traceback
from scraper import scrape_text
# Training data https://www.kaggle.com/datasets/vladimirvorobevv/chatgpt-paraphrases
os.environ['TF_USE_LEGACY_KERAS'] = "1"
CHECKPOINT = "facebook/bart-base"
INPUT_N_TOKENS = 70
TARGET_N_TOKENS = 70
@st.cache_resource
def load_models():
nlp = spacy.load(os.path.join('.', 'en_core_web_sm-3.6.0'))
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model = TFAutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)
model.load_weights(os.path.join("models", "bart_en_paraphraser.h5"), by_name=True)
logging.warning('Loaded models')
return nlp, tokenizer, model
nlp, tokenizer, model = load_models()
def inference_tokenize(input_: list, n_tokens: int):
tokenized_data = tokenizer(text=input_, max_length=n_tokens, truncation=True, padding="max_length", return_tensors="tf")
return tokenizer, tokenized_data
def process_result(result: str):
result = result.split('", "')
result = result if isinstance(result, str) else result[0]
result = re.sub('(<.*?>)', "", result).strip()
result = re.sub('^\"', "", result).strip()
return result
def inference(txt):
try:
inference_tokenizer, tokenized_data = inference_tokenize(input_=txt, n_tokens=INPUT_N_TOKENS)
pred = model.generate(**tokenized_data, max_new_tokens=TARGET_N_TOKENS, num_beams=4)
result = [process_result(inference_tokenizer.decode(p, skip_special_tokens=True)) for p in pred]
logging.warning(f'paraphrased_result: {result}')
return result
except:
logging.warning(traceback.format_exc())
raise
def inference_long_text(txt, n_sents):
paraphrased_txt = []
input_txt = []
doc = nlp(txt)
n = 0
for sent in doc.sents:
if n >= n_sents:
break
if len(sent.text.split()) >= 3:
input_txt.append(sent.text)
n += 1
with st.spinner('Rewriting...'):
paraphrased_txt = inference(input_txt)
return input_txt, paraphrased_txt
############## ENTRY POINT START #######################
def main():
st.markdown('''<h3>Text Rewriter</h3>''', unsafe_allow_html=True)
input_type = st.radio('Select an option:', ['Paste URL of the text', 'Paste the text'],
horizontal=True)
n_sents = st.slider('Select the number of sentences to process', 5, 30, 10)
scrape_error = None
paraphrase_error = None
paraphrased_txt = None
input_txt = None
if input_type == 'Paste URL of the text':
input_url = st.text_input("Paste URL of the text", "")
if (st.button("Submit")) or (input_url):
with st.status("Processing...", expanded=True) as status:
status.empty()
# Scraping data Start
try:
st.info("Scraping data from the URL.", icon="βΉοΈ")
input_txt = scrape_text(input_url)
st.success("Successfully scraped the data.", icon="β
")
except Exception as e:
input_txt = None
scrape_error = str(e)
# Scraping data End
if input_txt is not None:
input_txt = re.sub(r'\n+',' ', input_txt)
# Paraphrasing start
try:
st.info("Rewriting the text. This takes time.", icon="βΉοΈ")
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
except Exception as e:
paraphrased_txt = None
paraphrase_error = str(e)
if paraphrased_txt is not None:
st.success("Successfully rewrote the text.", icon="β
")
else:
st.error("Encountered an error while rewriting the text.", icon="π¨")
# Paraphrasing end
else:
st.error("Encountered an error while scraping the data.", icon="π¨")
if (scrape_error is None) and (paraphrase_error is None):
status.update(label="Done", state="complete", expanded=False)
else:
status.update(label="Error", state="error", expanded=False)
if scrape_error is not None:
st.error(f"Scrape Error: \n{scrape_error}", icon="π¨")
else:
if paraphrase_error is not None:
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨")
else:
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
result = "<br><br>".join(result)
result = result.replace("$", "$")
st.markdown(f"{result}", unsafe_allow_html=True)
else:
input_txt = st.text_area("Enter the text. (Ensure the text is grammatically correct and has punctuations at the right places):", "", height=150)
if (st.button("Submit")) or (input_txt):
with st.status("Processing...", expanded=True) as status:
input_txt = re.sub(r'\n+',' ', input_txt)
# Paraphrasing start
try:
st.info("Rewriting the text. This takes time.", icon="βΉοΈ")
input_txt, paraphrased_txt = inference_long_text(input_txt, n_sents)
except Exception as e:
paraphrased_txt = None
paraphrase_error = str(e)
if paraphrased_txt is not None:
st.success("Successfully rewrote the text.", icon="β
")
else:
st.error("Encountered an error while rewriting the text.", icon="π¨")
# Paraphrasing end
if paraphrase_error is None:
status.update(label="Done", state="complete", expanded=False)
else:
status.update(label="Error", state="error", expanded=False)
if paraphrase_error is not None:
st.error(f"Paraphrasing Error: \n{paraphrase_error}", icon="π¨")
else:
result = [f"<b>Scraped Sentence:</b> {scraped}<br><b>Rewritten Sentence:</b> {paraphrased}" for scraped, paraphrased in zip(input_txt, paraphrased_txt)]
result = "<br><br>".join(result)
result = result.replace("$", "$")
st.markdown(f"{result}", unsafe_allow_html=True)
############## ENTRY POINT END #######################
if __name__ == "__main__":
main() |