Yanqing0327 commited on
Commit
e6b49c3
·
verified ·
1 Parent(s): 14f4774

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -28,13 +28,17 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
28
  model = model.to(device)
29
 
30
  def load_image(image_file):
31
- """加载本地图片或 URL 图片"""
32
- if isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
 
 
 
33
  response = requests.get(image_file)
34
- image = Image.open(BytesIO(response.content)).convert('RGB')
35
- else:
36
- image = Image.open(image_file).convert('RGB')
37
- return image
 
38
 
39
  def llava_infer(image, text, temperature, top_p, max_tokens):
40
  """LLaVA 模型推理"""
 
28
  model = model.to(device)
29
 
30
  def load_image(image_file):
31
+ """确保 image 是 `PIL.Image`"""
32
+ if isinstance(image_file, Image.Image):
33
+ return image_file.convert("RGB") # 直接返回 `PIL.Image`
34
+
35
+ elif isinstance(image_file, str) and (image_file.startswith('http') or image_file.startswith('https')):
36
  response = requests.get(image_file)
37
+ return Image.open(BytesIO(response.content)).convert('RGB')
38
+
39
+ else: # 这里如果 `image_file` 是路径
40
+ return Image.open(image_file).convert("RGB")
41
+
42
 
43
  def llava_infer(image, text, temperature, top_p, max_tokens):
44
  """LLaVA 模型推理"""