File size: 4,603 Bytes
323d1c7
e10a731
 
 
5e08f68
e10a731
fdd7026
5e08f68
 
76cc3b5
 
 
 
 
5158177
3e17761
a80f51a
e10a731
5e08f68
 
 
 
 
 
 
 
 
 
 
 
fdd7026
 
e10a731
 
5158177
 
fdd7026
 
 
 
 
 
 
 
5e08f68
e10a731
76cc3b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e10a731
5e08f68
 
76cc3b5
 
 
 
 
 
 
 
 
 
 
 
 
fdd7026
6172b16
5e08f68
e10a731
76cc3b5
3e17761
eb2590f
fdd7026
edecd42
fdd7026
 
e10a731
fdd7026
e10a731
76cc3b5
e10a731
 
5e08f68
 
 
76cc3b5
5e08f68
 
 
fdd7026
5e08f68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76cc3b5
5e08f68
fdd7026
df718f8
 
eb2590f
fdd7026
5e08f68
e10a731
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from fastapi import FastAPI, File, UploadFile, Request
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load your trained model
model = tf.keras.models.load_model('recyclebot.keras')

# Load background removal model
birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)

# Transform for the background removal model
transform_image = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define class names for predictions (this should be the same as in your local code)
CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']

# Create FastAPI app
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins (or specify specific origins)
    allow_credentials=True,
    allow_methods=["*"],  # Allow all HTTP methods
    allow_headers=["*"],  # Allow all headers
)

# Preprocess the image (resize, reshape without normalization)
def preprocess_image(image_file):
    try:
        # Load image using PIL
        image = Image.open(image_file)
        
        # Convert image to numpy array
        image = np.array(image)
        
        # Resize to the input shape expected by the model
        image = cv2.resize(image, (240, 240))  # Resize image to match model input
        
        # Reshape the image (similar to your local code)
        image = image.reshape(-1, 240, 240, 3)  # Add the batch dimension for inference
        
        return image
    except Exception as e:
        logger.error(f"Error in preprocess_image: {str(e)}")
        raise

# Background removal function
def remove_background(image):
    try:
        image_size = image.size
        input_images = transform_image(image).unsqueeze(0)
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image_size)
        image.putalpha(mask)
        return image
    except Exception as e:
        logger.error(f"Error in remove_background: {str(e)}")
        raise

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        logger.info("Received request for /predict")
        img_array = preprocess_image(file.file)  # Preprocess the image
        prediction1 = model.predict(img_array)  # Get predictions
        
        predicted_class_idx = np.argmax(prediction1, axis=1)[0]  # Get predicted class index
        predicted_class = CLASSES[predicted_class_idx]  # Convert to class name
        
        return JSONResponse(content={"prediction": predicted_class})
    
    except Exception as e:
        logger.error(f"Error in /predict: {str(e)}")
        return JSONResponse(content={"error": str(e)}, status_code=400)

@app.post("/predict/recyclebot0accuracy")
async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
    try:
        logger.info("Received request for /predict/recyclebot0accuracy")
        # Load and remove background from image
        image = Image.open(file.file).convert("RGB")
        image = remove_background(image)
        
        # Save the image with a transparent background (to use in further processing)
        image_path = "processed_image.jpg"
        image.save(image_path, "JPEG")
        
        # Preprocess the image with the background removed
        img_array = preprocess_image(image_path)
        
        # Get predictions
        prediction1 = model.predict(img_array)
        
        predicted_class_idx = np.argmax(prediction1, axis=1)[0]  # Get predicted class index
        predicted_class = CLASSES[predicted_class_idx]  # Convert to class name
        
        return JSONResponse(content={"prediction": predicted_class})
    
    except Exception as e:
        logger.error(f"Error in /predict/recyclebot0accuracy: {str(e)}")
        return JSONResponse(content={"error": str(e)}, status_code=400)

@app.get("/working")
async def working():
    return JSONResponse(content={"Status": "Working"})

# To manually run FastAPI
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)