Yanqing0327 commited on
Commit
f04a169
·
verified ·
1 Parent(s): 5afe4ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -14,13 +14,12 @@ from llava.model.builder import load_pretrained_model
14
  from llava.utils import disable_torch_init
15
  from llava.mm_utils import tokenizer_image_token
16
 
17
- # 确保 Hugging Face 缓存目录设置正确
18
  os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
19
 
20
  # **加载 LLaVA-1.5-13B**
21
  disable_torch_init()
22
- model_id = "Yanqing0327/LLaVA-project"
23
-
24
  tokenizer, model, image_processor, context_len = load_pretrained_model(
25
  model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
26
  )
@@ -28,7 +27,6 @@ tokenizer, model, image_processor, context_len = load_pretrained_model(
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  model = model.to(device)
30
 
31
-
32
  def load_image(image_file):
33
  """加载本地图片或 URL 图片"""
34
  if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
@@ -38,7 +36,6 @@ def load_image(image_file):
38
  image = Image.open(image_file).convert('RGB')
39
  return image
40
 
41
-
42
  def llava_infer(image, text, temperature, top_p, max_tokens):
43
  """LLaVA 模型推理"""
44
  if image is None or text.strip() == "":
 
14
  from llava.utils import disable_torch_init
15
  from llava.mm_utils import tokenizer_image_token
16
 
17
+ # **确保 Hugging Face 缓存目录正确**
18
  os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
19
 
20
  # **加载 LLaVA-1.5-13B**
21
  disable_torch_init()
22
+ model_id = "Yanqing0327/LLaVA-project" # 替换为你的 Hugging Face 模型仓库
 
23
  tokenizer, model, image_processor, context_len = load_pretrained_model(
24
  model_id, model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False
25
  )
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
  model = model.to(device)
29
 
 
30
  def load_image(image_file):
31
  """加载本地图片或 URL 图片"""
32
  if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
 
36
  image = Image.open(image_file).convert('RGB')
37
  return image
38
 
 
39
  def llava_infer(image, text, temperature, top_p, max_tokens):
40
  """LLaVA 模型推理"""
41
  if image is None or text.strip() == "":