File size: 6,116 Bytes
3cb2a3f
b16db73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0652de
b16db73
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import streamlit as st
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
from surya.ocr import run_ocr
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from PIL import Image
import torch
import tempfile
import os
import re

# Page configuration
st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Surya OCR Models (English + Hindi)
det_processor, det_model = load_det_processor(), load_det_model()
det_model.to(device)
rec_model, rec_processor = load_rec_model(), load_rec_processor()
rec_model.to(device)

# Load GOT Models
@st.cache_resource
def init_got_model():
    tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
    model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    return model.eval(), tokenizer

@st.cache_resource
def init_got_gpu_model():
    tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
    model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
    return model.eval().cuda(), tokenizer

# Load Qwen Model
@st.cache_resource
def init_qwen_model():
    model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
    return model.eval(), processor

# Extract text using GOT
def extract_text_got(image_file, model, tokenizer):
    return model.chat(tokenizer, image_file, ocr_type='ocr')

# Extract text using Qwen
def extract_text_qwen(image_file, model, processor):
    try:
        image = Image.open(image_file).convert('RGB')
        conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
        text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
        output_ids = model.generate(**inputs)
        output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
        return output_text[0] if output_text else "No text extracted from the image."
    except Exception as e:
        return f"An error occurred: {str(e)}"

# Text Cleaning AI - Clean spaces, handle dual languages
def clean_extracted_text(text):
    # Remove extra spaces
    cleaned_text = re.sub(r'\s+', ' ', text).strip()
    cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
    return cleaned_text

# Highlight keyword search
def highlight_text(text, search_term):
    if not search_term:
        return text
    pattern = re.compile(re.escape(search_term), re.IGNORECASE)
    return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)

# Title and UI
st.title("OCR Application - Multimodel Support")
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.")

# Sidebar Configuration
st.sidebar.header("Configuration")
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))

# Upload Section
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])

# Predict button
predict_button = st.sidebar.button("Predict")

# Main columns
col1, col2 = st.columns([2, 1])

# Display image preview
if uploaded_file:
    image = Image.open(uploaded_file)
    with col1:
        col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)

# Handle predictions
if predict_button and uploaded_file:
    with st.spinner("Processing..."):
        # Save uploaded image
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            temp_file.write(uploaded_file.getvalue())
            temp_file_path = temp_file.name

        image = Image.open(temp_file_path)
        image = image.convert("RGB")

        if model_choice == "GOT_CPU":
            got_model, tokenizer = init_got_model()
            extracted_text = extract_text_got(temp_file_path, got_model, tokenizer)

        elif model_choice == "GOT_GPU":
            got_gpu_model, tokenizer = init_got_gpu_model()
            extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer)

        elif model_choice == "Qwen":
            qwen_model, qwen_processor = init_qwen_model()
            extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor)

        elif model_choice == "Surya (English+Hindi)":
            langs = ["en", "hi"]
            predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
            text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
            extracted_text = ' '.join(text_list)

        # Clean extracted text
        cleaned_text = clean_extracted_text(extracted_text)

        # Delete temp file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)

        # Display extracted text and search functionality
        st.subheader("Extracted Text (Cleaned)")
        st.markdown(cleaned_text, unsafe_allow_html=True)

        search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...")
        if search_query:
            highlighted_text = highlight_text(cleaned_text, search_query)
            st.markdown("### Highlighted Search Results:")
            st.markdown(highlighted_text, unsafe_allow_html=True)
        else:
            st.markdown("### Extracted Text:")
            st.markdown(cleaned_text, unsafe_allow_html=True)