PhyloLM / loading.py
Daetheys's picture
Fixed bug with load_data(force_clone=True)
15c54ca
import os
import ujson as json
import pygit2
import shutil
from pygit2.enums import MergeFavor
from phylogeny import compute_all_P, compute_sim_matrix
from plotting import get_color, UNKNOWN_COLOR, DEFAULT_COLOR
# ------------------------------------------------------------------------------------------------
#
# Loading data
#
# ------------------------------------------------------------------------------------------------
def load_data(force_clone=False):
global UNKNOWN_COLOR, DEFAULT_COLOR, MODEL_SEARCHED_X
data, model_names,families = load_git(force_clone=force_clone)
if data is None:
return
#Rename families if needed
with open('family_table.json','r') as f:
rename_table = json.load(f)
for i in range(len(model_names)):
try:
families[i] = rename_table[model_names[i]]
except KeyError:
pass
all_P = compute_all_P(data, model_names)
sim_matrix = compute_sim_matrix(model_names, all_P)
k = list(all_P.keys())[0]
unknown_color = UNKNOWN_COLOR
unique_families = list(set([f for f in families]))
colors = {}
idx = 0
for i, family in enumerate(unique_families):
color = get_color(idx)
idx += 1
while color == unknown_color: # Avoid using the unknown color for a family
color = get_color(idx)
idx += 1
colors[family] = color
colors['?'] = unknown_color # Assign the unknown color to the unknown family
return data, model_names, families, sim_matrix, colors
def load_git(force_clone = False):
cred = pygit2.UserPass(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN'])
if not os.path.exists('Data') or force_clone:
# Remove the existing directory if it exists
if os.path.exists('Data'):
shutil.rmtree('Data')
repo = pygit2.clone_repository('https://github.com/PhyloLM/Data', './Data', bare=False, callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_TOKEN']))
else:
repo = pygit2.Repository('Data')
remote = repo.remotes['origin'] # Use named reference instead of index
fetch_results = remote.fetch()
print(fetch_results)
# Get the current branch name
branch_name = repo.head.shorthand
# Find the reference to the remote branch
remote_ref_name = f'refs/remotes/origin/{branch_name}'
# Merge the changes into the current branch
remote_commit = repo.lookup_reference(remote_ref_name).target
#Resolve conflicts if any : strategy : theirs
try:
repo.merge(remote_commit)
except Exception as e:
print(f"Merge error: {e}")
# Redownload the repository if merge fails
return load_git(force_clone=True)
data_array = []
model_names = []
families = []
for foname in os.listdir('Data/math'):
#check if it is a directory
if not os.path.isdir(os.path.join('Data/math',foname)):
continue
for fname in os.listdir('Data/math/'+foname):
if not fname.endswith('.json'):
continue
with open(os.path.join('Data/math',foname,fname),'r') as f:
d = json.load(f)
families.append(d['family'])
model_names.append(foname+'/'+fname[:-5])
data_array.append(d['alleles'])
if data_array == []:
return None,[],[]
return data_array,model_names,families
# ------------------------------------------------------------------------------------------------
#
# Git functions
#
# ------------------------------------------------------------------------------------------------
class GitHubRemoteCallbacks(pygit2.RemoteCallbacks):
def __init__(self, username, token):
self.username = username
self.token = token
super().__init__()
def credentials(self, url, username_from_url, allowed_types):
return pygit2.UserPass(self.username, self.token)
# ------------------------------------------------------------------------------------------------
#
# Saving data
#
# ------------------------------------------------------------------------------------------------
def save_git(alleles,genes,model,family):
repo = pygit2.Repository('Data')
remo = repo.remotes['origin']
d = {'family':family,'alleles':alleles}
model_name = model
data_path = f'math/{model_name}.json'
path = os.path.join('Data',data_path)
#create the file folder path
if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path), exist_ok=True)
#Open the file
with open(path,'w') as f:
json.dump(d,f)
repo.index.add(data_path)
repo.index.write()
reference='HEAD'
tree = repo.index.write_tree()
author = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
commiter = pygit2.Signature(os.environ['GITHUB_USERNAME'], os.environ['GITHUB_MAIL'])
oid = repo.create_commit(reference, author, commiter, f'Add data for model {model}', tree, [repo.head.target])
remo.push(['refs/heads/main'],callbacks=GitHubRemoteCallbacks(os.environ['GITHUB_USERNAME'],os.environ['GITHUB_TOKEN']))