Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,13 +14,12 @@ from llava.model.builder import load_pretrained_model
|
|
14 |
from llava.utils import disable_torch_init
|
15 |
from llava.mm_utils import tokenizer_image_token
|
16 |
|
17 |
-
#
|
18 |
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
|
19 |
|
20 |
# **加载 LLaVA-1.5-13B**
|
21 |
disable_torch_init()
|
22 |
-
model_id = "Yanqing0327/LLaVA-project"
|
23 |
-
|
24 |
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
25 |
model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
|
26 |
)
|
@@ -28,7 +27,6 @@ tokenizer, model, image_processor, context_len = load_pretrained_model(
|
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
model = model.to(device)
|
30 |
|
31 |
-
|
32 |
def load_image(image_file):
|
33 |
"""加载本地图片或 URL 图片"""
|
34 |
if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
|
@@ -38,7 +36,6 @@ def load_image(image_file):
|
|
38 |
image = Image.open(image_file).convert('RGB')
|
39 |
return image
|
40 |
|
41 |
-
|
42 |
def llava_infer(image, text, temperature, top_p, max_tokens):
|
43 |
"""LLaVA 模型推理"""
|
44 |
if image is None or text.strip() == "":
|
|
|
14 |
from llava.utils import disable_torch_init
|
15 |
from llava.mm_utils import tokenizer_image_token
|
16 |
|
17 |
+
# **确保 Hugging Face 缓存目录正确**
|
18 |
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
|
19 |
|
20 |
# **加载 LLaVA-1.5-13B**
|
21 |
disable_torch_init()
|
22 |
+
model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库
|
|
|
23 |
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
24 |
model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
|
25 |
)
|
|
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
model = model.to(device)
|
29 |
|
|
|
30 |
def load_image(image_file):
|
31 |
"""加载本地图片或 URL 图片"""
|
32 |
if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
|
|
|
36 |
image = Image.open(image_file).convert('RGB')
|
37 |
return image
|
38 |
|
|
|
39 |
def llava_infer(image, text, temperature, top_p, max_tokens):
|
40 |
"""LLaVA 模型推理"""
|
41 |
if image is None or text.strip() == "":
|