File size: 5,065 Bytes
6725d4c
 
 
 
 
 
d26463a
 
6725d4c
d26463a
 
 
 
 
6725d4c
d26463a
 
 
4272847
d26463a
6725d4c
 
 
 
 
 
 
 
 
4272847
d26463a
4272847
6725d4c
 
d26463a
6725d4c
 
d26463a
 
6725d4c
 
 
d26463a
 
 
 
6725d4c
 
 
d26463a
6725d4c
 
 
 
 
 
 
 
 
 
 
 
d26463a
6725d4c
d26463a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6725d4c
 
d26463a
6725d4c
 
4272847
6725d4c
 
d26463a
 
6725d4c
 
d26463a
 
6725d4c
d26463a
6725d4c
d26463a
6725d4c
 
 
 
d26463a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration, pipeline
from sentence_transformers import SentenceTransformer, util
import requests
import os
import warnings
from transformers import logging

# Suppress warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

# Set API keys and environment variables
GROQ_API_KEY = os.getenv("GROQ_API_KEY")  # Ensure you set this in Hugging Face Spaces
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

# Groq API sentence segmentation
def segment_into_sentences_groq(passage):
    headers = {
        "Authorization": f"Bearer {GROQ_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": "llama3-8b-8192",
        "messages": [
            {
                "role": "system",
                "content": "Segment sentences by adding '1!2@3#' at the end of each sentence."
            },
            {
                "role": "user",
                "content": f"Segment the passage: {passage}"
            }
        ],
        "temperature": 1.0,
        "max_tokens": 8192
    }
    response = requests.post("https://api.groq.com/openai/v1/chat/completions", json=payload, headers=headers)
    if response.status_code == 200:
        data = response.json()
        segmented_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
        sentences = segmented_text.split("1!2@3#")
        return [sentence.strip() for sentence in sentences if sentence.strip()]
    else:
        raise ValueError(f"Groq API error: {response.text}")

# Text enhancement class
class TextEnhancer:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.paraphrase_tokenizer = AutoTokenizer.from_pretrained("prithivida/parrot_paraphraser_on_T5")
        self.paraphrase_model = T5ForConditionalGeneration.from_pretrained("prithivida/parrot_paraphraser_on_T5").to(self.device)
        self.grammar_pipeline = pipeline(
            "text2text-generation",
            model="Grammarly/coedit-large",
            device=0 if self.device == "cuda" else -1
        )
        self.similarity_model = SentenceTransformer('paraphrase-MiniLM-L6-v2').to(self.device)

    def enhance_text(self, text, min_similarity=0.8, max_variations=2):
        sentences = segment_into_sentences_groq(text)
        enhanced_sentences = []

        for sentence in sentences:
            if not sentence.strip():
                continue

            # Generate paraphrases
            inputs = self.paraphrase_tokenizer(
                f"paraphrase: {sentence}",
                return_tensors="pt",
                padding=True,
                max_length=150,
                truncation=True
            ).to(self.device)
            
            outputs = self.paraphrase_model.generate(
                **inputs,
                max_length=150,
                num_return_sequences=max_variations,
                num_beams=max_variations
            )
            paraphrases = [
                self.paraphrase_tokenizer.decode(output, skip_special_tokens=True)
                for output in outputs
            ]

            # Calculate semantic similarity
            sentence_embedding = self.similarity_model.encode(sentence)
            paraphrase_embeddings = self.similarity_model.encode(paraphrases)
            similarities = util.cos_sim(sentence_embedding, paraphrase_embeddings)

            # Select the most similar paraphrase
            valid_paraphrases = [
                para for para, sim in zip(paraphrases, similarities[0])
                if sim >= min_similarity
            ]
            if valid_paraphrases:
                corrected = self.grammar_pipeline(
                    valid_paraphrases[0],
                    max_length=150,
                    num_return_sequences=1
                )[0]["generated_text"]
                enhanced_sentences.append(corrected)
            else:
                enhanced_sentences.append(sentence)

        return ". ".join(enhanced_sentences).strip() + "."

# Gradio interface
def create_interface():
    enhancer = TextEnhancer()

    def process_text(text, similarity_threshold):
        try:
            return enhancer.enhance_text(text, min_similarity=similarity_threshold / 100)
        except Exception as e:
            return f"Error: {str(e)}"

    return gr.Interface(
        fn=process_text,
        inputs=[
            gr.Textbox(lines=10, placeholder="Enter text to enhance...", label="Input Text"),
            gr.Slider(50, 100, 80, label="Minimum Semantic Similarity (%)")
        ],
        outputs=gr.Textbox(lines=10, label="Enhanced Text"),
        title="Text Enhancement System",
        description="Enhance text quality with semantic preservation."
    )

if __name__ == "__main__":
    interface = create_interface()
    interface.launch(server_name="0.0.0.0", server_port=7860)