DualTextOCRFusion / ocr_cpu.py
UniquePratham's picture
Update ocr_cpu.py
b0416c1 verified
raw
history blame
4.39 kB
import os
from transformers import AutoModel, AutoTokenizer
import torch
# Load model and tokenizer
# model_name = "ucaslcl/GOT-OCR2_0"
model_name = "srimanth-d/GOT_CPU"
tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True, return_tensors='pt'
)
# Load the model
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id,
)
# Ensure the model is in evaluation mode and loaded on CPU
device = torch.device("cpu")
dtype = torch.float32 # Use float32 on CPU
model = model.eval().to(device)
# OCR function
def extract_text_got(uploaded_file):
"""Use GOT-OCR2.0 model to extract text from the uploaded image."""
temp_file_path = 'temp_image.jpg'
try:
# Save the uploaded file temporarily
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
print(f"Processing image from path: {temp_file_path}") # Debug info
ocr_types = ['ocr', 'format']
fine_grained_options = ['ocr', 'format']
color_options = ['red', 'green', 'blue']
box = [10, 10, 100, 100] # Example box for demonstration
multi_crop_types = ['ocr', 'format']
results = []
# Run basic OCR types
for ocr_type in ocr_types:
with torch.no_grad():
print(f"Running basic OCR with type: {ocr_type}") # Debug info
outputs = model.chat(tokenizer, temp_file_path, ocr_type=ocr_type)
# Debug outputs
print(f"Outputs for {ocr_type}: {outputs}")
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return if successful
results.append(outputs[0].strip() if outputs else "No result")
# Try FINE-GRAINED OCR with box options
for ocr_type in fine_grained_options:
with torch.no_grad():
print(f"Running fine-grained OCR with box, type: {ocr_type}") # Debug info
outputs = model.chat(tokenizer, temp_file_path, ocr_type=ocr_type, ocr_box=box)
print(f"Outputs for {ocr_type} with box: {outputs}")
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return if successful
results.append(outputs[0].strip() if outputs else "No result")
# Try FINE-GRAINED OCR with color options
for ocr_type in fine_grained_options:
for color in color_options:
with torch.no_grad():
print(f"Running fine-grained OCR with color {color}, type: {ocr_type}") # Debug info
outputs = model.chat(tokenizer, temp_file_path, ocr_type=ocr_type, ocr_color=color)
print(f"Outputs for {ocr_type} with color {color}: {outputs}")
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return if successful
results.append(outputs[0].strip() if outputs else "No result")
# Try MULTI-CROP OCR
for ocr_type in multi_crop_types:
with torch.no_grad():
print(f"Running multi-crop OCR with type: {ocr_type}") # Debug info
outputs = model.chat_crop(tokenizer, temp_file_path, ocr_type=ocr_type)
print(f"Outputs for multi-crop {ocr_type}: {outputs}")
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return if successful
results.append(outputs[0].strip() if outputs else "No result")
# Return combined results or no text found message
if all(not text for text in results):
return "No text extracted."
else:
return "\n".join(results)
except Exception as e:
return f"Error during text extraction: {str(e)}"
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"Temporary file {temp_file_path} removed.") # Debug info