Spaces:
Sleeping
Sleeping
import gradio as gr | |
import easyocr | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForTokenClassification | |
import torch | |
import logging | |
# Set up logging for debugging. | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
logger.info("Initializing EasyOCR...") | |
# Initialize the EasyOCR reader for English. | |
reader = easyocr.Reader(['en'], gpu=False) | |
logger.info("EasyOCR initialized.") | |
logger.info("Loading nutrition extraction model...") | |
# Load the model using the Hugging Face Transformers pipeline. | |
# Force CPU inference with device=-1. | |
tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor") | |
model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor") | |
logger.info("Model loaded successfully.") | |
def ocr_extract(image: Image.Image): | |
""" | |
Uses EasyOCR to extract text tokens and their bounding boxes from an image. | |
Returns a list of tokens and corresponding boxes in [left, top, width, height] format. | |
Bounding box coordinates are cast to int. | |
""" | |
# Convert PIL image to numpy array. | |
np_image = np.array(image) | |
results = reader.readtext(np_image) | |
tokens = [] | |
boxes = [] | |
for bbox, text, confidence in results: | |
if text.strip(): | |
tokens.append(text) | |
# Convert the bounding box (list of 4 points) to [left, top, width, height]. | |
xs = [point[0] for point in bbox] | |
ys = [point[1] for point in bbox] | |
left = int(min(xs)) | |
top = int(min(ys)) | |
width = int(max(xs) - left) | |
height = int(max(ys) - top) | |
boxes.append([left, top, width, height]) | |
logger.info(f"OCR extracted {len(tokens)} tokens.") | |
return tokens, boxes | |
def predict(image: Image.Image): | |
""" | |
Runs OCR with EasyOCR to extract tokens and bounding boxes, | |
then uses the nutrition extraction model to classify tokens and aggregate nutritional values. | |
""" | |
tokens, boxes = ocr_extract(image) | |
if len(tokens) == 0: | |
logger.error("No text detected in the image.") | |
return {"error": "No text detected in the image."} | |
# Prepare inputs: pass the tokens and boxes to the tokenizer. | |
encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True) | |
try: | |
outputs = model(**encoding) | |
except Exception as e: | |
logger.error(f"Error during model inference: {e}") | |
return {"error": f"Model inference error: {e}"} | |
# Get predicted labels for each token. | |
predictions = torch.argmax(outputs.logits, dim=2) | |
extracted_data = {} | |
for token, pred in zip(tokens, predictions[0].tolist()): | |
label = model.config.id2label.get(pred, "O").lower() | |
if label == "o": | |
continue | |
# Extract numeric value from token. | |
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token)) | |
try: | |
value = float(num_str) | |
extracted_data[label] = extracted_data.get(label, 0) + value | |
except ValueError: | |
continue | |
if not extracted_data: | |
logger.warning("No nutritional information extracted.") | |
return {"error": "No nutritional information extracted."} | |
logger.info(f"Extracted data: {extracted_data}") | |
return extracted_data | |
# Create a Gradio interface that exposes the API. | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="json", | |
title="Nutrition Extractor API with EasyOCR", | |
description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model." | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |