Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor | |
from surya.ocr import run_ocr | |
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.recognition.model import load_model as load_rec_model | |
from surya.model.recognition.processor import load_processor as load_rec_processor | |
from PIL import Image | |
import torch | |
import tempfile | |
import os | |
import re | |
# Page configuration | |
st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Surya OCR Models (English + Hindi) | |
det_processor, det_model = load_det_processor(), load_det_model() | |
det_model.to(device) | |
rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
rec_model.to(device) | |
# Load GOT Models | |
def init_got_model(): | |
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) | |
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval(), tokenizer | |
def init_got_gpu_model(): | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval().cuda(), tokenizer | |
# Load Qwen Model | |
def init_qwen_model(): | |
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
return model.eval(), processor | |
# Extract text using GOT | |
def extract_text_got(image_file, model, tokenizer): | |
return model.chat(tokenizer, image_file, ocr_type='ocr') | |
# Extract text using Qwen | |
def extract_text_qwen(image_file, model, processor): | |
try: | |
image = Image.open(image_file).convert('RGB') | |
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") | |
output_ids = model.generate(**inputs) | |
output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
return output_text[0] if output_text else "No text extracted from the image." | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Text Cleaning AI - Clean spaces, handle dual languages | |
def clean_extracted_text(text): | |
# Remove extra spaces | |
cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
return cleaned_text | |
# Highlight keyword search | |
def highlight_text(text, search_term): | |
if not search_term: | |
return text | |
pattern = re.compile(re.escape(search_term), re.IGNORECASE) | |
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text) | |
# Title and UI | |
st.title("OCR Application - Multimodel Support") | |
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") | |
# Sidebar Configuration | |
st.sidebar.header("Configuration") | |
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) | |
# Upload Section | |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
# Predict button | |
predict_button = st.sidebar.button("Predict") | |
# Main columns | |
col1, col2 = st.columns([2, 1]) | |
# Display image preview | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
with col1: | |
col1.image(image, caption='Uploaded Image', use_column_width=False, width=300) | |
# Handle predictions | |
if predict_button and uploaded_file: | |
with st.spinner("Processing..."): | |
# Save uploaded image | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
temp_file.write(uploaded_file.getvalue()) | |
temp_file_path = temp_file.name | |
image = Image.open(temp_file_path) | |
image = image.convert("RGB") | |
if model_choice == "GOT_CPU": | |
got_model, tokenizer = init_got_model() | |
extracted_text = extract_text_got(temp_file_path, got_model, tokenizer) | |
elif model_choice == "GOT_GPU": | |
got_gpu_model, tokenizer = init_got_gpu_model() | |
extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer) | |
elif model_choice == "Qwen": | |
qwen_model, qwen_processor = init_qwen_model() | |
extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor) | |
elif model_choice == "Surya (English+Hindi)": | |
langs = ["en", "hi"] | |
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) | |
text_list = re.findall(r"text='(.*?)'", str(predictions[0])) | |
extracted_text = ' '.join(text_list) | |
# Clean extracted text | |
cleaned_text = clean_extracted_text(extracted_text) | |
# Delete temp file | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
# Display extracted text and search functionality | |
st.subheader("Extracted Text (Cleaned)") | |
st.markdown(cleaned_text, unsafe_allow_html=True) | |
search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...") | |
if search_query: | |
highlighted_text = highlight_text(cleaned_text, search_query) | |
st.markdown("### Highlighted Search Results:") | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
else: | |
st.markdown("### Extracted Text:") | |
st.markdown(cleaned_text, unsafe_allow_html=True) | |