Spaces:
Sleeping
Sleeping
feat:增加device 参数
Browse filesfix: set_kv_cache使用默认device问题
- inference.py +1 -1
- server.py +2 -2
inference.py
CHANGED
@@ -399,7 +399,7 @@ class OmniInference:
|
|
399 |
model = self.model
|
400 |
|
401 |
with self.fabric.init_tensor():
|
402 |
-
model.set_kv_cache(batch_size=2)
|
403 |
|
404 |
mel, leng = load_audio(audio_path)
|
405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
|
|
399 |
model = self.model
|
400 |
|
401 |
with self.fabric.init_tensor():
|
402 |
+
model.set_kv_cache(batch_size=2,device=self.device)
|
403 |
|
404 |
mel, leng = load_audio(audio_path)
|
405 |
audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
|
server.py
CHANGED
@@ -46,9 +46,9 @@ def create_app():
|
|
46 |
return server.server
|
47 |
|
48 |
|
49 |
-
def serve(ip='0.0.0.0', port=60808):
|
50 |
|
51 |
-
OmniChatServer(ip, port=port,
|
52 |
|
53 |
|
54 |
if __name__ == "__main__":
|
|
|
46 |
return server.server
|
47 |
|
48 |
|
49 |
+
def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
|
50 |
|
51 |
+
OmniChatServer(ip, port=port,run_app=True, device=device)
|
52 |
|
53 |
|
54 |
if __name__ == "__main__":
|