shigeki Ishida commited on
Commit
5b27d64
·
1 Parent(s): bf7bdee

feat: add auto and float32 precision options, set auto as default

Browse files
Files changed (1) hide show
  1. src/display/utils.py +7 -2
src/display/utils.py CHANGED
@@ -133,10 +133,11 @@ class Precision(Enum):
133
  float16 = ModelDetails("float16")
134
  float32 = ModelDetails("float32")
135
  bfloat16 = ModelDetails("bfloat16")
 
136
 
137
  @staticmethod
138
  def from_str(precision: str) -> "Precision":
139
- if precision == "auto":
140
  return Precision.auto
141
  if precision in ["torch.float16", "float16"]:
142
  return Precision.float16
@@ -144,7 +145,11 @@ class Precision(Enum):
144
  return Precision.float32
145
  if precision in ["torch.bfloat16", "bfloat16"]:
146
  return Precision.bfloat16
147
- raise ValueError(f"Unsupported precision type: {precision}")
 
 
 
 
148
 
149
 
150
  class AddSpecialTokens(Enum):
 
133
  float16 = ModelDetails("float16")
134
  float32 = ModelDetails("float32")
135
  bfloat16 = ModelDetails("bfloat16")
136
+ float32 = ModelDetails("float32")
137
 
138
  @staticmethod
139
  def from_str(precision: str) -> "Precision":
140
+ if precision in ["auto", "Auto"]:
141
  return Precision.auto
142
  if precision in ["torch.float16", "float16"]:
143
  return Precision.float16
 
145
  return Precision.float32
146
  if precision in ["torch.bfloat16", "bfloat16"]:
147
  return Precision.bfloat16
148
+ if precision in ["torch.float32", "float32"]:
149
+ return Precision.float32
150
+ raise ValueError(
151
+ f"Unsupported precision type: {precision}. Please use 'auto' (recommended), 'float32', 'float16', or 'bfloat16'"
152
+ )
153
 
154
 
155
  class AddSpecialTokens(Enum):