|
import gradio as gr |
|
from transformers import pipeline |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
print("Loading model...") |
|
model_pipeline = pipeline('token-classification', 'openfoodfacts/nutrition-extractor') |
|
print("Model loaded successfully.") |
|
|
|
def predict(image: Image.Image): |
|
""" |
|
Receives an image, passes it directly to the nutrition extraction model, |
|
and processes the token-classification output to aggregate nutritional values. |
|
Assumes the model performs OCR internally. |
|
""" |
|
|
|
results = model_pipeline(image) |
|
|
|
|
|
extracted_data = {} |
|
for item in results: |
|
|
|
label = item.get('entity', 'O').lower() |
|
if label != 'o': |
|
token_text = item.get('word', '') |
|
|
|
num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token_text)) |
|
try: |
|
value = float(num_str) |
|
extracted_data[label] = extracted_data.get(label, 0) + value |
|
except ValueError: |
|
continue |
|
|
|
if not extracted_data: |
|
return {"error": "No nutritional information extracted."} |
|
return extracted_data |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="json", |
|
title="Nutrition Extractor API", |
|
description="Upload an image of a nutrition table to extract nutritional values. The model performs OCR internally." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|