hswu commited on
Commit
816c7cd
·
1 Parent(s): 2541285

feat:增加device 参数

Browse files

fix: set_kv_cache使用默认device问题

Files changed (2) hide show
  1. inference.py +1 -1
  2. server.py +2 -2
inference.py CHANGED
@@ -399,7 +399,7 @@ class OmniInference:
399
  model = self.model
400
 
401
  with self.fabric.init_tensor():
402
- model.set_kv_cache(batch_size=2)
403
 
404
  mel, leng = load_audio(audio_path)
405
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
 
399
  model = self.model
400
 
401
  with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2,device=self.device)
403
 
404
  mel, leng = load_audio(audio_path)
405
  audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
server.py CHANGED
@@ -46,9 +46,9 @@ def create_app():
46
  return server.server
47
 
48
 
49
- def serve(ip='0.0.0.0', port=60808):
50
 
51
- OmniChatServer(ip, port=port, run_app=True)
52
 
53
 
54
  if __name__ == "__main__":
 
46
  return server.server
47
 
48
 
49
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
50
 
51
+ OmniChatServer(ip, port=port,run_app=True, device=device)
52
 
53
 
54
  if __name__ == "__main__":