recycleai-api / app.py
sharktide's picture
Update app.py
f02393e verified
from fastapi import FastAPI, File, UploadFile, Request
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load your trained model
model = tf.keras.models.load_model('recyclebot.keras')
# Define class names for predictions (this should be the same as in your local code)
CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
# Create FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (or specify specific origins)
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Preprocess the image (resize, reshape without normalization)
def preprocess_image(image_file):
try:
# Load image using PIL
image = Image.open(image_file)
# Convert image to numpy array
image = np.array(image)
# Resize to the input shape expected by the model
image = cv2.resize(image, (240, 240)) # Resize image to match model input
# Reshape the image (similar to your local code)
image = image.reshape(-1, 240, 240, 3) # Add the batch dimension for inference
return image
except Exception as e:
logger.error(f"Error in preprocess_image: {str(e)}")
raise
# Background removal function
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
logger.info("Received request for /predict")
img_array = preprocess_image(file.file) # Preprocess the image
prediction1 = model.predict(img_array) # Get predictions
predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
predicted_class = CLASSES[predicted_class_idx] # Convert to class name
return JSONResponse(content={"prediction": predicted_class})
except Exception as e:
logger.error(f"Error in /predict: {str(e)}")
return JSONResponse(content={"error": str(e)}, status_code=400)
@app.get("/working")
async def working():
return JSONResponse(content={"Status": "Working"})
# To manually run FastAPI
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)