File size: 1,996 Bytes
3d6ba31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c80c514
 
3d6ba31
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import transformers
from huggingface_hub import snapshot_download,constants

def download_llm_to_cache(model_name, revision="main", cache_dir=None):
    """
    Download an LLM from the Hugging Face Hub to the cache without loading it into memory.
    
    Args:
        model_name (str): The name of the model on Hugging Face Hub (e.g., "meta-llama/Llama-2-7b-hf")
        revision (str, optional): The specific model version to use. Defaults to "main".
        cache_dir (str, optional): The cache directory to use. If None, uses the default HF cache directory.
        
    Returns:
        str: Path to the model in cache
    """
    # Get default cache dir if not specified
    if cache_dir is None:
        cache_dir = constants.HUGGINGFACE_HUB_CACHE
    
    try:
        # Download model to cache without loading into memory
        cached_path = snapshot_download(
            repo_id=model_name,
            revision=revision,
            cache_dir=cache_dir,
            local_files_only=False  # Set to True if you want to check local cache only
        )
        
        print(f"Model '{model_name}' is available in cache at: {cached_path}")
        return cached_path
        
    except Exception as e:
        print(f"Error downloading model '{model_name}': {e}")
        return None

def load_model(path,cache_dir=None):
    model = transformers.AutoModelForCausalLM.from_pretrained(path,cache_dir=cache_dir,device_map='auto',trust_remote_code=False)
    tokenizer = transformers.AutoTokenizer.from_pretrained(path,cache_dir=cache_dir,device_map='auto',trust_remote_code=False)
    return model,tokenizer

def llm_run(model,tokenizer,genes,N):
    generate = transformers.pipeline('text-generation',model=model, tokenizer=tokenizer,device_map='auto')
    output = []
    for i,gene in enumerate(genes):
        out = generate([gene], min_new_tokens=4, max_new_tokens=4, do_sample=True, num_return_sequences=N)
        output.append(out[0])
        yield output
    return output