Spaces:
Sleeping
Sleeping
File size: 2,506 Bytes
323d1c7 e10a731 5e08f68 e10a731 fdd7026 76cc3b5 5158177 3e17761 a80f51a e10a731 5e08f68 f02393e 5e08f68 fdd7026 e10a731 5158177 fdd7026 5e08f68 e10a731 76cc3b5 e10a731 5e08f68 fdd7026 6172b16 5e08f68 e10a731 76cc3b5 3e17761 eb2590f fdd7026 edecd42 fdd7026 e10a731 fdd7026 e10a731 76cc3b5 e10a731 fdd7026 df718f8 eb2590f fdd7026 5e08f68 e10a731 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from fastapi import FastAPI, File, UploadFile, Request
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load your trained model
model = tf.keras.models.load_model('recyclebot.keras')
# Define class names for predictions (this should be the same as in your local code)
CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
# Create FastAPI app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins (or specify specific origins)
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
# Preprocess the image (resize, reshape without normalization)
def preprocess_image(image_file):
try:
# Load image using PIL
image = Image.open(image_file)
# Convert image to numpy array
image = np.array(image)
# Resize to the input shape expected by the model
image = cv2.resize(image, (240, 240)) # Resize image to match model input
# Reshape the image (similar to your local code)
image = image.reshape(-1, 240, 240, 3) # Add the batch dimension for inference
return image
except Exception as e:
logger.error(f"Error in preprocess_image: {str(e)}")
raise
# Background removal function
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
try:
logger.info("Received request for /predict")
img_array = preprocess_image(file.file) # Preprocess the image
prediction1 = model.predict(img_array) # Get predictions
predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
predicted_class = CLASSES[predicted_class_idx] # Convert to class name
return JSONResponse(content={"prediction": predicted_class})
except Exception as e:
logger.error(f"Error in /predict: {str(e)}")
return JSONResponse(content={"error": str(e)}, status_code=400)
@app.get("/working")
async def working():
return JSONResponse(content={"Status": "Working"})
# To manually run FastAPI
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|