File size: 3,963 Bytes
05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e 8ef2cad 05d640e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from transformers import PreTrainedModel, PretrainedConfig
from .config import MoondreamConfig
from .moondream import MoondreamModel
# Files sometimes don't get loaded without these...
from .image_crops import *
from .vision import *
from .text import *
from .region import *
from .utils import *
def extract_question(text):
prefix = "<image>\n\nQuestion: "
suffix = "\n\nAnswer:"
if text.startswith(prefix) and text.endswith(suffix):
return text[len(prefix) : -len(suffix)]
else:
return None
class HfConfig(PretrainedConfig):
_auto_class = "AutoConfig"
model_type = "moondream1"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.config = {}
class HfMoondream(PreTrainedModel):
_auto_class = "AutoModelForCausalLM"
config_class = HfConfig
def __init__(self, config):
super().__init__(config)
self.model = MoondreamModel(
MoondreamConfig.from_dict(config.config), setup_caches=False
)
self._is_kv_cache_setup = False
def _setup_caches(self):
if not self._is_kv_cache_setup:
self.model._setup_caches()
self._is_kv_cache_setup = True
@property
def encode_image(self):
self._setup_caches()
return self.model.encode_image
@property
def query(self):
self._setup_caches()
return self.model.query
@property
def caption(self):
self._setup_caches()
return self.model.caption
@property
def detect(self):
self._setup_caches()
return self.model.detect
@property
def point(self):
self._setup_caches()
return self.model.point
@property
def detect_gaze(self):
self._setup_caches()
return self.model.detect_gaze
def answer_question(
self,
image_embeds,
question,
tokenizer=None,
chat_history="",
result_queue=None,
max_new_tokens=256,
**kwargs
):
answer = self.query(image_embeds, question)["answer"].strip()
if result_queue is not None:
result_queue.put(answer)
return answer
def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
answers = []
for image, prompt in zip(images, prompts):
answers.append(self.query(image, prompt)["answer"].strip())
return answers
def _unsupported_exception(self):
raise NotImplementedError(
"This method is not supported in the latest version of moondream. "
"Consider upgrading to the updated API spec, or alternately pin "
"to 'revision=2024-08-26'."
)
def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
"""
Function definition remains unchanged for backwards compatibility.
Be aware that tokenizer, max_new_takens, and kwargs are ignored.
"""
prompt_extracted = extract_question(prompt)
if prompt_extracted is not None:
answer = self.model.query(
image=image_embeds, question=prompt_extracted, stream=False
)["answer"]
else:
image_embeds = self.encode_image(image_embeds)
prompt_tokens = torch.tensor(
[self.model.tokenizer.encode(prompt).ids],
device=self.device,
)
def generator():
for token in self.model._generate_text(
prompt_tokens,
image_embeds.kv_cache,
image_embeds.pos,
max_new_tokens,
):
yield token
answer = "".join(list(generator()))
return [answer]
def get_input_embeddings(self):
return super().get_input_embeddings()
def input_embeds(self, *args, **kwargs):
self._unsupported_exception()
|