Zhiding commited on
Commit
2c8a4a7
·
1 Parent(s): 9841aa1
Files changed (1) hide show
  1. eagle_vl/serve/inference.py +31 -1
eagle_vl/serve/inference.py CHANGED
@@ -18,7 +18,7 @@ from .chat_utils import Conversation, get_conv_template
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
- def load_model(model_path: str = "nvidia/Eagle-2.5-8B"):
22
 
23
  token = os.environ.get("HF_TOKEN")
24
  # hotfix the model to use flash attention 2
@@ -41,6 +41,36 @@ def load_model(model_path: str = "nvidia/Eagle-2.5-8B"):
41
 
42
  return model, processor
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class StoppingCriteriaSub(StoppingCriteria):
46
  def __init__(self, stops=[], encounters=1):
 
18
  logger = logging.getLogger(__name__)
19
 
20
 
21
+ def load_model_from_nv(model_path: str = "nvidia/Eagle-2.5-8B"):
22
 
23
  token = os.environ.get("HF_TOKEN")
24
  # hotfix the model to use flash attention 2
 
41
 
42
  return model, processor
43
 
44
+ def load_model_from_eagle(model_path: str = "NVEagle/Eagle2.5-VL-8B-Preview"):
45
+
46
+ token = os.environ.get("HF_TOKEN")
47
+ # hotfix the model to use flash attention 2
48
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, use_auth_token=token)
49
+ config._attn_implementation = "flash_attention_2"
50
+ config.vision_config._attn_implementation = "flash_attention_2"
51
+ config.text_config._attn_implementation = "flash_attention_2"
52
+ print("Successfully set the attn_implementation to flash_attention_2")
53
+
54
+ logger.info(f"token = {token[:4]}***{token[-2:]}")
55
+ model = AutoModel.from_pretrained(
56
+ model_path,
57
+ trust_remote_code=True,
58
+ torch_dtype=torch.bfloat16,
59
+ attn_implementation="flash_attention_2",
60
+ use_auth_token=token
61
+ )
62
+ model.to("cuda")
63
+ processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True, use_fast=True, use_auth_token=token)
64
+
65
+ return model, processor
66
+
67
+ def load_model(model_path: str = "nvidia/Eagle-2.5-8B"):
68
+ try:
69
+ model, processor = load_model_from_nv(model_path)
70
+ except Exception as e:
71
+ logger.error(f"Failed to load model from HF, trying to load from eagle: {e}")
72
+ model, processor = load_model_from_eagle(model_path)
73
+ return model, processor
74
 
75
  class StoppingCriteriaSub(StoppingCriteria):
76
  def __init__(self, stops=[], encounters=1):