hiandrewfisher commited on
Commit
fe9f7ed
·
verified ·
1 Parent(s): 85374e6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load the model using Transformers' pipeline.
7
+ print("Loading model...")
8
+ model_pipeline = pipeline('token-classification', 'openfoodfacts/nutrition-extractor')
9
+ print("Model loaded successfully.")
10
+
11
+ def predict(image: Image.Image):
12
+ """
13
+ Receives an image, passes it directly to the nutrition extraction model,
14
+ and processes the token-classification output to aggregate nutritional values.
15
+ Assumes the model performs OCR internally.
16
+ """
17
+ # Directly pass the image to the model pipeline.
18
+ results = model_pipeline(image)
19
+
20
+ # Process the output: aggregate numeric values for each entity label.
21
+ extracted_data = {}
22
+ for item in results:
23
+ # Expected structure: {'word': '100', 'entity': 'CALORIES', 'score': 0.98, ...}
24
+ label = item.get('entity', 'O').lower()
25
+ if label != 'o': # Skip non-entity tokens.
26
+ token_text = item.get('word', '')
27
+ # Extract digits and decimal point.
28
+ num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token_text))
29
+ try:
30
+ value = float(num_str)
31
+ extracted_data[label] = extracted_data.get(label, 0) + value
32
+ except ValueError:
33
+ continue
34
+
35
+ if not extracted_data:
36
+ return {"error": "No nutritional information extracted."}
37
+ return extracted_data
38
+
39
+ # Create a Gradio interface that exposes the API.
40
+ demo = gr.Interface(
41
+ fn=predict,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs="json",
44
+ title="Nutrition Extractor API",
45
+ description="Upload an image of a nutrition table to extract nutritional values. The model performs OCR internally."
46
+ )
47
+
48
+ if __name__ == "__main__":
49
+ demo.launch()