hiandrewfisher commited on
Commit
3df9c0b
·
verified ·
1 Parent(s): 46c29ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -26
app.py CHANGED
@@ -1,39 +1,91 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
  from PIL import Image
 
4
  import torch
 
5
 
6
- # Load the model using Transformers' pipeline.
7
- print("Loading model...")
8
- model_pipeline = pipeline('token-classification', 'openfoodfacts/nutrition-extractor')
9
- print("Model loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def predict(image: Image.Image):
12
  """
13
- Receives an image, passes it directly to the nutrition extraction model,
14
- and processes the token-classification output to aggregate nutritional values.
15
- Assumes the model performs OCR internally.
16
  """
17
- # Directly pass the image to the model pipeline.
18
- results = model_pipeline(image)
 
 
 
 
 
19
 
20
- # Process the output: aggregate numeric values for each entity label.
 
 
 
 
 
 
 
21
  extracted_data = {}
22
- for item in results:
23
- # Expected structure: {'word': '100', 'entity': 'CALORIES', 'score': 0.98, ...}
24
- label = item.get('entity', 'O').lower()
25
- if label != 'o': # Skip non-entity tokens.
26
- token_text = item.get('word', '')
27
- # Extract digits and decimal point.
28
- num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token_text))
29
- try:
30
- value = float(num_str)
31
- extracted_data[label] = extracted_data.get(label, 0) + value
32
- except ValueError:
33
- continue
34
 
35
  if not extracted_data:
 
36
  return {"error": "No nutritional information extracted."}
 
 
37
  return extracted_data
38
 
39
  # Create a Gradio interface that exposes the API.
@@ -41,9 +93,9 @@ demo = gr.Interface(
41
  fn=predict,
42
  inputs=gr.Image(type="pil"),
43
  outputs="json",
44
- title="Nutrition Extractor API",
45
- description="Upload an image of a nutrition table to extract nutritional values. The model performs OCR internally."
46
  )
47
 
48
  if __name__ == "__main__":
49
- demo.launch()
 
1
  import gradio as gr
2
+ import easyocr
3
+ import numpy as np
4
  from PIL import Image
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
6
  import torch
7
+ import logging
8
 
9
+ # Set up logging for debugging.
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ logger.info("Initializing EasyOCR...")
14
+ # Initialize the EasyOCR reader for English.
15
+ reader = easyocr.Reader(['en'], gpu=False)
16
+ logger.info("EasyOCR initialized.")
17
+
18
+ logger.info("Loading nutrition extraction model...")
19
+ # Load the model using the Hugging Face Transformers pipeline.
20
+ # We force CPU inference by using device=-1.
21
+ tokenizer = AutoTokenizer.from_pretrained("openfoodfacts/nutrition-extractor")
22
+ model = AutoModelForTokenClassification.from_pretrained("openfoodfacts/nutrition-extractor")
23
+ logger.info("Model loaded successfully.")
24
+
25
+ def ocr_extract(image: Image.Image):
26
+ """
27
+ Uses EasyOCR to extract text tokens and their bounding boxes from an image.
28
+ Returns a list of tokens and corresponding boxes in [left, top, width, height] format.
29
+ """
30
+ # Convert PIL image to numpy array.
31
+ np_image = np.array(image)
32
+ results = reader.readtext(np_image)
33
+
34
+ tokens = []
35
+ boxes = []
36
+ for bbox, text, confidence in results:
37
+ if text.strip():
38
+ tokens.append(text)
39
+ # Convert the bounding box (list of 4 points) to [left, top, width, height].
40
+ xs = [point[0] for point in bbox]
41
+ ys = [point[1] for point in bbox]
42
+ left = min(xs)
43
+ top = min(ys)
44
+ width = max(xs) - left
45
+ height = max(ys) - top
46
+ boxes.append([left, top, width, height])
47
+ logger.info(f"OCR extracted {len(tokens)} tokens.")
48
+ return tokens, boxes
49
 
50
  def predict(image: Image.Image):
51
  """
52
+ Runs OCR with EasyOCR to extract tokens and bounding boxes,
53
+ then uses the nutrition extraction model to classify tokens and aggregate nutritional values.
 
54
  """
55
+ tokens, boxes = ocr_extract(image)
56
+ if len(tokens) == 0:
57
+ logger.error("No text detected in the image.")
58
+ return {"error": "No text detected in the image."}
59
+
60
+ # Prepare inputs: pass the tokens and boxes to the tokenizer.
61
+ encoding = tokenizer(tokens, boxes=boxes, return_tensors="pt", truncation=True, padding=True)
62
 
63
+ try:
64
+ outputs = model(**encoding)
65
+ except Exception as e:
66
+ logger.error(f"Error during model inference: {e}")
67
+ return {"error": f"Model inference error: {e}"}
68
+
69
+ # Get predicted labels for each token.
70
+ predictions = torch.argmax(outputs.logits, dim=2)
71
  extracted_data = {}
72
+ for token, pred in zip(tokens, predictions[0].tolist()):
73
+ label = model.config.id2label.get(pred, "O").lower()
74
+ if label == "o":
75
+ continue
76
+ # Extract numeric value from token.
77
+ num_str = "".join(filter(lambda c: c.isdigit() or c == '.', token))
78
+ try:
79
+ value = float(num_str)
80
+ extracted_data[label] = extracted_data.get(label, 0) + value
81
+ except ValueError:
82
+ continue
 
83
 
84
  if not extracted_data:
85
+ logger.warning("No nutritional information extracted.")
86
  return {"error": "No nutritional information extracted."}
87
+
88
+ logger.info(f"Extracted data: {extracted_data}")
89
  return extracted_data
90
 
91
  # Create a Gradio interface that exposes the API.
 
93
  fn=predict,
94
  inputs=gr.Image(type="pil"),
95
  outputs="json",
96
+ title="Nutrition Extractor API with EasyOCR",
97
+ description="Upload an image of a nutrition table to extract nutritional values. The pipeline uses EasyOCR to extract tokens and bounding boxes, then processes them with the openfoodfacts/nutrition-extractor model."
98
  )
99
 
100
  if __name__ == "__main__":
101
+ demo.launch(share=True)