Spaces:
Sleeping
Sleeping
File size: 830 Bytes
db2db2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from transformers import pipeline, AutoImageProcessor
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def load_model(local_path, is_image_model=False):
"""
Load a model from the specified local path.
Args:
local_path (str): The local path to the model.
is_image_model (bool): Flag indicating if the model is an image model.
Returns:
pipeline: The loaded model pipeline.
"""
if is_image_model:
image_processor = AutoImageProcessor.from_pretrained(local_path, use_fast=True)
return pipeline("image-classification", model=local_path, device=device, image_processor=image_processor)
return pipeline("text-classification", model=local_path, device=device) |