Spaces:
Sleeping
Sleeping
File size: 5,065 Bytes
6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 4272847 d26463a 6725d4c 4272847 d26463a 4272847 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c 4272847 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a 6725d4c d26463a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from sentence_transformers import SentenceTransformer, util
import requests
import os
import warnings
from transformers import logging
# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
# Set API keys and environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY") # Ensure you set this in Hugging Face Spaces
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# Groq API sentence segmentation
def segment_into_sentences_groq(passage):
headers = {
"Authorization": f"Bearer {GROQ_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "llama3-8b-8192",
"messages": [
{
"role": "system",
"content": "Segment sentences by adding '1!2@3#' at the end of each sentence."
},
{
"role": "user",
"content": f"Segment the passage: {passage}"
}
],
"temperature": 1.0,
"max_tokens": 8192
}
response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
if response.status_code == 200:
data = response.json()
segmented_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
sentences = segmented_text.split("1!2@3#")
return [sentence.strip() for sentence in sentences if sentence.strip()]
else:
raise ValueError(f"Groq API error: {response.text}")
# Text enhancement class
class TextEnhancer:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
self.grammar_pipeline = pipeline(
"text2text-generation",
model="Grammarly/coedit-large",
device=0 if self.device == "cuda" else -1
)
self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)
def enhance_text(self, text, min_similarity=0.8, max_variations=2):
sentences = segment_into_sentences_groq(text)
enhanced_sentences = []
for sentence in sentences:
if not sentence.strip():
continue
# Generate paraphrases
inputs = self.paraphrase_tokenizer(
f"paraphrase: {sentence}",
return_tensors="pt",
padding=True,
max_length=150,
truncation=True
).to(self.device)
outputs = self.paraphrase_model.generate(
**inputs,
max_length=150,
num_return_sequences=max_variations,
num_beams=max_variations
)
paraphrases = [
self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
# Calculate semantic similarity
sentence_embedding = self.similarity_model.encode(sentence)
paraphrase_embeddings = self.similarity_model.encode(paraphrases)
similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings)
# Select the most similar paraphrase
valid_paraphrases = [
para for para, sim in zip(paraphrases, similarities[0])
if sim >= min_similarity
]
if valid_paraphrases:
corrected = self.grammar_pipeline(
valid_paraphrases[0],
max_length=150,
num_return_sequences=1
)[0]["generated_text"]
enhanced_sentences.append(corrected)
else:
enhanced_sentences.append(sentence)
return ". ".join(enhanced_sentences).strip() + "."
# Gradio interface
def create_interface():
enhancer = TextEnhancer()
def process_text(text, similarity_threshold):
try:
return enhancer.enhance_text(text, min_similarity=similarity_threshold / 100)
except Exception as e:
return f"Error: {str(e)}"
return gr.Interface(
fn=process_text,
inputs=[
gr.Textbox(lines=10, placeholder="Enter text to enhance...", label="Input Text"),
gr.Slider(50, 100, 80, label="Minimum Semantic Similarity (%)")
],
outputs=gr.Textbox(lines=10, label="Enhanced Text"),
title="Text Enhancement System",
description="Enhance text quality with semantic preservation."
)
if __name__ == "__main__":
interface = create_interface()
interface.launch(server_name="0.0.0.0", server_port=7860)
|