Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor | |
from surya.ocr import run_ocr | |
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.recognition.model import load_model as load_rec_model | |
from surya.model.recognition.processor import load_processor as load_rec_processor | |
from PIL import Image | |
import torch | |
import tempfile | |
import os | |
import re | |
import base64 | |
from groq import Groq | |
# Page configuration | |
st.set_page_config(page_title="DualTextOCRFusion", page_icon="π", layout="wide") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Surya OCR Models (English + Hindi) | |
det_processor, det_model = load_det_processor(), load_det_model() | |
det_model.to(device) | |
rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
rec_model.to(device) | |
# Load GOT Models | |
def init_got_model(): | |
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) | |
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval(), tokenizer | |
def init_got_gpu_model(): | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval().cuda(), tokenizer | |
# Load Qwen Model | |
def init_qwen_model(): | |
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
return model.eval(), processor | |
# Text Cleaning AI - Clean spaces, handle dual languages | |
def clean_extracted_text(text): | |
# Remove extra spaces | |
cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
return cleaned_text | |
# Polish the text using a model | |
def polish_text_with_ai(cleaned_text): | |
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces, keeping it as close to the original as possible, along with relevant details or insights that an AI can provide about the extracted text. Extracted Text : {cleaned_text}" | |
client = Groq(api_key="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") | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are a pedantic sentence corrector. Remove extra spaces between and within words to make the sentence meaningful in English, Hindi, or Hinglish, according to the context of the sentence, without changing any words." | |
}, | |
{ | |
"role": "user", | |
"content": prompt, | |
} | |
], | |
model="gemma2-9b-it", | |
) | |
polished_text = chat_completion.choices[0].message.content | |
return polished_text | |
# Extract text using GOT | |
def extract_text_got(image_file, model, tokenizer): | |
return model.chat(tokenizer, image_file, ocr_type='ocr') | |
# Extract text using Qwen | |
def extract_text_qwen(image_file, model, processor): | |
try: | |
image = Image.open(image_file).convert('RGB') | |
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") | |
output_ids = model.generate(**inputs) | |
output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
return output_text[0] if output_text else "No text extracted from the image." | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Highlight keyword search | |
def highlight_text(text, search_term): | |
if not search_term: # If no search term is provided, return the original text | |
return text | |
# Use a regular expression to search for the term, case insensitive | |
pattern = re.compile(re.escape(search_term), re.IGNORECASE) | |
# Highlight matched terms with yellow background | |
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text) | |
# Title and UI | |
st.title("DualTextOCRFusion - π") | |
st.header("OCR Application - Multimodel Support") | |
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") | |
# Sidebar Configuration | |
st.sidebar.header("Configuration") | |
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) | |
# Upload Section | |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
# Input from clipboard | |
if st.sidebar.button("Paste from Clipboard"): | |
try: | |
clipboard_data = st.experimental_get_clipboard() | |
if clipboard_data: | |
# Assuming clipboard data is base64 encoded image | |
image_data = base64.b64decode(clipboard_data) | |
uploaded_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png") | |
uploaded_file.write(image_data) | |
uploaded_file.seek(0) | |
except: | |
st.sidebar.warning("Clipboard data is not an image.") | |
# Input from camera | |
camera_file = st.sidebar.camera_input("Capture from Camera") | |
if camera_file: | |
uploaded_file = camera_file | |
# Predict button | |
predict_button = st.sidebar.button("Predict") | |
# Main columns | |
col1, col2 = st.columns([2, 1]) | |
# Display image preview | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
with col1: | |
col1.image(image, caption='Uploaded Image', use_column_width=False, width=300) | |
# Handle predictions | |
if predict_button and uploaded_file: | |
with st.spinner("Processing..."): | |
# Save uploaded image | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
temp_file.write(uploaded_file.getvalue()) | |
temp_file_path = temp_file.name | |
image = Image.open(temp_file_path) | |
image = image.convert("RGB") | |
if model_choice == "GOT_CPU": | |
got_model, tokenizer = init_got_model() | |
extracted_text = extract_text_got(temp_file_path, got_model, tokenizer) | |
elif model_choice == "GOT_GPU": | |
got_gpu_model, tokenizer = init_got_gpu_model() | |
extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer) | |
elif model_choice == "Qwen": | |
qwen_model, qwen_processor = init_qwen_model() | |
extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor) | |
elif model_choice == "Surya (English+Hindi)": | |
langs = ["en", "hi"] | |
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) | |
text_list = re.findall(r"text='(.*?)'", str(predictions[0])) | |
extracted_text = ' '.join(text_list) | |
# Clean extracted text | |
cleaned_text = clean_extracted_text(extracted_text) | |
# Optionally, polish text with AI model for better language flow | |
polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text | |
# Delete temp file | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
# Display extracted text and search functionality | |
st.subheader("Extracted Text (Cleaned & Polished)") | |
st.markdown(polished_text, unsafe_allow_html=True) | |
# Input box for real-time search | |
search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...") | |
# Update results dynamically based on the search term | |
if search_query: | |
# Highlight the search term in the text | |
highlighted_text = highlight_text(polished_text, search_query) | |
st.markdown("### Highlighted Search Results:") | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
else: | |
st.markdown("### Extracted Text:") | |
st.markdown(polished_text) | |