Werli commited on
Commit
d63e2b7
·
verified ·
1 Parent(s): e1adac7

Upload reorganizer_model.py

Browse files
Files changed (1) hide show
  1. reorganizer_model.py +101 -0
reorganizer_model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io,copy,requests,spaces,gradio as gr,numpy as np
3
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
4
+
5
+ # Experimental #
6
+
7
+ LAMINI_PROMPT_LONG= "gokaygokay/Lamini-Prompt-Enchance-Long"
8
+
9
+ class reorganizer_class:
10
+ def __init__(self, repoId: str, device: str = None, loadModel: bool = False):
11
+ self.modelPath = self.download_model(repoId)
12
+ if device is None:
13
+ import torch
14
+ self.totalVram = 0
15
+ if torch.cuda.is_available():
16
+ try:
17
+ deviceId = torch.cuda.current_device()
18
+ self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory / (1024 * 1024 * 1024)
19
+ except Exception as e:
20
+ print(traceback.format_exc())
21
+ print("Error detect vram: " + str(e))
22
+ device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
23
+ else:
24
+ device = "cpu"
25
+ self.device = device
26
+ self.system_prompt = "Reorganize and enhance the following English labels describing a single image into a readable English article:\n\n"
27
+ if loadModel:
28
+ self.load_model()
29
+
30
+ def download_model(self, repoId):
31
+ import huggingface_hub
32
+ allowPatterns = [
33
+ #"tf_model.h5",
34
+ #"model.ckpt.index",
35
+ #"flax_model.msgpack",
36
+ #"pytorch_model.bin",
37
+ "config.json",
38
+ "generation_config.json",
39
+ "model.safetensors",
40
+ "tokenizer.json",
41
+ "tokenizer_config.json",
42
+ "special_tokens_map.json",
43
+ "vocab.json",
44
+ "added_tokens.json",
45
+ "spiece.model"
46
+ ]
47
+ kwargs = {"allow_patterns": allowPatterns,}
48
+ try:
49
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
50
+ except (huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError) as exception:
51
+ import warnings
52
+ warnings.warn(
53
+ "An error occurred while synchronizing the model %s from the Hugging Face Hub:\n%s",
54
+ repoId,
55
+ exception,
56
+ )
57
+ warnings.warn(
58
+ "Trying to load the model directly from the local cache, if it exists."
59
+ )
60
+ kwargs["local_files_only"] = True
61
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
62
+
63
+ def load_model(self):
64
+ import transformers
65
+ try:
66
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
67
+ self.Tokenizer = T5Tokenizer.from_pretrained(self.modelPath)
68
+ self.Model = T5ForConditionalGeneration.from_pretrained(self.modelPath).to(self.device)
69
+ except Exception as e:
70
+ self.release_vram()
71
+ raise e
72
+
73
+ def release_vram(self):
74
+ try:
75
+ import torch
76
+ if torch.cuda.is_available():
77
+ if getattr(self, "Model", None) is not None:
78
+ self.Model.to('cpu')
79
+ del self.Model
80
+ if getattr(self, "Tokenizer", None) is not None:
81
+ del self.Tokenizer
82
+ import gc
83
+ gc.collect()
84
+ torch.cuda.empty_cache()
85
+ print("release vram end.")
86
+ except Exception as e:
87
+ print(traceback.format_exc())
88
+ print("Error release vram: " + str(e))
89
+
90
+ def reorganize(self, text: str, max_length: int = 400):
91
+ try:
92
+ input_ids = self.Tokenizer(self.system_prompt + text, return_tensors="pt").input_ids.to(self.device)
93
+ output = self.Model.generate(input_ids, max_length=max_length, no_repeat_ngram_size=3, num_beams=2, early_stopping=True)
94
+ result = self.Tokenizer.decode(output[0], skip_special_tokens=True)
95
+ return result
96
+ except Exception as e:
97
+ print(traceback.format_exc())
98
+ print("Error reorganize text: " + str(e))
99
+ return None
100
+
101
+ reorganizer_list=[LAMINI_PROMPT_LONG]