sflindrs commited on
Commit
e6a9c05
·
verified ·
1 Parent(s): 3f01084

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -7,6 +7,7 @@ model_path = "ibm-granite/granite-vision-3.1-2b-preview"
7
  processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
8
  model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
9
 
 
10
  def get_text_from_content(content):
11
  texts = []
12
  for item in content:
@@ -39,7 +40,7 @@ def chat_inference(image, text, temperature, top_p, top_k, max_tokens, conversat
39
  tokenize=True,
40
  return_dict=True,
41
  return_tensors="pt"
42
- ).to("cpu")
43
 
44
  torch.manual_seed(random.randint(0, 10000))
45
 
 
7
  processor = LlavaNextProcessor.from_pretrained(model_path, use_fast=True)
8
  model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto", device_map="auto")
9
 
10
+ @spaces.GPU()
11
  def get_text_from_content(content):
12
  texts = []
13
  for item in content:
 
40
  tokenize=True,
41
  return_dict=True,
42
  return_tensors="pt"
43
+ ).to(model.device)
44
 
45
  torch.manual_seed(random.randint(0, 10000))
46