File size: 5,259 Bytes
3c6abc9
09a5bc7
3c6abc9
 
 
 
3db54ae
3c6abc9
 
3db54ae
3c6abc9
 
 
 
3db54ae
755d5e1
3c6abc9
 
755d5e1
3c6abc9
 
755d5e1
 
 
 
 
 
 
 
3c6abc9
755d5e1
3c6abc9
b975069
0e619d7
3db54ae
3c6abc9
 
 
 
 
 
 
 
 
 
 
3db54ae
3c6abc9
 
 
 
 
 
 
 
 
 
 
 
 
 
3db54ae
3c6abc9
3db54ae
3c6abc9
 
3db54ae
3c6abc9
 
3db54ae
3c6abc9
 
 
 
 
 
 
 
 
3db54ae
3c6abc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3db54ae
3c6abc9
 
3db54ae
3c6abc9
 
 
3db54ae
3c6abc9
 
3db54ae
 
 
 
 
09a5bc7
 
3db54ae
 
 
09a5bc7
3db54ae
 
 
 
393508a
3db54ae
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import streamlit as st
from PIL import Image
from deepseek_vl2.serve.inference import load_model, deepseek_generate, convert_conversation_to_prompts
from deepseek_vl2.serve.app_modules.utils import configure_logger, strip_stop_words, pil_to_base64

# Set up logging
logger = configure_logger()

# Models and deployment
MODELS = ["deepseek-ai/deepseek-vl2-tiny"]
DEPLOY_MODELS = {}
IMAGE_TOKEN = "<image>"

# Fetch model
def fetch_model(model_name: str, dtype=torch.bfloat16):
    global DEPLOY_MODELS
    if model_name not in DEPLOY_MODELS:
        logger.info(f"Loading {model_name}...")
        model_info = load_model(model_name, dtype=dtype)
        tokenizer, model, vl_chat_processor = model_info
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        try:
            model = model.to(device)
        except RuntimeError as e:
            logger.warning(f"Could not move model to {device}: {e}")
            device = torch.device('cpu')
            model = model.to(device)
            logger.warning("Model fallback to CPU. Inference might be slow.")
        DEPLOY_MODELS[model_name] = (tokenizer, model, vl_chat_processor)
        logger.info(f"Loaded {model_name} on {device}")
    return DEPLOY_MODELS[model_name]

    
# Generate prompt with history
def generate_prompt_with_history(text, images, history, vl_chat_processor, tokenizer, max_length=2048):
    conversation = vl_chat_processor.new_chat_template()
    if history:
        conversation.messages = history
    if images:
        text = f"{IMAGE_TOKEN}\n{text}"
        text = (text, images)
    conversation.append_message(conversation.roles[0], text)
    conversation.append_message(conversation.roles[1], "")
    return conversation

# Convert conversation to gradio format
def to_gradio_chatbot(conv):
    ret = []
    for i, (role, msg) in enumerate(conv.messages[conv.offset:]):
        if i % 2 == 0:
            if isinstance(msg, tuple):
                msg, images = msg
                for image in images:
                    img_b64 = pil_to_base64(image, "user upload", max_size=800, min_size=400)
                    msg = msg.replace(IMAGE_TOKEN, img_b64, 1)
            ret.append([msg, None])
        else:
            ret[-1][-1] = msg
    return ret

# Predict function
def predict(text, images, chatbot, history, model_name="deepseek-ai/deepseek-vl2-tiny"):
    logger.info("Starting predict function...")
    tokenizer, vl_gpt, vl_chat_processor = fetch_model(model_name)
    if not text:
        logger.warning("Empty text input detected.")
        return chatbot, history, "Empty context."

    logger.info("Processing images...")
    pil_images = [Image.open(img).convert("RGB") for img in images] if images else []
    conversation = generate_prompt_with_history(
        text, pil_images, history, vl_chat_processor, tokenizer
    )
    all_conv, _ = convert_conversation_to_prompts(conversation)
    stop_words = conversation.stop_str
    gradio_chatbot_output = to_gradio_chatbot(conversation)

    full_response = ""
    logger.info("Generating response...")
    try:
        with torch.no_grad():
            for x in deepseek_generate(
                conversations=all_conv,
                vl_gpt=vl_gpt,
                vl_chat_processor=vl_chat_processor,
                tokenizer=tokenizer,
                stop_words=stop_words,
                max_length=2048,
                temperature=0.1,
                top_p=0.9,
                repetition_penalty=1.1
            ):
                full_response += x
                response = strip_stop_words(full_response, stop_words)
                conversation.update_last_message(response)
                gradio_chatbot_output[-1][1] = response
                logger.info(f"Yielding partial response: {response[:50]}...")
                yield gradio_chatbot_output, conversation.messages, "Generating..."

        logger.info("Generation complete.")
        torch.cuda.empty_cache()
        yield gradio_chatbot_output, conversation.messages, "Success"
    except Exception as e:
        logger.error(f"Error in generation: {str(e)}")
        yield gradio_chatbot_output, conversation.messages, f"Error: {str(e)}"

# Streamlit OCR app interface
def upload_and_process(image):
    if image is None:
        return "Please upload an image.", []
    prompt = "Extract all text from this image exactly as it appears, ensuring the output is in English only. Preserve spaces, bullets, numbers, and all formatting. Do not translate, generate, or include text in any other language. Stop at the last character of the image text."
    chatbot = []
    history = []
    logger.info("Starting upload_and_process...")
    for chatbot_output, history_output, status in predict(prompt, [image], chatbot, history):
        logger.info(f"Status: {status}")
        if status == "Success":
            return chatbot_output[-1][1], history_output
    return "Processing failed.", []

# Streamlit UI
st.title("OCR Extraction Application")
image_input = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
output_text = st.text_area("Extracted Text", height=300)
if image_input:
    output, _ = upload_and_process(image_input)
    output_text.write(output)