Yanqing0327 commited on
Commit
dd68e80
·
verified ·
1 Parent(s): e6b49c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -47,7 +47,14 @@ def llava_infer(image, text, temperature, top_p, max_tokens):
47
 
48
  # 预处理图像
49
  image_data = load_image(image)
50
- image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device)
 
 
 
 
 
 
 
51
 
52
  # **处理对话**
53
  conv_mode = "llava_v1"
 
47
 
48
  # 预处理图像
49
  image_data = load_image(image)
50
+ image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values']
51
+
52
+ # **确保数据在正确的设备上**
53
+ image_tensor = image_tensor.to(device)
54
+ if torch.cuda.is_available():
55
+ image_tensor = image_tensor.half() # GPU: 用 float16
56
+ else:
57
+ image_tensor = image_tensor.float() # CPU: 用 float32
58
 
59
  # **处理对话**
60
  conv_mode = "llava_v1"