hiandrewfisher's picture
Update app.py
3df9c0b verified
raw
history blame
3.78 kB
import gradio as gr
import easyocr
import numpy as np
from PIL import Image
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import logging
# Set up logging for debugging.
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("Initializing EasyOCR...")
# Initialize the EasyOCR reader for English.
reader = easyocr.Reader(['en'], gpu=False)
logger.info("EasyOCR initialized.")
logger.info("Loading nutrition extraction model...")
# Load the model using the Hugging Face Transformers pipeline.
# We force CPU inference by using device=-1.
tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor")
model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor")
logger.info("Model loaded successfully.")
def ocr_extract(image: Image.Image):
"""
Uses EasyOCR to extract text tokens and their bounding boxes from an image.
Returns a list of tokens and corresponding boxes in [left, top, width, height] format.
"""
# Convert PIL image to numpy array.
np_image = np.array(image)
results = reader.readtext(np_image)
tokens = []
boxes = []
for bbox, text, confidence in results:
if text.strip():
tokens.append(text)
# Convert the bounding box (list of 4 points) to [left, top, width, height].
xs = [point[0] for point in bbox]
ys = [point[1] for point in bbox]
left = min(xs)
top = min(ys)
width = max(xs) - left
height = max(ys) - top
boxes.append([left, top, width, height])
logger.info(f"OCR extracted {len(tokens)} tokens.")
return tokens, boxes
def predict(image: Image.Image):
"""
Runs OCR with EasyOCR to extract tokens and bounding boxes,
then uses the nutrition extraction model to classify tokens and aggregate nutritional values.
"""
tokens, boxes = ocr_extract(image)
if len(tokens) == 0:
logger.error("No text detected in the image.")
return {"error": "No text detected in the image."}
# Prepare inputs: pass the tokens and boxes to the tokenizer.
encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True)
try:
outputs = model(**encoding)
except Exception as e:
logger.error(f"Error during model inference: {e}")
return {"error": f"Model inference error: {e}"}
# Get predicted labels for each token.
predictions = torch.argmax(outputs.logits, dim=2)
extracted_data = {}
for token, pred in zip(tokens, predictions[0].tolist()):
label = model.config.id2label.get(pred, "O").lower()
if label == "o":
continue
# Extract numeric value from token.
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token))
try:
value = float(num_str)
extracted_data[label] = extracted_data.get(label, 0) + value
except ValueError:
continue
if not extracted_data:
logger.warning("No nutritional information extracted.")
return {"error": "No nutritional information extracted."}
logger.info(f"Extracted data: {extracted_data}")
return extracted_data
# Create a Gradio interface that exposes the API.
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="json",
title="Nutrition Extractor API with EasyOCR",
description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model."
)
if __name__ == "__main__":
demo.launch(share=True)