nao_api_flask / app.py
brightlembo's picture
Update app.py
8a9662a verified
raw
history blame
4.35 kB
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image
import io
import base64
import numpy as np
from flask import Flask, request, jsonify
class HuggingFaceClassifier:
def __init__(self, model_name="microsoft/resnet-50"):
"""
Initialize Hugging Face model and feature extractor
Args:
model_name (str): Hugging Face model identifier
"""
try:
# Load pre-trained model and feature extractor
self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
self.model = AutoModelForImageClassification.from_pretrained(model_name)
# Move to GPU if available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise ValueError(f"Model loading error: {e}")
def preprocess_image(self, image):
"""
Preprocess image for model input
Args:
image (PIL.Image): Input image
Returns:
torch.Tensor: Preprocessed image tensor
"""
# Preprocess image using feature extractor
inputs = self.feature_extractor(images=image, return_tensors="pt")
return inputs.pixel_values.to(self.device)
def predict(self, image):
"""
Predict image classification
Args:
image (PIL.Image): Input image
Returns:
list: Top prediction results
"""
try:
# Preprocess image
inputs = self.preprocess_image(image)
# Perform prediction
with torch.no_grad():
outputs = self.model(inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
top_k = torch.topk(probabilities, k=5)
# Process results
predicted_classes = [
{
"label": self.model.config.id2label[idx.item()],
"score": prob.item()
}
for idx, prob in zip(top_k.indices[0], top_k.values[0])
]
return predicted_classes
except Exception as e:
raise RuntimeError(f"Prediction error: {e}")
# Flask API Setup
app = Flask(__name__)
# Initialize classifier (can be changed to any model)
classifier = HuggingFaceClassifier(
model_name="microsoft/resnet-50"
)
@app.route('/predict', methods=['POST'])
def predict_image():
"""
Image classification endpoint
Supports base64 and file upload
"""
try:
# Handle base64 encoded image
if 'image' in request.json:
image_data = base64.b64decode(request.json['image'])
image = Image.open(io.BytesIO(image_data))
# Handle file upload
elif 'file' in request.files:
image = Image.open(request.files['file'])
else:
return jsonify({
'error': 'No image provided',
'status': 'failed'
}), 400
# Perform prediction
predictions = classifier.predict(image)
return jsonify({
'predictions': predictions,
'status': 'success'
})
except Exception as e:
return jsonify({
'error': str(e),
'status': 'failed'
}), 500
@app.route('/models', methods=['GET'])
def available_models():
"""
List available pre-trained models
"""
models = [
"microsoft/resnet-50",
"google/vit-base-patch16-224",
"facebook/vit-mae-base",
"microsoft/beit-base-patch16-224"
]
return jsonify({
'models': models,
'total_models': len(models)
})
@app.route('/health', methods=['GET'])
def health_check():
"""
API health check endpoint
"""
return jsonify({
'status': 'healthy',
'model': classifier.model.config.model_type,
'device': str(classifier.device)
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=True)