File size: 2,832 Bytes
db2db2a
 
583e7cf
db2db2a
7312439
 
 
8973310
 
db2db2a
 
 
 
 
 
8973310
c034083
 
 
 
 
 
 
 
 
 
db2db2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d089521
db2db2a
067c765
db2db2a
 
 
 
 
 
 
c034083
 
 
 
 
 
 
 
 
 
 
 
db2db2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
067c765
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import boto3
from utils.log import logger

aws_access_key = os.getenv("AWS_ACCESS_KEY_ID") 
aws_key_pw =  os.getenv("AWS_SECRET_ACCESS_KEY") 
BUCKET_NAME =  os.getenv("BUCKET_NAME") 


s3 = boto3.client(
    "s3",
    aws_access_key_id=aws_access_key,
    aws_secret_access_key=aws_key_pw,    
)

def download_model_from_s3(local_path: str, s3_prefix: str):
    """
    Downloads a model from S3 to the specified local path.

    Args:
        local_path (str): The local path to download the model to.
        s3_prefix (str): The S3 prefix of the model to download.

    Raises:
        RuntimeError: If there is an error downloading the model from S3.
    """
    try:
        if os.path.exists(local_path) and os.listdir(local_path):
            logger.info(f"Model {local_path} already exists. Skipping download.")
            return

        logger.info(f"Downloading model from S3: {s3_prefix} to {local_path}")
        os.makedirs(local_path, exist_ok=True)
        paginator = s3.get_paginator("list_objects_v2")

        for result in paginator.paginate(Bucket=BUCKET_NAME, Prefix=s3_prefix):
            if "Contents" in result:
                for key in result["Contents"]:
                    s3_key = key["Key"]
                    local_file = os.path.join(local_path, os.path.relpath(s3_key, s3_prefix))
                   
                    os.makedirs(os.path.dirname(local_file), exist_ok=True)
                    s3.download_file(BUCKET_NAME, s3_key, local_file)
                    logger.info(f"Completed download {s3_key} to {local_file}")
    except Exception as e:
        logger.info(f"Failed to download model from S3: {e}")
        raise RuntimeError(f"Error downloading model from S3: {e}")

def upload_image_to_s3(
        file_name, 
        s3_prefix="ml-images", 
        object_name=None
    ):
    """
    Uploads an image to S3 and returns a presigned URL for the object.

    Args:
        file_name (str): The file name of the image to upload.
        s3_prefix (str): The S3 prefix to use for the object name.
        object_name (str, optional): The object name to use for the S3 key.
            If not provided, the object name will be the same as the file name.

    Returns:
        str: The presigned URL for the S3 object.
    """
    if object_name is None:
        object_name = os.path.basename(file_name)

    object_name = f"{s3_prefix}/{object_name}"
    s3.upload_file(file_name, BUCKET_NAME, object_name)
    logger.info(f"Uploaded {file_name} to s3://{BUCKET_NAME}/{object_name}")

    response = s3.generate_presigned_url(
        'get_object',
        Params={
            "Bucket": BUCKET_NAME,
            "Key": object_name
        },
        ExpiresIn=3600
    )
    logger.info(f"Generated presigned URL for {object_name}: {response}")
    return response