sharktide commited on
Commit
f02393e
·
verified ·
1 Parent(s): 1eab4af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -57
app.py CHANGED
@@ -5,9 +5,6 @@ from PIL import Image
5
  import cv2
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from torchvision import transforms
9
- from transformers import AutoModelForImageSegmentation
10
- import torch # Make sure torch is imported
11
  import logging
12
 
13
  # Set up logging
@@ -17,17 +14,8 @@ logger = logging.getLogger(__name__)
17
  # Load your trained model
18
  model = tf.keras.models.load_model('recyclebot.keras')
19
 
20
- # Load background removal model
21
- birefnet = AutoModelForImageSegmentation.from_pretrained(
22
- "ZhengPeng7/BiRefNet", trust_remote_code=True
23
- )
24
 
25
- # Transform for the background removal model
26
- transform_image = transforms.Compose([
27
- transforms.Resize((1024, 1024)),
28
- transforms.ToTensor(),
29
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
- ])
31
 
32
  # Define class names for predictions (this should be the same as in your local code)
33
  CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
@@ -64,20 +52,6 @@ def preprocess_image(image_file):
64
  raise
65
 
66
  # Background removal function
67
- def remove_background(image):
68
- try:
69
- image_size = image.size
70
- input_images = transform_image(image).unsqueeze(0)
71
- with torch.no_grad():
72
- preds = birefnet(input_images)[-1].sigmoid()
73
- pred = preds[0].squeeze()
74
- pred_pil = transforms.ToPILImage()(pred)
75
- mask = pred_pil.resize(image_size)
76
- image.putalpha(mask)
77
- return image
78
- except Exception as e:
79
- logger.error(f"Error in remove_background: {str(e)}")
80
- raise
81
 
82
  @app.post("/predict")
83
  async def predict(file: UploadFile = File(...)):
@@ -95,36 +69,6 @@ async def predict(file: UploadFile = File(...)):
95
  logger.error(f"Error in /predict: {str(e)}")
96
  return JSONResponse(content={"error": str(e)}, status_code=400)
97
 
98
- @app.post("/predict/recyclebot0accuracy")
99
- async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
100
- try:
101
- logger.info("Received request for /predict/recyclebot0accuracy")
102
- # Load and remove background from image
103
- image = Image.open(file.file).convert("RGB")
104
- image = remove_background(image)
105
-
106
- # Convert the image to RGB mode before saving as JPEG
107
- if image.mode == 'RGBA':
108
- image = image.convert('RGB')
109
-
110
- # Save the image as JPEG (to use in further processing)
111
- image_path = "processed_image.jpg"
112
- image.save(image_path, "JPEG")
113
-
114
- # Preprocess the image with the background removed
115
- img_array = preprocess_image(image_path)
116
-
117
- # Get predictions
118
- prediction1 = model.predict(img_array)
119
-
120
- predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
121
- predicted_class = CLASSES[predicted_class_idx] # Convert to class name
122
-
123
- return JSONResponse(content={"prediction": predicted_class})
124
-
125
- except Exception as e:
126
- logger.error(f"Error in /predict/recyclebot0accuracy: {str(e)}")
127
- return JSONResponse(content={"error": str(e)}, status_code=400)
128
 
129
  @app.get("/working")
130
  async def working():
 
5
  import cv2
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
8
  import logging
9
 
10
  # Set up logging
 
14
  # Load your trained model
15
  model = tf.keras.models.load_model('recyclebot.keras')
16
 
 
 
 
 
17
 
18
+
 
 
 
 
 
19
 
20
  # Define class names for predictions (this should be the same as in your local code)
21
  CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
 
52
  raise
53
 
54
  # Background removal function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @app.post("/predict")
57
  async def predict(file: UploadFile = File(...)):
 
69
  logger.error(f"Error in /predict: {str(e)}")
70
  return JSONResponse(content={"error": str(e)}, status_code=400)
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  @app.get("/working")
74
  async def working():