Transformers
English
Raiff1982 commited on
Commit
ee62854
·
verified ·
1 Parent(s): f7a6098

Create ModelInfo.ts

Browse files
Files changed (1) hide show
  1. ModelInfo.ts +255 -0
ModelInfo.ts ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fs from 'fs';
2
+ import Ajv from 'ajv';
3
+
4
+ export class ModelInfo {
5
+ /**
6
+ * Key to config.json file.
7
+ */
8
+ key: string;
9
+ etag: string;
10
+ lastModified: Date;
11
+ size: number;
12
+ modelId: string;
13
+ author?: string;
14
+ siblings: any[];
15
+ config: any;
16
+ configTxt
17
+
18
+ ?: string; /// if flag is set when fetching.
19
+ downloads?: number; /// if flag is set when fetching.
20
+ naturalIdx: number;
21
+ cardSource?: string;
22
+ cardData?: any;
23
+
24
+ constructor(o: Partial<ModelInfo>) {
25
+ Object.assign(this, o);
26
+ this.config = this._loadConfig('ai_config.json');
27
+ }
28
+
29
+ private _loadConfig(filePath: string): any {
30
+ try {
31
+ const configData = fs.readFileSync(filePath, 'utf8');
32
+ return JSON.parse(configData);
33
+ } catch (error) {
34
+ console.error(`Failed to load config from ${filePath}:`, error);
35
+ return {
36
+ "model_name": "mistralai/Mistral-7B-Instruct-v0.2",
37
+ "max_input_length": 4096,
38
+ "safety_thresholds": {
39
+ "memory": 85,
40
+ "cpu": 90
41
+ }
42
+ };
43
+ }
44
+ }
45
+
46
+ get jsonUrl(): string {
47
+ return `https://your-bucket-url/${this.key}`;
48
+ }
49
+
50
+ get cdnJsonUrl(): string {
51
+ return `https://cdn.your-bucket-url/${this.key}`;
52
+ }
53
+
54
+ async validate(): Promise<Ajv.ErrorObject[] | undefined> {
55
+ const jsonSchema = JSON.parse(
56
+ await fs.promises.readFile('path/to/your/schema.json', 'utf8')
57
+ );
58
+ const ajv = new Ajv();
59
+ ajv.validate(jsonSchema, this
60
+
61
+ .config);
62
+ return ajv.errors ?? undefined;
63
+ }
64
+
65
+ /**
66
+ * Readme key, w. and w/o S3 prefix.
67
+ */
68
+ get readmeKey(): string {
69
+ return this.key.replace("config.json", "README.md");
70
+ }
71
+
72
+ get readmeTrimmedKey(): string {
73
+ return this.readmeKey.replace("S3_MODELS_PREFIX", "");
74
+ }
75
+
76
+ /**
77
+ * ["pytorch", "tf", ...]
78
+ */
79
+ get mlFrameworks(): string[] {
80
+ return Object.keys(FileType).filter(k => {
81
+ const filename = FileType[k];
82
+ const isExtension = filename.startsWith(".");
83
+ return isExtension
84
+ ? this.siblings.some(sibling => sibling.rfilename.endsWith(filename))
85
+ : this.siblings.some(sibling => sibling.rfilename === filename);
86
+ });
87
+ }
88
+
89
+ /**
90
+ * What to display in the code sample.
91
+ */
92
+ get autoArchitecture(): string {
93
+ const useTF = this.mlFrameworks.includes("tf") && !this.mlFrameworks.includes("pytorch");
94
+ const arch = this.autoArchType[0];
95
+ return
96
+
97
+ useTF ? `TF${arch}` : arch;
98
+ }
99
+
100
+ get autoArchType(): [string, string | undefined] {
101
+ const architectures = this.config.architectures;
102
+ if (!architectures || architectures.length === 0) {
103
+ return ["AutoModel", undefined];
104
+ }
105
+ const architecture = architectures[0].toString() as string;
106
+ if (architecture.endsWith("ForQuestionAnswering")) {
107
+ return ["AutoModelForQuestionAnswering", "question-answering"];
108
+ }
109
+ else if (architecture.endsWith("ForTokenClassification")) {
110
+ return ["AutoModelForTokenClassification", "token-classification"];
111
+ }
112
+ else if (architecture endsWith("ForSequenceClassification")) {
113
+ return ["AutoModelForSequenceClassification", "text-classification"];
114
+ }
115
+ else if (architecture endsWith("ForMultipleChoice")) {
116
+ return ["AutoModelForMultipleChoice", "multiple-choice"];
117
+ }
118
+ else if (architecture endsWith("ForPreTraining")) {
119
+ return ["AutoModelForPreTraining", "pretraining"];
120
+ }
121
+ else if (architecture endsWith("ForMaskedLM")) {
122
+ return ["AutoModelForMaskedLM", "masked-lm"];
123
+ }
124
+ else if (architecture endsWith("ForCausalLM")) {
125
+ return ["AutoModelForCausalLM", "causal-lm"];
126
+ }
127
+ else if (
128
+ architecture endsWith("ForConditionalGeneration")
129
+ || architecture endsWith("MTModel")
130
+ || architecture == "EncoderDecoderModel"
131
+ ) {
132
+ return ["AutoModelForSeq2SeqLM", "seq2seq"];
133
+ }
134
+ else if (architecture includes("LMHead")) {
135
+ return ["AutoModelWithLMHead", "lm-head"];
136
+ }
137
+ else if (architecture endsWith("Model")) {
138
+ return ["AutoModel", undefined];
139
+ }
140
+ else {
141
+ return [architecture, undefined];
142
+ }
143
+ }
144
+
145
+ /**
146
+ * All tags
147
+ */
148
+ get tags(): string[] {
149
+ const x = [
150
+ ...this.mlFrameworks,
151
+ ];
152
+ if (this.config.model_type) {
153
+ x.push(this.config.model_type);
154
+ }
155
+ const arch = this.autoArchType[1];
156
+ if (arch) {
157
+ x.push(arch);
158
+ }
159
+ if (arch === "lm-head" && this.config.model_type) {
160
+ if (
161
+ ["t5", "bart", "marian"].includes(this.config.model_type)) {
162
+
163
+
164
+ x.push("seq2seq");
165
+ }
166
+ else if (["gpt2", "ctrl", "openai-gpt", "xlnet", "transfo-xl", "reformer"].includes(this.config.model_type)) {
167
+ x.push("causal-lm");
168
+ }
169
+ else {
170
+ x.push("masked-lm");
171
+ }
172
+ }
173
+ x.push(...this.languages() ?? []);
174
+ x.push(...this.datasets().map(k => `dataset:${k}`));
175
+ for (let [k, v] of Object.entries(this.cardData ?? {})) {
176
+ if (!['tags', 'license'].includes(k)) {
177
+ /// ^^ whitelist of other accepted keys
178
+ continue;
179
+ }
180
+ if (typeof v === 'string') {
181
+ v = [ v ];
182
+ } else if (Utils.isStrArray(v)) {
183
+ /// ok
184
+ } else {
185
+ c.error(`Invalid ${k} tag type`, v);
186
+ c.debug(this.modelId);
187
+
188
+
189
+ continue;
190
+ }
191
+ if (k === 'license') {
192
+ x.push(...v.map(x => `license:${x.toLowerCase()}`));
193
+ } else {
194
+ x.push(...v);
195
+ }
196
+ }
197
+ if (this.config.task_specific_params) {
198
+ const keys = Object.keys(this.config.task_specific_params);
199
+ for (const key of keys) {
200
+ x.push(`pipeline:${key}`);
201
+ }
202
+ }
203
+ const explicit_ptag = this.cardData?.pipeline_tag;
204
+ if (explicit_ptag) {
205
+ if (typeof explicit_ptag === 'string') {
206
+ x.push(`pipeline_tag:${explicit_ptag}`);
207
+ } else {
208
+ x.push(`pipeline_tag:invalid`);
209
+ }
210
+ }
211
+ return [...new Set(x)];
212
+ }
213
+
214
+ get pipeline_tag(): (keyof typeof PipelineType) | undefined {
215
+ if (isBlacklisted(this.modelId) || this.cardData?.inference === false) {
216
+ return undefined;
217
+ }
218
+
219
+ const explicit_ptag = this.cardData?.pipeline_tag;
220
+ if (explicit_ptag) {
221
+ if (typeof explicit_ptag == 'string') {
222
+ return explicit_ptag as keyof typeof PipelineType;
223
+ } else {
224
+ c.error(`Invalid explicit pipeline_tag`, explicit_ptag);
225
+ return undefined;
226
+ }
227
+ }
228
+
229
+ const tags = this.tags;
230
+ /// Special case for translation
231
+ /// Get the first of the explicit tags that matches.
232
+ const EXPLICIT_PREFIX = "pipeline:";
233
+ const explicit_tag = tags find(x => x.startsWith(EXPLICIT_PREFIX + `translation`));
234
+ if (!!explicit_tag) {
235
+ return "translation";
236
+ }
237
+ /// Otherwise, get the first (most specific) match **from the mapping**.
238
+ for (const ptag of ALL_PIPELINE_TYPES) {
239
+ if (tags includes(ptag)) {
240
+ return ptag;
241
+ }
242
+ }
243
+ /// Extra mapping
244
+ const mapping = new Map<string, keyof typeof PipelineType>([
245
+ ["seq2seq", "text-generation"],
246
+ ["causal-lm", "text-generation"],
247
+ ["masked-lm", "fill-mask"],
248
+ ]);
249
+ for (const [tag, ptag] of mapping) {
250
+ if (tags includes(tag)) {
251
+ return ptag;
252
+ }
253
+ }
254
+ }
255
+ }