|
import os |
|
import re |
|
|
|
import timm.models.hub as timm_hub |
|
import torch |
|
import torch.distributed as dist |
|
from mmengine.dist import is_distributed, is_main_process |
|
from transformers import StoppingCriteria |
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
|
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all((stop == input_ids[0][-len(stop):])).item(): |
|
return True |
|
|
|
return False |
|
|
|
|
|
def download_cached_file(url, check_hash=True, progress=False): |
|
"""Download a file from a URL and cache it locally. |
|
|
|
If the file already exists, it is not downloaded again. If distributed, |
|
only the main process downloads the file, and the other processes wait for |
|
the file to be downloaded. |
|
""" |
|
|
|
def get_cached_file_path(): |
|
|
|
parts = torch.hub.urlparse(url) |
|
filename = os.path.basename(parts.path) |
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) |
|
|
|
return cached_file |
|
|
|
if is_main_process(): |
|
timm_hub.download_cached_file(url, check_hash, progress) |
|
|
|
if is_distributed(): |
|
dist.barrier() |
|
|
|
return get_cached_file_path() |
|
|
|
|
|
def is_url(input_url): |
|
"""Check if an input string is a url. |
|
|
|
look for http(s):// and ignoring the case |
|
""" |
|
is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None |
|
return is_url |
|
|