Spaces:
Sleeping
Sleeping
File size: 4,427 Bytes
d7f5bb8 f4ed285 a4d3440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
#!pip install gradio transformers torch
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
import torch
# Load the OCR model and processor
ocr_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype="auto",
device_map="auto",
)
ocr_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
# Load the Math model and tokenizer
math_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-Math-72B-Instruct",
torch_dtype="auto",
device_map="auto"
)
math_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Math-72B-Instruct")
# OCR extraction function
def ocr_and_query(image, question):
# Prepare image for the model
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": question
},
],
}
]
# Process image and text prompt
text_prompt = ocr_processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = ocr_processor(text=[text_prompt], images=[image], padding=True, return_tensors="pt")
# Run the model to generate OCR results
inputs = inputs.to("cuda")
output_ids = ocr_model.generate(**inputs, max_new_tokens=1024)
# Decode the generated text
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]
return output_text
# Math problem solving function
def solve_math_problem(prompt):
# CoT (Chain of Thought)
messages = [
{"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
{"role": "user", "content": prompt}
]
text = math_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = math_tokenizer([text], return_tensors="pt").to("cuda")
generated_ids = math_model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = math_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
# Function to clear inputs and output
def clear_inputs():
return None, "", ""
# Gradio interface setup
def gradio_app(image, question, task):
if task == "OCR and Query":
return image, question, ocr_and_query(image, question)
elif task == "Solve Math Problem from Image":
if image is None:
return image, question, "Please upload an image."
extracted_text = ocr_and_query(image, "")
math_solution = solve_math_problem(extracted_text)
return image, extracted_text, math_solution
elif task == "Solve Math Problem from Text":
if question.strip() == "":
return image, question, "Please enter a math problem."
math_solution = solve_math_problem(question)
return image, question, math_solution
else:
return image, question, "Please select a task."
# Gradio interface
with gr.Blocks() as app:
gr.Markdown("# Image OCR and Math Solver")
gr.Markdown("Upload an image, enter your question or math problem, and select the appropriate task.")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
text_input = gr.Textbox(lines=2, placeholder="Enter your question or math problem here...", label="Input")
with gr.Row():
task_radio = gr.Radio(["OCR and Query", "Solve Math Problem from Image", "Solve Math Problem from Text"], label="Task")
with gr.Row():
complete_button = gr.Button("Complete")
clear_button = gr.Button("Clear")
output = gr.Markdown(label="Output")
# Event listeners
complete_button.click(fn=gradio_app, inputs=[image_input, text_input, task_radio], outputs=[image_input, text_input, output])
clear_button.click(fn=clear_inputs, outputs=[image_input, text_input, output])
# Launch the app
app.launch(share=True)
|