Multi-Tagger / reorganizer_model.py
Werli's picture
Upload reorganizer_model.py
d63e2b7 verified
raw
history blame
4.16 kB
import os
import io,copy,requests,spaces,gradio as gr,numpy as np
from transformers import T5ForConditionalGeneration, T5Tokenizer
# Experimental #
LAMINI_PROMPT_LONG= "gokaygokay/Lamini-Prompt-Enchance-Long"
class reorganizer_class:
def __init__(self, repoId: str, device: str = None, loadModel: bool = False):
self.modelPath = self.download_model(repoId)
if device is None:
import torch
self.totalVram = 0
if torch.cuda.is_available():
try:
deviceId = torch.cuda.current_device()
self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory / (1024 * 1024 * 1024)
except Exception as e:
print(traceback.format_exc())
print("Error detect vram: " + str(e))
device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
else:
device = "cpu"
self.device = device
self.system_prompt = "Reorganize and enhance the following English labels describing a single image into a readable English article:\n\n"
if loadModel:
self.load_model()
def download_model(self, repoId):
import huggingface_hub
allowPatterns = [
#"tf_model.h5",
#"model.ckpt.index",
#"flax_model.msgpack",
#"pytorch_model.bin",
"config.json",
"generation_config.json",
"model.safetensors",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"added_tokens.json",
"spiece.model"
]
kwargs = {"allow_patterns": allowPatterns,}
try:
return huggingface_hub.snapshot_download(repoId, **kwargs)
except (huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError) as exception:
import warnings
warnings.warn(
"An error occurred while synchronizing the model %s from the Hugging Face Hub:\n%s",
repoId,
exception,
)
warnings.warn(
"Trying to load the model directly from the local cache, if it exists."
)
kwargs["local_files_only"] = True
return huggingface_hub.snapshot_download(repoId, **kwargs)
def load_model(self):
import transformers
try:
print('\n\nLoading model: %s\n\n' % self.modelPath)
self.Tokenizer = T5Tokenizer.from_pretrained(self.modelPath)
self.Model = T5ForConditionalGeneration.from_pretrained(self.modelPath).to(self.device)
except Exception as e:
self.release_vram()
raise e
def release_vram(self):
try:
import torch
if torch.cuda.is_available():
if getattr(self, "Model", None) is not None:
self.Model.to('cpu')
del self.Model
if getattr(self, "Tokenizer", None) is not None:
del self.Tokenizer
import gc
gc.collect()
torch.cuda.empty_cache()
print("release vram end.")
except Exception as e:
print(traceback.format_exc())
print("Error release vram: " + str(e))
def reorganize(self, text: str, max_length: int = 400):
try:
input_ids = self.Tokenizer(self.system_prompt + text, return_tensors="pt").input_ids.to(self.device)
output = self.Model.generate(input_ids, max_length=max_length, no_repeat_ngram_size=3, num_beams=2, early_stopping=True)
result = self.Tokenizer.decode(output[0], skip_special_tokens=True)
return result
except Exception as e:
print(traceback.format_exc())
print("Error reorganize text: " + str(e))
return None
reorganizer_list=[LAMINI_PROMPT_LONG]