File size: 7,933 Bytes
3cb2a3f
b16db73
 
 
 
 
 
 
 
 
 
aa47a7c
42cb48e
b16db73
 
42cb48e
b16db73
 
aa47a7c
 
 
 
 
 
b16db73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0efdb28
 
 
 
 
 
 
42cb48e
aa47a7c
42cb48e
 
aa47a7c
 
 
42cb48e
 
 
1a5d3d0
0efdb28
 
b16db73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa47a7c
b16db73
 
 
 
 
55c903d
7da5361
b16db73
 
 
 
 
 
 
 
aa47a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b16db73
0efdb28
 
b16db73
55c903d
aa47a7c
 
 
b16db73
0efdb28
b16db73
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
from surya.ocr import run_ocr
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from PIL import Image
import torch
import tempfile
import os
import re
import json
from groq import Groq

# Page configuration
st.set_page_config(page_title="DualTextOCRFusion", page_icon="πŸ”", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Directories for images and results
IMAGES_DIR = "images"
RESULTS_DIR = "results"
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Load Surya OCR Models (English + Hindi)
det_processor, det_model = load_det_processor(), load_det_model()
det_model.to(device)
rec_model, rec_processor = load_rec_model(), load_rec_processor()
rec_model.to(device)

# Load GOT Models
@st.cache_resource
def init_got_model():
    tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
    model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    return model.eval(), tokenizer

@st.cache_resource
def init_got_gpu_model():
    tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
    model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    return model.eval().cuda(), tokenizer

# Load Qwen Model
@st.cache_resource
def init_qwen_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    return model.eval(), processor

# Text Cleaning AI - Clean spaces, handle dual languages
def clean_extracted_text(text):
    cleaned_text = re.sub(r'\s+', ' ', text).strip()
    cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
    return cleaned_text

# Polish the text using a model
def polish_text_with_ai(cleaned_text):
    prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces."
    client = Groq(api_key="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")
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a pedantic sentence corrector."},
            {"role": "user", "content": prompt},
        ],
        model="gemma2-9b-it",
    )
    polished_text = chat_completion.choices[0].message.content
    return polished_text

# Extract text using GOT
def extract_text_got(image_file, model, tokenizer):
    return model.chat(tokenizer, image_file, ocr_type='ocr')

# Extract text using Qwen
def extract_text_qwen(image_file, model, processor):
    try:
        image = Image.open(image_file).convert('RGB')
        conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
        text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
        output_ids = model.generate(**inputs)
        output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
        return output_text[0] if output_text else "No text extracted from the image."
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Highlight keyword search
def highlight_text(text, search_term):
    if not search_term:
        return text
    pattern = re.compile(re.escape(search_term), re.IGNORECASE)
    return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)

# Title and UI
st.title("DualTextOCRFusion - πŸ”")
st.header("OCR Application - Multimodel Support")
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.")

# Sidebar Configuration
st.sidebar.header("Configuration")
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))

# Upload Section
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
clipboard_text = st.sidebar.text_area("Paste image path from clipboard:")

if uploaded_file or clipboard_text:
    image_path = None
    if uploaded_file:
        image_path = os.path.join(IMAGES_DIR, uploaded_file.name)
        with open(image_path, "wb") as f:
            f.write(uploaded_file.getvalue())
    elif clipboard_text:
        image_path = clipboard_text.strip()

    # Predict button
    predict_button = st.sidebar.button("Predict")

    # Main columns
    col1, col2 = st.columns([2, 1])

    # Check if result JSON already exists
    result_json_path = os.path.join(RESULTS_DIR, f"{os.path.basename(image_path)}_result.json") if image_path else None

    if predict_button and image_path:
        if os.path.exists(result_json_path):
            with open(result_json_path, "r") as json_file:
                result_data = json.load(json_file)
            polished_text = result_data.get("polished_text", "")
        else:
            with st.spinner("Processing..."):
                image = Image.open(image_path).convert("RGB")

                if model_choice == "GOT_CPU":
                    got_model, tokenizer = init_got_model()
                    extracted_text = extract_text_got(image_path, got_model, tokenizer)

                elif model_choice == "GOT_GPU":
                    got_gpu_model, tokenizer = init_got_gpu_model()
                    extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer)

                elif model_choice == "Qwen":
                    qwen_model, qwen_processor = init_qwen_model()
                    extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor)

                elif model_choice == "Surya (English+Hindi)":
                    langs = ["en", "hi"]
                    predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
                    text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
                    extracted_text = ' '.join(text_list)

                cleaned_text = clean_extracted_text(extracted_text)
                polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text

                # Save result to JSON
                with open(result_json_path, "w") as json_file:
                    json.dump({"polished_text": polished_text}, json_file)

        # Display image preview and text
        if image_path:
            with col1:
                col1.image(image_path, caption='Uploaded Image', use_column_width=False, width=300)

        st.subheader("Extracted Text (Cleaned & Polished)")
        st.markdown(polished_text, unsafe_allow_html=True)

        # Input box for real-time search
        search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...", on_change=lambda: st.session_state.update(search_query) disabled=not uploaded_file)
        
        # Highlight the search term in the text
        if search_query:
            highlighted_text = highlight_text(polished_text, search_query)
            st.markdown("### Highlighted Search Results:")
            st.markdown(highlighted_text, unsafe_allow_html=True)