geoalgo commited on
Commit
f7abd76
·
1 Parent(s): 6f977e2

revert changes to ugly enum to have icons

Browse files
src/constants.py CHANGED
@@ -1,9 +1,11 @@
1
 
2
  class MetricNames:
 
3
  normalized_error: str = "normalized-error"
4
  fit_time_per_1K_rows: str = "fit-time-per-1K-rows"
5
  inference_time_per_1K_rows: str = "inference-time-per-1K-rows"
6
 
 
7
  class ProblemTypes:
8
  col_name: str = "problem_type"
9
  regression: str = "Regression"
 
1
 
2
  class MetricNames:
3
+ raw_error: str = "raw-error"
4
  normalized_error: str = "normalized-error"
5
  fit_time_per_1K_rows: str = "fit-time-per-1K-rows"
6
  inference_time_per_1K_rows: str = "inference-time-per-1K-rows"
7
 
8
+
9
  class ProblemTypes:
10
  col_name: str = "problem_type"
11
  regression: str = "Regression"
src/display/utils.py CHANGED
@@ -63,22 +63,37 @@ model_type_emoji = {
63
  MethodTypes.other: "❓",
64
  }
65
 
 
 
 
 
 
 
 
 
66
 
67
- @dataclass
68
- class ModelType:
69
- name: str
70
- symbol: str
71
 
72
  def to_str(self, separator=" "):
73
- return f"{self.symbol}{separator}{self.name}"
74
-
75
- @classmethod
76
- def from_str(cls, name: str):
77
- symbol = model_type_emoji.get(name, "❓")
78
- return cls(
79
- name=name,
80
- symbol=symbol,
81
- )
 
 
 
 
 
 
 
 
 
 
82
 
83
  class WeightType(Enum):
84
  Adapter = ModelDetails("Adapter")
 
63
  MethodTypes.other: "❓",
64
  }
65
 
66
+ def _make_model_details(name: str):
67
+ return ModelDetails(name=f"{model_type_emoji[name]} {name}", symbol=model_type_emoji[name])
68
+ class ModelType(Enum):
69
+ T1 = _make_model_details(MethodTypes.foundational)
70
+ T2 = _make_model_details(MethodTypes.finetuned)
71
+ T3 = _make_model_details(MethodTypes.automl)
72
+ T4 = _make_model_details(MethodTypes.boosted_tree)
73
+ T5 = _make_model_details(MethodTypes.other)
74
 
75
+ Unknown = ModelDetails(name="", symbol="?")
 
 
 
76
 
77
  def to_str(self, separator=" "):
78
+ return f"{self.value.symbol}{separator}{self.value.name}"
79
+
80
+ @staticmethod
81
+ def from_str(type):
82
+ if MethodTypes.foundational in type or model_type_emoji[MethodTypes.foundational] in type:
83
+ return ModelType.T1
84
+ if MethodTypes.finetuned in type or model_type_emoji[MethodTypes.finetuned] in type:
85
+ return ModelType.T2
86
+ if MethodTypes.automl in type or model_type_emoji[MethodTypes.automl] in type:
87
+ return ModelType.T3
88
+ if MethodTypes.boosted_tree in type or model_type_emoji[MethodTypes.boosted_tree] in type:
89
+ return ModelType.T4
90
+ if MethodTypes.other in type or model_type_emoji[MethodTypes.other] in type:
91
+ return ModelType.T5
92
+ return ModelType.T5
93
+
94
+ def to_str(self, separator=" "):
95
+ return f"{self.value.symbol}{separator}{self.value.name}"
96
+
97
 
98
  class WeightType(Enum):
99
  Adapter = ModelDetails("Adapter")
src/leaderboard/read_evals.py CHANGED
@@ -19,7 +19,7 @@ class ModelConfig:
19
  """Represents the model configuration of a model"""
20
  model: str
21
  model_link: str = ""
22
- model_type: ModelType = ModelType.from_str("?")
23
  code_link: str = ""
24
  precision: Precision = Precision.Unknown
25
  license: str = "?"
@@ -48,8 +48,8 @@ class ModelConfig:
48
  ModelInfoColumn.model.name: self.model,
49
  'model_w_link': model_hyperlink(self.model_link, self.code_link, self.model),
50
  ModelInfoColumn.precision.name: self.precision.value.name,
51
- ModelInfoColumn.model_type.name: self.model_type.name,
52
- ModelInfoColumn.model_type_symbol.name: self.model_type.symbol,
53
  # ModelInfoColumn.model.model_link: model_hyperlink(self.full_model),
54
  ModelInfoColumn.license.name: self.license,
55
  ModelInfoColumn.likes.name: self.likes,
 
19
  """Represents the model configuration of a model"""
20
  model: str
21
  model_link: str = ""
22
+ model_type: ModelType = ModelType.Unknown
23
  code_link: str = ""
24
  precision: Precision = Precision.Unknown
25
  license: str = "?"
 
48
  ModelInfoColumn.model.name: self.model,
49
  'model_w_link': model_hyperlink(self.model_link, self.code_link, self.model),
50
  ModelInfoColumn.precision.name: self.precision.value.name,
51
+ ModelInfoColumn.model_type.name: self.model_type.value.name,
52
+ ModelInfoColumn.model_type_symbol.name: self.model_type.value.symbol,
53
  # ModelInfoColumn.model.model_link: model_hyperlink(self.full_model),
54
  ModelInfoColumn.license.name: self.license,
55
  ModelInfoColumn.likes.name: self.likes,
src/utils.py CHANGED
@@ -192,6 +192,7 @@ def get_grouped_dfs(root_dir='results', ds_properties='results/dataset_propertie
192
  # grouped_results_overall = grouped_results_overall.rename(columns={'model':'Model'})
193
  # grouped_results.to_csv(f'artefacts/grouped_results_by_model.csv')
194
  grouped_dfs = {}
 
195
  for col_name in [ProblemTypes.col_name]:
196
  grouped_dfs[col_name] = group_by(df, col_name)
197
  # print(f"Grouping by {col_name}:\n {grouped_dfs.head(20)}")
 
192
  # grouped_results_overall = grouped_results_overall.rename(columns={'model':'Model'})
193
  # grouped_results.to_csv(f'artefacts/grouped_results_by_model.csv')
194
  grouped_dfs = {}
195
+ # for col_name in ["domain", 'term_length', 'frequency', 'univariate']:
196
  for col_name in [ProblemTypes.col_name]:
197
  grouped_dfs[col_name] = group_by(df, col_name)
198
  # print(f"Grouping by {col_name}:\n {grouped_dfs.head(20)}")