TwT-6's picture
Upload 2667 files
256a159 verified
import os
import re
import timm.models.hub as timm_hub
import torch
import torch.distributed as dist
from mmengine.dist import is_distributed, is_main_process
from transformers import StoppingCriteria
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
for stop in self.stops:
if torch.all((stop == input_ids[0][-len(stop):])).item():
return True
return False
def download_cached_file(url, check_hash=True, progress=False):
"""Download a file from a URL and cache it locally.
If the file already exists, it is not downloaded again. If distributed,
only the main process downloads the file, and the other processes wait for
the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_distributed():
dist.barrier()
return get_cached_file_path()
def is_url(input_url):
"""Check if an input string is a url.
look for http(s):// and ignoring the case
"""
is_url = re.match(r'^(?:http)s?://', input_url, re.IGNORECASE) is not None
return is_url