Spaces:
Runtime error
Runtime error
Upload app.py
Browse files
app.py
CHANGED
@@ -6,16 +6,18 @@ from PIL import Image
|
|
6 |
import numpy as np
|
7 |
import transformers
|
8 |
from typing import Dict, Optional, Sequence, List
|
9 |
-
|
|
|
|
|
10 |
|
11 |
import sys
|
|
|
12 |
from oryx.conversation import conv_templates, SeparatorStyle
|
13 |
from oryx.model.builder import load_pretrained_model
|
14 |
from oryx.utils import disable_torch_init
|
15 |
from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
|
16 |
from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
17 |
|
18 |
-
|
19 |
model_path = "THUdyh/Oryx-7B"
|
20 |
model_name = get_model_name_from_path(model_path)
|
21 |
overwrite_config = {}
|
@@ -78,7 +80,7 @@ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_im
|
|
78 |
targets = torch.tensor(targets, dtype=torch.long)
|
79 |
return input_ids
|
80 |
|
81 |
-
|
82 |
def oryx_inference(video, text):
|
83 |
vr = VideoReader(video, ctx=cpu(0))
|
84 |
total_frame_num = len(vr)
|
|
|
6 |
import numpy as np
|
7 |
import transformers
|
8 |
from typing import Dict, Optional, Sequence, List
|
9 |
+
|
10 |
+
import subprocess
|
11 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
12 |
|
13 |
import sys
|
14 |
+
# sys.path.append('/mnt/lzy/oryx-demo')
|
15 |
from oryx.conversation import conv_templates, SeparatorStyle
|
16 |
from oryx.model.builder import load_pretrained_model
|
17 |
from oryx.utils import disable_torch_init
|
18 |
from oryx.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video_genli
|
19 |
from oryx.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
20 |
|
|
|
21 |
model_path = "THUdyh/Oryx-7B"
|
22 |
model_name = get_model_name_from_path(model_path)
|
23 |
overwrite_config = {}
|
|
|
80 |
targets = torch.tensor(targets, dtype=torch.long)
|
81 |
return input_ids
|
82 |
|
83 |
+
|
84 |
def oryx_inference(video, text):
|
85 |
vr = VideoReader(video, ctx=cpu(0))
|
86 |
total_frame_num = len(vr)
|