File size: 3,451 Bytes
1d7c63d
 
 
 
679611d
 
af9c1e6
1d7c63d
 
 
 
679611d
 
 
 
 
 
 
 
 
 
1d7c63d
 
 
 
 
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
 
 
 
 
 
 
 
 
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
 
 
 
 
 
 
 
 
af9c1e6
 
 
 
c5343e6
 
af9c1e6
 
c5343e6
 
 
 
 
 
0c61c42
c5343e6
 
 
 
 
 
af9c1e6
 
c5343e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests


pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))

if not os.path.exists('dataset'):
  REPO_ID='Serrelab/Fossils'
  token = os.environ.get('READ_TOKEN')
  print(f"Read token:{token}")
  if token is None:
     print("warning! A read token in env variables is needed for authentication.")
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')

embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')

fossils_pd= pd.read_csv('fossils_paths.csv')

def pca_distance(pca,sample,embedding):
    """
    Args:
        pca:fitted PCA model
        sample:sample for which to find the closest embeddings
        embedding:embeddings of the dataset
    Returns:
        The indices of the five closest embeddings to the sample
    """
    s = pca.transform(sample.reshape(1,-1))
    all = pca.transform(embedding[:,-1])
    distances = np.linalg.norm(all - s, axis=1)
    return np.argsort(distances)[:5]

def return_paths(argsorted,files):
    paths= []
    for i in argsorted:
        paths.append(files[i])
    return paths

def download_public_image(url, destination_path):
    response = requests.get(url)
    if response.status_code == 200:
        with open(destination_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded image to {destination_path}")
    else:
        print(f"Failed to download image from bucket. Status code: {response.status_code}")

def get_images(embedding):
    
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    local_paths = []
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        download_public_image(public_path, local_file_path)
        names = []
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]

    return classes, local_paths