Spaces:
Sleeping
Sleeping
File size: 4,040 Bytes
afedbd6 6bb168e afedbd6 6bb168e afedbd6 6bb168e afedbd6 6bb168e afedbd6 6bb168e afedbd6 6bb168e afedbd6 b0416c1 9919fac 6bb168e b0416c1 6bb168e b0416c1 9919fac 6bb168e 9919fac 6bb168e 9919fac afedbd6 6bb168e 9919fac 6bb168e 9919fac 6bb168e b0416c1 6bb168e 9919fac afedbd6 8b34af2 9919fac afedbd6 9919fac 8b34af2 afedbd6 8b34af2 afedbd6 9919fac 8b34af2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
# ocr_cpu.py
import os
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
import re
# -----------------------------
# OCR Model Initialization
# -----------------------------
# Load OCR model and tokenizer
ocr_model_name = "srimanth-d/GOT_CPU" # Using GOT model on CPU
ocr_tokenizer = AutoTokenizer.from_pretrained(
ocr_model_name, trust_remote_code=True, return_tensors='pt'
)
# Load the OCR model
ocr_model = AutoModel.from_pretrained(
ocr_model_name,
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
pad_token_id=ocr_tokenizer.eos_token_id,
)
# Ensure the OCR model is in evaluation mode and loaded on CPU
ocr_device = torch.device("cpu")
ocr_model = ocr_model.eval().to(ocr_device)
# -----------------------------
# Text Cleaning Model Initialization
# -----------------------------
# Load Text Cleaning model and tokenizer
clean_model_name = "gpt2" # You can choose a different model if preferred
clean_tokenizer = AutoTokenizer.from_pretrained(clean_model_name)
clean_model = AutoModelForCausalLM.from_pretrained(clean_model_name)
# Ensure the Text Cleaning model is in evaluation mode and loaded on CPU
clean_device = torch.device("cpu")
clean_model = clean_model.eval().to(clean_device)
# -----------------------------
# OCR Function
# -----------------------------
def extract_text_got(uploaded_file):
"""
Use GOT-OCR2.0 model to extract text from the uploaded image.
"""
temp_file_path = 'temp_image.jpg'
try:
# Save the uploaded file temporarily
with open(temp_file_path, 'wb') as temp_file:
temp_file.write(uploaded_file.read())
print(f"Processing image from path: {temp_file_path}")
ocr_types = ['ocr', 'format']
results = []
# Run OCR on the image
for ocr_type in ocr_types:
with torch.no_grad():
print(f"Running OCR with type: {ocr_type}")
outputs = ocr_model.chat(ocr_tokenizer, temp_file_path, ocr_type=ocr_type)
if isinstance(outputs, list) and outputs[0].strip():
return outputs[0].strip() # Return the result if successful
results.append(outputs[0].strip() if outputs else "No result")
# Combine results or return no text found message
return results[0] if results else "No text extracted."
except Exception as e:
return f"Error during text extraction: {str(e)}"
finally:
# Clean up temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
print(f"Temporary file {temp_file_path} removed.")
# -----------------------------
# Text Cleaning Function
# -----------------------------
def clean_text_with_ai(extracted_text):
"""
Cleans extracted text by leveraging a language model to intelligently remove extra spaces and correct formatting.
"""
try:
# Define the prompt for cleaning
prompt = f"Please clean the following text by removing extra spaces and ensuring proper formatting:\n\n{extracted_text}\n\nCleaned Text:"
# Tokenize the input prompt
inputs = clean_tokenizer.encode(prompt, return_tensors="pt").to(clean_device)
# Generate the cleaned text
with torch.no_grad():
outputs = clean_model.generate(
inputs,
max_length=500, # Adjust as needed
temperature=0.7,
top_p=0.9,
do_sample=True,
eos_token_id=clean_tokenizer.eos_token_id,
pad_token_id=clean_tokenizer.eos_token_id
)
# Decode the generated text
cleaned_text = clean_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the cleaned text after the prompt
cleaned_text = cleaned_text.split("Cleaned Text:")[-1].strip()
return cleaned_text
except Exception as e:
return f"Error during AI text cleaning: {str(e)}"
|