nooneshouldtouch commited on
Commit
688a486
·
verified ·
1 Parent(s): cb4b9e1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
6
+ from io import BytesIO
7
+
8
+ app = FastAPI()
9
+
10
+ # Enable CORS to allow requests from frontend (React)
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"], # Change ["http://localhost:5173"] for better security
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ # Load your model
20
+ model = load_model("densenet201_food_classification.h5")
21
+
22
+ # Define class indices
23
+ class_indices = {
24
+ 0: "burger",
25
+ 1: "butter_naan",
26
+ 2: "chai",
27
+ 3: "chapati",
28
+ 4: "chole_bhature",
29
+ 5: "dal_makhani",
30
+ 6: "dhokla",
31
+ 7: "fried_rice",
32
+ 8: "idli",
33
+ 9: "jalebi",
34
+ 10: "kaathi_rolls",
35
+ 11: "kadai_paneer",
36
+ 12: "kulfi",
37
+ 13: "masala_dosa",
38
+ 14: "momos",
39
+ 15: "paani_puri",
40
+ 16: "pakode",
41
+ 17: "pav_bhaji",
42
+ 18: "pizza",
43
+ 19: "samosa"
44
+ }
45
+
46
+ def predict_image(image, model):
47
+ try:
48
+ img = load_img(image, target_size=(224, 224))
49
+ image_array = img_to_array(img) / 255.0
50
+ image_array = np.expand_dims(image_array, axis=0)
51
+
52
+ predictions = model.predict(image_array)
53
+ class_idx = np.argmax(predictions)
54
+ class_label = class_indices.get(class_idx, "Unknown")
55
+ confidence = float(predictions[0][class_idx])
56
+
57
+ return class_label, confidence
58
+ except Exception as e:
59
+ return None, None
60
+
61
+ @app.post("/predict/")
62
+ async def predict(file: UploadFile = File(...)):
63
+ try:
64
+ image_data = await file.read()
65
+ image = BytesIO(image_data)
66
+
67
+ class_label, confidence = predict_image(image, model)
68
+
69
+ if class_label is None:
70
+ return {"error": "Prediction failed"}
71
+
72
+ return {"predicted_class": class_label, "confidence": f"{confidence:.2f}"}
73
+ except Exception as e:
74
+ return {"error": f"Internal Server Error: {str(e)}"}