Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
# Load the model using Transformers' pipeline. | |
print("Loading model...") | |
model_pipeline = pipeline('token-classification', 'openfoodfacts/nutrition-extractor') | |
print("Model loaded successfully.") | |
def predict(image: Image.Image): | |
""" | |
Receives an image, passes it directly to the nutrition extraction model, | |
and processes the token-classification output to aggregate nutritional values. | |
Assumes the model performs OCR internally. | |
""" | |
# Directly pass the image to the model pipeline. | |
results = model_pipeline(image) | |
# Process the output: aggregate numeric values for each entity label. | |
extracted_data = {} | |
for item in results: | |
# Expected structure: {'word': '100', 'entity': 'CALORIES', 'score': 0.98, ...} | |
label = item.get('entity', 'O').lower() | |
if label != 'o': # Skip non-entity tokens. | |
token_text = item.get('word', '') | |
# Extract digits and decimal point. | |
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token_text)) | |
try: | |
value = float(num_str) | |
extracted_data[label] = extracted_data.get(label, 0) + value | |
except ValueError: | |
continue | |
if not extracted_data: | |
return {"error": "No nutritional information extracted."} | |
return extracted_data | |
# Create a Gradio interface that exposes the API. | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs="json", | |
title="Nutrition Extractor API", | |
description="Upload an image of a nutrition table to extract nutritional values. The model performs OCR internally." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |