File size: 3,781 Bytes
fe9f7ed
3df9c0b
 
fe9f7ed
3df9c0b
fe9f7ed
3df9c0b
fe9f7ed
3df9c0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe9f7ed
 
 
3df9c0b
 
fe9f7ed
3df9c0b
 
 
 
 
 
 
fe9f7ed
3df9c0b
 
 
 
 
 
 
 
fe9f7ed
3df9c0b
 
 
 
 
 
 
 
 
 
 
fe9f7ed
 
3df9c0b
fe9f7ed
3df9c0b
 
fe9f7ed
 
 
 
 
 
 
3df9c0b
 
fe9f7ed
 
 
3df9c0b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import easyocr
import numpy as np
from PIL import Image
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import logging

# Set up logging for debugging.
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("Initializing EasyOCR...")
# Initialize the EasyOCR reader for English.
reader = easyocr.Reader(['en'], gpu=False)
logger.info("EasyOCR initialized.")

logger.info("Loading nutrition extraction model...")
# Load the model using the Hugging Face Transformers pipeline.
# We force CPU inference by using device=-1.
tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor")
model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor")
logger.info("Model loaded successfully.")

def ocr_extract(image: Image.Image):
    """
    Uses EasyOCR to extract text tokens and their bounding boxes from an image.
    Returns a list of tokens and corresponding boxes in [left, top, width, height] format.
    """
    # Convert PIL image to numpy array.
    np_image = np.array(image)
    results = reader.readtext(np_image)
    
    tokens = []
    boxes = []
    for bbox, text, confidence in results:
        if text.strip():
            tokens.append(text)
            # Convert the bounding box (list of 4 points) to [left, top, width, height].
            xs = [point[0] for point in bbox]
            ys = [point[1] for point in bbox]
            left = min(xs)
            top = min(ys)
            width = max(xs) - left
            height = max(ys) - top
            boxes.append([left, top, width, height])
    logger.info(f"OCR extracted {len(tokens)} tokens.")
    return tokens, boxes

def predict(image: Image.Image):
    """
    Runs OCR with EasyOCR to extract tokens and bounding boxes,
    then uses the nutrition extraction model to classify tokens and aggregate nutritional values.
    """
    tokens, boxes = ocr_extract(image)
    if len(tokens) == 0:
        logger.error("No text detected in the image.")
        return {"error": "No text detected in the image."}
    
    # Prepare inputs: pass the tokens and boxes to the tokenizer.
    encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True)
    
    try:
        outputs = model(**encoding)
    except Exception as e:
        logger.error(f"Error during model inference: {e}")
        return {"error": f"Model inference error: {e}"}
    
    # Get predicted labels for each token.
    predictions = torch.argmax(outputs.logits, dim=2)
    extracted_data = {}
    for token, pred in zip(tokens, predictions[0].tolist()):
        label = model.config.id2label.get(pred, "O").lower()
        if label == "o":
            continue
        # Extract numeric value from token.
        num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token))
        try:
            value = float(num_str)
            extracted_data[label] = extracted_data.get(label, 0) + value
        except ValueError:
            continue

    if not extracted_data:
        logger.warning("No nutritional information extracted.")
        return {"error": "No nutritional information extracted."}
    
    logger.info(f"Extracted data: {extracted_data}")
    return extracted_data

# Create a Gradio interface that exposes the API.
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="json",
    title="Nutrition Extractor API with EasyOCR",
    description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model."
)

if __name__ == "__main__":
    demo.launch(share=True)