File size: 4,040 Bytes
afedbd6
 
6bb168e
 
afedbd6
 
 
 
 
 
6bb168e
afedbd6
 
 
 
 
6bb168e
afedbd6
 
 
6bb168e
 
 
afedbd6
6bb168e
 
afedbd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb168e
 
afedbd6
 
 
b0416c1
9919fac
6bb168e
b0416c1
6bb168e
b0416c1
 
9919fac
6bb168e
 
 
 
9919fac
6bb168e
 
9919fac
afedbd6
6bb168e
 
9919fac
6bb168e
 
9919fac
 
6bb168e
 
 
 
 
b0416c1
6bb168e
 
9919fac
 
afedbd6
 
 
 
8b34af2
9919fac
afedbd6
9919fac
8b34af2
afedbd6
 
 
 
 
 
 
8b34af2
afedbd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9919fac
8b34af2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# ocr_cpu.py

import os
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
import re

# -----------------------------
# OCR Model Initialization
# -----------------------------

# Load OCR model and tokenizer
ocr_model_name = "srimanth-d/GOT_CPU"  # Using GOT model on CPU
ocr_tokenizer = AutoTokenizer.from_pretrained(
    ocr_model_name, trust_remote_code=True, return_tensors='pt'
)

# Load the OCR model
ocr_model = AutoModel.from_pretrained(
    ocr_model_name,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    pad_token_id=ocr_tokenizer.eos_token_id,
)

# Ensure the OCR model is in evaluation mode and loaded on CPU
ocr_device = torch.device("cpu")
ocr_model = ocr_model.eval().to(ocr_device)

# -----------------------------
# Text Cleaning Model Initialization
# -----------------------------

# Load Text Cleaning model and tokenizer
clean_model_name = "gpt2"  # You can choose a different model if preferred
clean_tokenizer = AutoTokenizer.from_pretrained(clean_model_name)
clean_model = AutoModelForCausalLM.from_pretrained(clean_model_name)

# Ensure the Text Cleaning model is in evaluation mode and loaded on CPU
clean_device = torch.device("cpu")
clean_model = clean_model.eval().to(clean_device)

# -----------------------------
# OCR Function
# -----------------------------

def extract_text_got(uploaded_file):
    """
    Use GOT-OCR2.0 model to extract text from the uploaded image.
    """
    temp_file_path = 'temp_image.jpg'

    try:
        # Save the uploaded file temporarily
        with open(temp_file_path, 'wb') as temp_file:
            temp_file.write(uploaded_file.read())

        print(f"Processing image from path: {temp_file_path}")

        ocr_types = ['ocr', 'format']
        results = []

        # Run OCR on the image
        for ocr_type in ocr_types:
            with torch.no_grad():
                print(f"Running OCR with type: {ocr_type}")
                outputs = ocr_model.chat(ocr_tokenizer, temp_file_path, ocr_type=ocr_type)

                if isinstance(outputs, list) and outputs[0].strip():
                    return outputs[0].strip()  # Return the result if successful
                results.append(outputs[0].strip() if outputs else "No result")

        # Combine results or return no text found message
        return results[0] if results else "No text extracted."

    except Exception as e:
        return f"Error during text extraction: {str(e)}"

    finally:
        # Clean up temporary file
        if os.path.exists(temp_file_path):
            os.remove(temp_file_path)
            print(f"Temporary file {temp_file_path} removed.")

# -----------------------------
# Text Cleaning Function
# -----------------------------

def clean_text_with_ai(extracted_text):
    """
    Cleans extracted text by leveraging a language model to intelligently remove extra spaces and correct formatting.
    """
    try:
        # Define the prompt for cleaning
        prompt = f"Please clean the following text by removing extra spaces and ensuring proper formatting:\n\n{extracted_text}\n\nCleaned Text:"

        # Tokenize the input prompt
        inputs = clean_tokenizer.encode(prompt, return_tensors="pt").to(clean_device)

        # Generate the cleaned text
        with torch.no_grad():
            outputs = clean_model.generate(
                inputs,
                max_length=500,  # Adjust as needed
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                eos_token_id=clean_tokenizer.eos_token_id,
                pad_token_id=clean_tokenizer.eos_token_id
            )

        # Decode the generated text
        cleaned_text = clean_tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract the cleaned text after the prompt
        cleaned_text = cleaned_text.split("Cleaned Text:")[-1].strip()

        return cleaned_text

    except Exception as e:
        return f"Error during AI text cleaning: {str(e)}"