hiandrewfisher's picture
Create app.py
fe9f7ed verified
raw
history blame
1.78 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
# Load the model using Transformers' pipeline.
print("Loading model...")
model_pipeline = pipeline('token-classification', 'openfoodfacts/nutrition-extractor')
print("Model loaded successfully.")
def predict(image: Image.Image):
"""
Receives an image, passes it directly to the nutrition extraction model,
and processes the token-classification output to aggregate nutritional values.
Assumes the model performs OCR internally.
"""
# Directly pass the image to the model pipeline.
results = model_pipeline(image)
# Process the output: aggregate numeric values for each entity label.
extracted_data = {}
for item in results:
# Expected structure: {'word': '100', 'entity': 'CALORIES', 'score': 0.98, ...}
label = item.get('entity', 'O').lower()
if label != 'o': # Skip non-entity tokens.
token_text = item.get('word', '')
# Extract digits and decimal point.
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token_text))
try:
value = float(num_str)
extracted_data[label] = extracted_data.get(label, 0) + value
except ValueError:
continue
if not extracted_data:
return {"error": "No nutritional information extracted."}
return extracted_data
# Create a Gradio interface that exposes the API.
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="json",
title="Nutrition Extractor API",
description="Upload an image of a nutrition table to extract nutritional values. The model performs OCR internally."
)
if __name__ == "__main__":
demo.launch()