Spaces:
Running
Running
File size: 7,933 Bytes
3cb2a3f b16db73 aa47a7c 42cb48e b16db73 42cb48e b16db73 aa47a7c b16db73 0efdb28 42cb48e aa47a7c 42cb48e aa47a7c 42cb48e 1a5d3d0 0efdb28 b16db73 aa47a7c b16db73 55c903d 7da5361 b16db73 aa47a7c b16db73 0efdb28 b16db73 55c903d aa47a7c b16db73 0efdb28 b16db73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import streamlit as st
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
from surya.ocr import run_ocr
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from PIL import Image
import torch
import tempfile
import os
import re
import json
from groq import Groq
# Page configuration
st.set_page_config(page_title="DualTextOCRFusion", page_icon="π", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Directories for images and results
IMAGES_DIR = "images"
RESULTS_DIR = "results"
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
# Load Surya OCR Models (English + Hindi)
det_processor, det_model = load_det_processor(), load_det_model()
det_model.to(device)
rec_model, rec_processor = load_rec_model(), load_rec_processor()
rec_model.to(device)
# Load GOT Models
@st.cache_resource
def init_got_model():
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
return model.eval(), tokenizer
@st.cache_resource
def init_got_gpu_model():
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
return model.eval().cuda(), tokenizer
# Load Qwen Model
@st.cache_resource
def init_qwen_model():
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
return model.eval(), processor
# Text Cleaning AI - Clean spaces, handle dual languages
def clean_extracted_text(text):
cleaned_text = re.sub(r'\s+', ' ', text).strip()
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
return cleaned_text
# Polish the text using a model
def polish_text_with_ai(cleaned_text):
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces."
client = Groq(api_key="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")
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a pedantic sentence corrector."},
{"role": "user", "content": prompt},
],
model="gemma2-9b-it",
)
polished_text = chat_completion.choices[0].message.content
return polished_text
# Extract text using GOT
def extract_text_got(image_file, model, tokenizer):
return model.chat(tokenizer, image_file, ocr_type='ocr')
# Extract text using Qwen
def extract_text_qwen(image_file, model, processor):
try:
image = Image.open(image_file).convert('RGB')
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
output_ids = model.generate(**inputs)
output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
return output_text[0] if output_text else "No text extracted from the image."
except Exception as e:
return f"An error occurred: {str(e)}"
# Highlight keyword search
def highlight_text(text, search_term):
if not search_term:
return text
pattern = re.compile(re.escape(search_term), re.IGNORECASE)
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
# Title and UI
st.title("DualTextOCRFusion - π")
st.header("OCR Application - Multimodel Support")
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.")
# Sidebar Configuration
st.sidebar.header("Configuration")
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))
# Upload Section
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
clipboard_text = st.sidebar.text_area("Paste image path from clipboard:")
if uploaded_file or clipboard_text:
image_path = None
if uploaded_file:
image_path = os.path.join(IMAGES_DIR, uploaded_file.name)
with open(image_path, "wb") as f:
f.write(uploaded_file.getvalue())
elif clipboard_text:
image_path = clipboard_text.strip()
# Predict button
predict_button = st.sidebar.button("Predict")
# Main columns
col1, col2 = st.columns([2, 1])
# Check if result JSON already exists
result_json_path = os.path.join(RESULTS_DIR, f"{os.path.basename(image_path)}_result.json") if image_path else None
if predict_button and image_path:
if os.path.exists(result_json_path):
with open(result_json_path, "r") as json_file:
result_data = json.load(json_file)
polished_text = result_data.get("polished_text", "")
else:
with st.spinner("Processing..."):
image = Image.open(image_path).convert("RGB")
if model_choice == "GOT_CPU":
got_model, tokenizer = init_got_model()
extracted_text = extract_text_got(image_path, got_model, tokenizer)
elif model_choice == "GOT_GPU":
got_gpu_model, tokenizer = init_got_gpu_model()
extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer)
elif model_choice == "Qwen":
qwen_model, qwen_processor = init_qwen_model()
extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor)
elif model_choice == "Surya (English+Hindi)":
langs = ["en", "hi"]
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
extracted_text = ' '.join(text_list)
cleaned_text = clean_extracted_text(extracted_text)
polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text
# Save result to JSON
with open(result_json_path, "w") as json_file:
json.dump({"polished_text": polished_text}, json_file)
# Display image preview and text
if image_path:
with col1:
col1.image(image_path, caption='Uploaded Image', use_column_width=False, width=300)
st.subheader("Extracted Text (Cleaned & Polished)")
st.markdown(polished_text, unsafe_allow_html=True)
# Input box for real-time search
search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...", on_change=lambda: st.session_state.update(search_query) disabled=not uploaded_file)
# Highlight the search term in the text
if search_query:
highlighted_text = highlight_text(polished_text, search_query)
st.markdown("### Highlighted Search Results:")
st.markdown(highlighted_text, unsafe_allow_html=True)
|