File size: 4,158 Bytes
d63e2b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import io,copy,requests,spaces,gradio as gr,numpy as np
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Experimental #

LAMINI_PROMPT_LONG= "gokaygokay/Lamini-Prompt-Enchance-Long"

class reorganizer_class:
    def __init__(self, repoId: str, device: str = None, loadModel: bool = False):
        self.modelPath = self.download_model(repoId)
        if device is None:
            import torch
            self.totalVram = 0
            if torch.cuda.is_available():
                try:
                    deviceId = torch.cuda.current_device()
                    self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory / (1024 * 1024 * 1024)
                except Exception as e:
                    print(traceback.format_exc())
                    print("Error detect vram: " + str(e))
                device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
            else:
                device = "cpu"
        self.device = device
        self.system_prompt = "Reorganize and enhance the following English labels describing a single image into a readable English article:\n\n"
        if loadModel:
            self.load_model()

    def download_model(self, repoId):
        import huggingface_hub
        allowPatterns = [
            #"tf_model.h5",
            #"model.ckpt.index",
            #"flax_model.msgpack",
            #"pytorch_model.bin",
            "config.json",
            "generation_config.json",
            "model.safetensors",
            "tokenizer.json",
            "tokenizer_config.json",
            "special_tokens_map.json",
            "vocab.json",
            "added_tokens.json",
            "spiece.model"
        ]
        kwargs = {"allow_patterns": allowPatterns,}
        try:
            return huggingface_hub.snapshot_download(repoId, **kwargs)
        except (huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError) as exception:
            import warnings
            warnings.warn(
                "An error occurred while synchronizing the model %s from the Hugging Face Hub:\n%s",
                repoId,
                exception,
            )
            warnings.warn(
                "Trying to load the model directly from the local cache, if it exists."
            )
            kwargs["local_files_only"] = True
            return huggingface_hub.snapshot_download(repoId, **kwargs)

    def load_model(self):
        import transformers
        try:
            print('\n\nLoading model: %s\n\n' % self.modelPath)
            self.Tokenizer = T5Tokenizer.from_pretrained(self.modelPath)
            self.Model = T5ForConditionalGeneration.from_pretrained(self.modelPath).to(self.device)
        except Exception as e:
            self.release_vram()
            raise e

    def release_vram(self):
        try:
            import torch
            if torch.cuda.is_available():
                if getattr(self, "Model", None) is not None:
                    self.Model.to('cpu')
                    del self.Model
                if getattr(self, "Tokenizer", None) is not None:
                    del self.Tokenizer
                import gc
                gc.collect()
                torch.cuda.empty_cache()
                print("release vram end.")
        except Exception as e:
            print(traceback.format_exc())
            print("Error release vram: " + str(e))

    def reorganize(self, text: str, max_length: int = 400):
        try:
            input_ids = self.Tokenizer(self.system_prompt + text, return_tensors="pt").input_ids.to(self.device)
            output = self.Model.generate(input_ids, max_length=max_length, no_repeat_ngram_size=3, num_beams=2, early_stopping=True)
            result = self.Tokenizer.decode(output[0], skip_special_tokens=True)
            return result
        except Exception as e:
            print(traceback.format_exc())
            print("Error reorganize text: " + str(e))
            return None

reorganizer_list=[LAMINI_PROMPT_LONG]