Aurelien-Morgan-Bot commited on
Commit
f422804
·
verified ·
1 Parent(s): 04055fd

source-code for model version v0.28_20250409_220919394_UTC- retrain-pipelines 0.1.1

Browse files
v0.28_20250409_220919394_UTC/requirements.txt ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==1.5.2
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.15
5
+ aiosignal==1.3.2
6
+ alabaster==1.0.0
7
+ albucore==0.0.23
8
+ albumentations==2.0.5
9
+ ale-py==0.10.2
10
+ altair==5.5.0
11
+ annotated-types==0.7.0
12
+ anyio==4.9.0
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ array_record==0.7.1
16
+ arviz==0.21.0
17
+ astropy==7.0.1
18
+ astropy-iers-data==0.2025.3.31.0.36.18
19
+ astunparse==1.6.3
20
+ atpublic==5.1
21
+ attrs==25.3.0
22
+ audioread==3.0.1
23
+ autograd==1.7.0
24
+ babel==2.17.0
25
+ backcall==0.2.0
26
+ beautifulsoup4==4.13.3
27
+ betterproto==2.0.0b6
28
+ bigframes==1.42.0
29
+ bigquery-magics==0.9.0
30
+ bitsandbytes==0.45.5
31
+ bleach==6.2.0
32
+ blinker==1.9.0
33
+ blis==1.2.1
34
+ blosc2==3.2.1
35
+ bokeh==3.6.3
36
+ boto3==1.37.30
37
+ botocore==1.37.30
38
+ Bottleneck==1.4.2
39
+ bqplot==0.12.44
40
+ branca==0.8.1
41
+ CacheControl==0.14.2
42
+ cachetools==5.5.2
43
+ catalogue==2.0.10
44
+ certifi==2025.1.31
45
+ cffi==1.17.1
46
+ chardet==5.2.0
47
+ charset-normalizer==3.4.1
48
+ chex==0.1.89
49
+ clarabel==0.10.0
50
+ click==8.1.8
51
+ cloudpathlib==0.21.0
52
+ cloudpickle==3.1.1
53
+ cmake==3.31.6
54
+ cmdstanpy==1.2.5
55
+ colorama==0.4.6
56
+ colorcet==3.1.0
57
+ colorlover==0.3.0
58
+ colour==0.1.5
59
+ comm==0.2.2
60
+ community==1.0.0b1
61
+ confection==0.1.5
62
+ cons==0.4.6
63
+ contourpy==1.3.1
64
+ cramjam==2.9.1
65
+ cryptography==43.0.3
66
+ cuda-python==12.6.0
67
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
68
+ cudf-polars-cu12==24.12.0
69
+ cufflinks==0.17.3
70
+ cuml-cu12==25.2.1
71
+ cupy-cuda12x==13.3.0
72
+ cut-cross-entropy==25.1.1
73
+ cuvs-cu12==25.2.1
74
+ cvxopt==1.3.2
75
+ cvxpy==1.6.4
76
+ cycler==0.12.1
77
+ cyipopt==1.5.0
78
+ cymem==2.0.11
79
+ Cython==3.0.12
80
+ dask==2024.12.1
81
+ dask-cuda==25.2.0
82
+ dask-cudf-cu12==25.2.2
83
+ dask-expr==1.1.21
84
+ datascience==0.17.6
85
+ datasets==3.1.0
86
+ db-dtypes==1.4.2
87
+ dbus-python==1.2.18
88
+ debugpy==1.8.0
89
+ decorator==4.4.2
90
+ defusedxml==0.7.1
91
+ Deprecated==1.2.18
92
+ diffusers==0.32.2
93
+ dill==0.3.8
94
+ distributed==2024.12.1
95
+ distributed-ucxx-cu12==0.42.0
96
+ distro==1.9.0
97
+ dlib==19.24.6
98
+ dm-tree==0.1.9
99
+ docker==7.1.0
100
+ docker-pycreds==0.4.0
101
+ docstring_parser==0.16
102
+ docutils==0.21.2
103
+ dopamine_rl==4.1.2
104
+ duckdb==1.2.1
105
+ earthengine-api==1.5.9
106
+ easydict==1.13
107
+ editdistance==0.8.1
108
+ eerepr==0.1.1
109
+ einops==0.8.1
110
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
111
+ entrypoints==0.4
112
+ et_xmlfile==2.0.0
113
+ etils==1.12.2
114
+ etuples==0.3.9
115
+ Farama-Notifications==0.0.4
116
+ fastai==2.7.19
117
+ fastapi==0.115.12
118
+ fastcore==1.7.29
119
+ fastdownload==0.0.7
120
+ fastjsonschema==2.21.1
121
+ fastprogress==1.0.3
122
+ fastrlock==0.8.3
123
+ filelock==3.18.0
124
+ firebase-admin==6.7.0
125
+ Flask==3.1.0
126
+ flatbuffers==25.2.10
127
+ flax==0.10.5
128
+ folium==0.19.5
129
+ fonttools==4.57.0
130
+ frozendict==2.4.6
131
+ frozenlist==1.5.0
132
+ fsspec==2024.9.0
133
+ future==1.0.0
134
+ gast==0.6.0
135
+ GDAL==3.6.4
136
+ gdown==5.2.0
137
+ geemap==0.35.3
138
+ geocoder==1.38.1
139
+ geographiclib==2.0
140
+ geopandas==1.0.1
141
+ geopy==2.4.1
142
+ gin-config==0.5.0
143
+ gitdb==4.0.12
144
+ GitPython==3.1.44
145
+ glob2==0.7
146
+ google==2.0.3
147
+ google-ai-generativelanguage==0.6.15
148
+ google-api-core==2.24.2
149
+ google-api-python-client==2.164.0
150
+ google-auth==2.38.0
151
+ google-auth-httplib2==0.2.0
152
+ google-auth-oauthlib==1.2.1
153
+ google-cloud-aiplatform==1.87.0
154
+ google-cloud-bigquery==3.31.0
155
+ google-cloud-bigquery-connection==1.18.2
156
+ google-cloud-bigquery-storage==2.30.0
157
+ google-cloud-bigtable==2.30.0
158
+ google-cloud-core==2.4.3
159
+ google-cloud-dataproc==5.18.1
160
+ google-cloud-datastore==2.20.2
161
+ google-cloud-firestore==2.20.1
162
+ google-cloud-functions==1.20.2
163
+ google-cloud-iam==2.18.3
164
+ google-cloud-language==2.17.1
165
+ google-cloud-pubsub==2.29.0
166
+ google-cloud-resource-manager==1.14.2
167
+ google-cloud-spanner==3.53.0
168
+ google-cloud-storage==2.19.0
169
+ google-cloud-translate==3.20.2
170
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
171
+ google-crc32c==1.7.1
172
+ google-genai==1.9.0
173
+ google-generativeai==0.8.4
174
+ google-pasta==0.2.0
175
+ google-resumable-media==2.7.2
176
+ google-spark-connect==0.5.2
177
+ googleapis-common-protos==1.69.2
178
+ googledrivedownloader==1.1.0
179
+ graphviz==0.20.3
180
+ greenlet==3.1.1
181
+ grpc-google-iam-v1==0.14.2
182
+ grpc-interceptor==0.15.4
183
+ grpcio==1.71.0
184
+ grpcio-status==1.71.0
185
+ grpclib==0.4.7
186
+ gspread==6.2.0
187
+ gspread-dataframe==4.0.0
188
+ gym==0.25.2
189
+ gym-notices==0.0.8
190
+ gymnasium==1.1.1
191
+ h11==0.14.0
192
+ h2==4.2.0
193
+ h5netcdf==1.6.1
194
+ h5py==3.13.0
195
+ hdbscan==0.8.40
196
+ hf_transfer==0.1.9
197
+ highspy==1.9.0
198
+ holidays==0.69
199
+ holoviews==1.20.2
200
+ hpack==4.1.0
201
+ html5lib==1.1
202
+ httpcore==1.0.7
203
+ httpimport==1.4.1
204
+ httplib2==0.22.0
205
+ httptools==0.6.4
206
+ httpx==0.28.1
207
+ huggingface-hub==0.27.1
208
+ humanize==4.12.2
209
+ hyperframe==6.1.0
210
+ hyperopt==0.2.7
211
+ ibis-framework==9.5.0
212
+ idna==3.10
213
+ imageio==2.37.0
214
+ imageio-ffmpeg==0.6.0
215
+ imagesize==1.4.1
216
+ imbalanced-learn==0.13.0
217
+ immutabledict==4.2.1
218
+ importlib_metadata==8.6.1
219
+ importlib_resources==6.5.2
220
+ imutils==0.5.4
221
+ inflect==7.5.0
222
+ iniconfig==2.1.0
223
+ intel-cmplr-lib-ur==2025.1.0
224
+ intel-openmp==2025.1.0
225
+ ipyevents==2.0.2
226
+ ipyfilechooser==0.6.0
227
+ ipykernel==6.29.5
228
+ ipyleaflet==0.19.2
229
+ ipyparallel==8.8.0
230
+ ipython==7.34.0
231
+ ipython-genutils==0.2.0
232
+ ipython-sql==0.5.0
233
+ ipytree==0.2.2
234
+ ipywidgets==7.7.1
235
+ itsdangerous==2.2.0
236
+ jax==0.5.2
237
+ jax-cuda12-pjrt==0.5.1
238
+ jax-cuda12-plugin==0.5.1
239
+ jaxlib==0.5.1
240
+ jedi==0.19.2
241
+ jeepney==0.7.1
242
+ jellyfish==1.1.0
243
+ jieba==0.42.1
244
+ Jinja2==3.1.4
245
+ jiter==0.9.0
246
+ jmespath==1.0.1
247
+ joblib==1.4.2
248
+ jsonpatch==1.33
249
+ jsonpickle==4.0.5
250
+ jsonpointer==3.0.0
251
+ jsonschema==4.23.0
252
+ jsonschema-specifications==2024.10.1
253
+ jupyter-client==6.1.12
254
+ jupyter-console==6.1.0
255
+ jupyter-leaflet==0.19.2
256
+ jupyter-server==1.16.0
257
+ jupyter_core==5.7.2
258
+ jupyterlab_pygments==0.3.0
259
+ jupyterlab_widgets==3.0.13
260
+ kaggle==1.7.4.2
261
+ kagglehub==0.3.11
262
+ keras==3.8.0
263
+ keras-hub==0.18.1
264
+ keras-nlp==0.18.1
265
+ keyring==23.5.0
266
+ kiwisolver==1.4.8
267
+ langchain==0.3.22
268
+ langchain-core==0.3.50
269
+ langchain-text-splitters==0.3.7
270
+ langcodes==3.5.0
271
+ langsmith==0.3.23
272
+ language_data==1.3.0
273
+ launchpadlib==1.10.16
274
+ lazr.restfulclient==0.14.4
275
+ lazr.uri==1.0.6
276
+ lazy_loader==0.4
277
+ libclang==18.1.1
278
+ libcudf-cu12==24.12.0
279
+ libcugraph-cu12==25.2.0
280
+ libcuml-cu12==25.2.1
281
+ libcuvs-cu12==25.2.1
282
+ libkvikio-cu12==24.12.1
283
+ libraft-cu12==25.2.0
284
+ librosa==0.11.0
285
+ libucx-cu12==1.18.0
286
+ libucxx-cu12==0.42.0
287
+ lightgbm==4.5.0
288
+ linkify-it-py==2.0.3
289
+ litserve==0.2.6
290
+ llvmlite==0.43.0
291
+ locket==1.0.0
292
+ logical-unification==0.4.6
293
+ lxml==5.3.0
294
+ Mako==1.1.3
295
+ marisa-trie==1.2.1
296
+ Markdown==3.7
297
+ markdown-it-py==3.0.0
298
+ MarkupSafe==3.0.2
299
+ matplotlib==3.9.2
300
+ matplotlib-inline==0.1.7
301
+ matplotlib-venn==1.1.2
302
+ mdit-py-plugins==0.4.2
303
+ mdurl==0.1.2
304
+ metaflow==2.10.0
305
+ metaflow-card-html==1.0.2
306
+ miniKanren==1.0.3
307
+ missingno==0.5.2
308
+ mistune==3.1.3
309
+ mizani==0.13.2
310
+ mkl==2025.0.1
311
+ ml-dtypes==0.4.1
312
+ mlxtend==0.23.4
313
+ more-itertools==10.6.0
314
+ moviepy==1.0.3
315
+ mpmath==1.3.0
316
+ msgpack==1.1.0
317
+ multidict==6.2.0
318
+ multipledispatch==1.0.0
319
+ multiprocess==0.70.16
320
+ multitasking==0.0.11
321
+ murmurhash==1.0.12
322
+ music21==9.3.0
323
+ namex==0.0.8
324
+ narwhals==1.33.0
325
+ natsort==8.4.0
326
+ nbclassic==1.2.0
327
+ nbclient==0.10.2
328
+ nbconvert==7.16.6
329
+ nbformat==5.10.4
330
+ ndindex==1.9.2
331
+ nest-asyncio==1.6.0
332
+ networkx==3.2.1
333
+ nibabel==5.3.2
334
+ nltk==3.9.1
335
+ notebook==6.5.7
336
+ notebook_shim==0.2.4
337
+ numba==0.60.0
338
+ numba-cuda==0.2.0
339
+ numexpr==2.10.2
340
+ numpy==1.26.4
341
+ nvidia-cublas-cu12==12.4.5.8
342
+ nvidia-cuda-cupti-cu12==12.4.127
343
+ nvidia-cuda-nvcc-cu12==12.5.82
344
+ nvidia-cuda-nvrtc-cu12==12.4.127
345
+ nvidia-cuda-runtime-cu12==12.4.127
346
+ nvidia-cudnn-cu12==9.1.0.70
347
+ nvidia-cufft-cu12==11.2.1.3
348
+ nvidia-curand-cu12==10.3.5.147
349
+ nvidia-cusolver-cu12==11.6.1.9
350
+ nvidia-cusparse-cu12==12.3.1.170
351
+ nvidia-cusparselt-cu12==0.6.2
352
+ nvidia-ml-py==12.570.86
353
+ nvidia-nccl-cu12==2.21.5
354
+ nvidia-nvcomp-cu12==4.1.0.6
355
+ nvidia-nvjitlink-cu12==12.4.127
356
+ nvidia-nvtx-cu12==12.4.127
357
+ nvtx==0.2.11
358
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.2.0-py3-none-any.whl
359
+ oauth2client==4.1.3
360
+ oauthlib==3.2.2
361
+ openai==1.70.0
362
+ opencv-contrib-python==4.11.0.86
363
+ opencv-python==4.11.0.86
364
+ opencv-python-headless==4.11.0.86
365
+ openpyxl==3.1.5
366
+ opentelemetry-api==1.31.1
367
+ opentelemetry-sdk==1.31.1
368
+ opentelemetry-semantic-conventions==0.52b1
369
+ opt_einsum==3.4.0
370
+ optax==0.2.4
371
+ optree==0.14.1
372
+ orbax-checkpoint==0.11.10
373
+ orjson==3.10.16
374
+ osqp==1.0.3
375
+ packaging==24.2
376
+ pandas==2.2.2
377
+ pandas-datareader==0.10.0
378
+ pandas-gbq==0.28.0
379
+ pandas-stubs==2.2.2.240909
380
+ pandocfilters==1.5.1
381
+ panel==1.6.2
382
+ param==2.2.0
383
+ parso==0.8.4
384
+ parsy==2.1
385
+ partd==1.4.2
386
+ pathlib==1.0.1
387
+ patsy==1.0.1
388
+ peewee==3.17.9
389
+ peft==0.14.0
390
+ pexpect==4.9.0
391
+ pickleshare==0.7.5
392
+ pillow==11.1.0
393
+ platformdirs==4.3.7
394
+ plotly==5.24.1
395
+ plotnine==0.14.5
396
+ pluggy==1.5.0
397
+ ply==3.11
398
+ polars==1.11.0
399
+ pooch==1.8.2
400
+ portpicker==1.5.2
401
+ preshed==3.0.9
402
+ prettytable==3.16.0
403
+ proglog==0.1.11
404
+ progressbar2==4.5.0
405
+ prometheus_client==0.21.1
406
+ promise==2.3
407
+ prompt_toolkit==3.0.50
408
+ propcache==0.3.1
409
+ prophet==1.1.6
410
+ proto-plus==1.26.1
411
+ protobuf==3.20.3
412
+ psutil==5.9.5
413
+ psycopg2==2.9.10
414
+ ptyprocess==0.7.0
415
+ py-cpuinfo==9.0.0
416
+ py4j==0.10.9.7
417
+ pyarrow==17.0.0
418
+ pyasn1==0.6.1
419
+ pyasn1_modules==0.4.2
420
+ pycairo==1.27.0
421
+ pycocotools==2.0.8
422
+ pycparser==2.22
423
+ pydantic==2.9.2
424
+ pydantic_core==2.23.4
425
+ pydata-google-auth==1.9.1
426
+ pydot==1.4.2
427
+ pydotplus==2.0.2
428
+ PyDrive==1.3.1
429
+ PyDrive2==1.21.3
430
+ pyerfa==2.0.1.5
431
+ pygame==2.6.1
432
+ pygit2==1.17.0
433
+ Pygments==2.18.0
434
+ PyGObject==3.42.0
435
+ PyJWT==2.10.1
436
+ pylibcudf-cu12==24.12.0
437
+ pylibcugraph-cu12==25.2.0
438
+ pylibraft-cu12==25.2.0
439
+ pymc==5.21.2
440
+ pymystem3==0.2.0
441
+ pynndescent==0.5.13
442
+ pynvjitlink-cu12==0.5.2
443
+ pynvml==12.0.0
444
+ pyogrio==0.10.0
445
+ Pyomo==6.8.2
446
+ PyOpenGL==3.1.9
447
+ pyOpenSSL==24.2.1
448
+ pyparsing==3.2.3
449
+ pyperclip==1.9.0
450
+ pyproj==3.7.1
451
+ pyshp==2.3.1
452
+ PySocks==1.7.1
453
+ pyspark==3.5.5
454
+ pytensor==2.30.2
455
+ pytest==8.3.3
456
+ python-apt==0.0.0
457
+ python-box==7.3.2
458
+ python-dateutil==2.8.2
459
+ python-dotenv==1.0.1
460
+ python-louvain==0.16
461
+ python-multipart==0.0.20
462
+ python-slugify==8.0.4
463
+ python-snappy==0.7.3
464
+ python-utils==3.9.1
465
+ pytz==2025.2
466
+ pyviz_comms==3.0.4
467
+ PyYAML==6.0.2
468
+ pyzmq==24.0.1
469
+ raft-dask-cu12==25.2.0
470
+ rapids-dask-dependency==25.2.0
471
+ ratelim==0.1.6
472
+ referencing==0.36.2
473
+ regex==2024.11.6
474
+ requests==2.32.3
475
+ requests-oauthlib==2.0.0
476
+ requests-toolbelt==1.0.0
477
+ requirements-parser==0.9.0
478
+ retrain_pipelines @ git+https://github.com/aurelienmorgan/retrain-pipelines.git@8c1868b9297143ee2e02c7b5fcfac55b67949743#subdirectory=pkg_src
479
+ rich==13.9.4
480
+ rmm-cu12==24.12.0
481
+ roman-numerals-py==3.1.0
482
+ rpds-py==0.24.0
483
+ rpy2==3.5.17
484
+ rsa==4.9
485
+ s3transfer==0.11.4
486
+ safetensors==0.5.3
487
+ scikit-image==0.25.2
488
+ scikit-learn==1.6.1
489
+ scipy==1.14.1
490
+ scooby==0.10.0
491
+ scs==3.2.7.post2
492
+ seaborn==0.13.2
493
+ SecretStorage==3.3.1
494
+ Send2Trash==1.8.3
495
+ sentence-transformers==3.4.1
496
+ sentencepiece==0.2.0
497
+ sentry-sdk==2.25.1
498
+ setproctitle==1.3.5
499
+ shap==0.47.1
500
+ shapely==2.1.0
501
+ shellingham==1.5.4
502
+ shtab==1.7.1
503
+ simple-parsing==0.1.7
504
+ simplejson==3.20.1
505
+ simsimd==6.2.1
506
+ six==1.17.0
507
+ sklearn-compat==0.1.3
508
+ sklearn-pandas==2.2.0
509
+ slicer==0.0.8
510
+ smart-open==7.1.0
511
+ smmap==5.0.2
512
+ sniffio==1.3.1
513
+ snowballstemmer==2.2.0
514
+ sortedcontainers==2.4.0
515
+ soundfile==0.13.1
516
+ soupsieve==2.6
517
+ soxr==0.5.0.post1
518
+ spacy==3.8.5
519
+ spacy-legacy==3.0.12
520
+ spacy-loggers==1.0.5
521
+ spanner-graph-notebook==1.1.6
522
+ Sphinx==8.2.3
523
+ sphinxcontrib-applehelp==2.0.0
524
+ sphinxcontrib-devhelp==2.0.0
525
+ sphinxcontrib-htmlhelp==2.1.0
526
+ sphinxcontrib-jsmath==1.0.1
527
+ sphinxcontrib-qthelp==2.0.0
528
+ sphinxcontrib-serializinghtml==2.0.0
529
+ SQLAlchemy==2.0.40
530
+ sqlglot==25.20.2
531
+ sqlparse==0.5.3
532
+ srsly==2.5.1
533
+ stanio==0.5.1
534
+ starlette==0.46.1
535
+ statsmodels==0.14.4
536
+ stringzilla==3.12.3
537
+ sympy==1.13.1
538
+ tables==3.10.2
539
+ tabulate==0.9.0
540
+ tbb==2022.1.0
541
+ tblib==3.1.0
542
+ tcmlib==1.3.0
543
+ tenacity==9.1.2
544
+ tensorboard==2.18.0
545
+ tensorboard-data-server==0.7.2
546
+ tensorflow==2.18.0
547
+ tensorflow-datasets==4.9.8
548
+ tensorflow-hub==0.16.1
549
+ tensorflow-io-gcs-filesystem==0.37.1
550
+ tensorflow-metadata==1.17.0
551
+ tensorflow-probability==0.25.0
552
+ tensorflow-text==2.18.1
553
+ tensorstore==0.1.73
554
+ termcolor==3.0.1
555
+ terminado==0.18.1
556
+ text-unidecode==1.3
557
+ textblob==0.19.0
558
+ tf-slim==1.1.0
559
+ tf_keras==2.18.0
560
+ thinc==8.3.4
561
+ threadpoolctl==3.6.0
562
+ tifffile==2025.3.30
563
+ timm==1.0.15
564
+ tinycss2==1.4.0
565
+ tokenizers==0.20.3
566
+ toml==0.10.2
567
+ toolz==0.12.1
568
+ torch==2.5.0
569
+ torchsummary==1.5.1
570
+ torchvision==0.20.0
571
+ tornado==6.4.2
572
+ tqdm==4.67.1
573
+ traitlets==5.7.1
574
+ traittypes==0.2.1
575
+ transformers==4.46.2
576
+ treelite==4.4.1
577
+ treescope==0.1.9
578
+ triton==3.1.0
579
+ trl==0.12.0
580
+ tweepy==4.15.0
581
+ typeguard==4.4.2
582
+ typer==0.15.2
583
+ types-pytz==2025.2.0.20250326
584
+ types-setuptools==78.1.0.20250329
585
+ typing-inspection==0.4.0
586
+ typing_extensions==4.13.1
587
+ tyro==0.9.18
588
+ tzdata==2025.2
589
+ tzlocal==5.3.1
590
+ uc-micro-py==1.0.3
591
+ ucx-py-cu12==0.42.0
592
+ ucxx-cu12==0.42.0
593
+ umap-learn==0.5.7
594
+ umf==0.10.0
595
+ unsloth @ git+https://github.com/unslothai/unsloth.git@0c8c5ed81e423658ab9ae81eac5aab8d18f5d7af
596
+ unsloth_zoo==2024.11.4
597
+ uritemplate==4.1.1
598
+ urllib3==2.3.0
599
+ uvicorn==0.34.0
600
+ uvloop==0.21.0
601
+ vega-datasets==0.9.0
602
+ wadllib==1.3.6
603
+ wandb==0.19.9
604
+ wasabi==1.1.3
605
+ watchfiles==1.0.5
606
+ wcwidth==0.2.13
607
+ weasel==0.4.1
608
+ webcolors==24.11.1
609
+ webencodings==0.5.1
610
+ websocket-client==1.8.0
611
+ websockets==15.0.1
612
+ Werkzeug==3.1.3
613
+ widgetsnbextension==3.6.10
614
+ wordcloud==1.9.4
615
+ wrapt==1.17.2
616
+ xarray==2025.1.2
617
+ xarray-einstats==0.8.0
618
+ xformers==0.0.28.post2
619
+ xgboost==2.1.4
620
+ xlrd==2.0.1
621
+ xxhash==3.5.0
622
+ xyzservices==2025.1.0
623
+ yarl==1.18.3
624
+ yellowbrick==1.5
625
+ yfinance==0.2.55
626
+ zict==3.0.0
627
+ zipp==3.21.0
628
+ zstandard==0.23.0
v0.28_20250409_220919394_UTC/retraining_pipeline.py ADDED
@@ -0,0 +1,2212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from unsloth import FastLanguageModel, \
3
+ is_bfloat16_supported, UnslothTrainer, \
4
+ UnslothTrainingArguments
5
+
6
+ import torch
7
+
8
+ import os
9
+ import sys
10
+
11
+ import gc
12
+ import json
13
+ import time
14
+ import shutil
15
+ import logging
16
+ import traceback
17
+ import subprocess
18
+ import importlib.util
19
+ from enum import Enum
20
+ from io import StringIO
21
+ from textwrap import dedent
22
+ from datetime import datetime
23
+ from contextlib import redirect_stdout
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ import polars as pl
29
+ from polars.exceptions import ComputeError
30
+
31
+ import matplotlib
32
+ import matplotlib.pyplot as plt
33
+
34
+ from jinja2 import Environment, FileSystemLoader
35
+
36
+ from metaflow import FlowSpec, step, Parameter, JSONType, \
37
+ IncludeFile, current, metaflow_config as mf_config, \
38
+ resources, Flow, Task, card
39
+ from metaflow.current import Current
40
+ from metaflow.cards import Image, Table, Markdown, \
41
+ Artifact, get_cards
42
+
43
+ from datasets import load_dataset, Dataset, DatasetDict
44
+ from datasets.config import HF_DATASETS_CACHE, HF_CACHE_HOME
45
+ from huggingface_hub import list_repo_commits
46
+ from transformers import AutoTokenizer
47
+
48
+ from retrain_pipelines import __version__
49
+ from retrain_pipelines.dataset.hf_utils import get_lazy_df, \
50
+ get_column_info, iterable_dataset_multi_buffer_sampler, \
51
+ push_dataset_version_to_hub
52
+ from retrain_pipelines.dataset.tool_calls import \
53
+ get_unique_tools, count_tool_occurrences, \
54
+ plot_tools_occurences, column_words_stats, \
55
+ plot_words_count
56
+ from retrain_pipelines.utils.hf_utils import \
57
+ get_new_repo_minor_version, push_files_to_hub_repo_branch
58
+ from retrain_pipelines.utils import create_requirements
59
+
60
+
61
+ class LocalServeReadinessEnum(Enum):
62
+ """
63
+ tracking local-serve (infra-validation)
64
+ status using a "3+"-states enum :
65
+ - "-1" for "not applicable"
66
+ (i.e. "model version not blessed"),
67
+ - "0/1" bool for failure/success.
68
+ """
69
+ NOT_APPLICABLE = -1
70
+ FAILURE = 0
71
+ FAILURE_NO_DOCKER = 2
72
+ SUCCESS = 1
73
+
74
+
75
+ class UnslothFuncCallFlow(FlowSpec):
76
+ """
77
+ Training pipeline
78
+ """
79
+ # @see https://github.com/unslothai/unsloth/wiki
80
+
81
+ #--- flow parameters -------------------------------------------------------
82
+
83
+ RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
84
+ # in order to share the config across subprocesses
85
+ os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
86
+
87
+ hf_dataset = Parameter(
88
+ "hf_dataset",
89
+ help="dict with 'repo_id' and 'commit_hash' keys. " + \
90
+ "if 'commit_hash is None, falls back to latest version " +\
91
+ "of the dataset available in parquet format.\n" +
92
+ "Note that there are 3 required 'attributes' of type " + \
93
+ "str, list[str], list[str]",
94
+ type=JSONType,
95
+ default=dedent("""{
96
+ "repo_id": "Salesforce/xlam-function-calling-60k",
97
+ "config_name": "",
98
+ "commit_hash": "",
99
+ "attributes": {
100
+ "query_attr": "query",
101
+ "answers_attr": "answers",
102
+ "tools_attr": "tools"
103
+ }
104
+ }""").replace("'", '"').strip('"')
105
+ )
106
+
107
+ augmentation_rate = Parameter(
108
+ "augmentation_rate",
109
+ type=float,
110
+ default=.05,
111
+ help="proportion of records to be augmented "+\
112
+ "(x% of original dataset is created"+\
113
+ " as additional augmented datapoints), i.e. "+\
114
+ "truncated queries to serve as negative examples, "+\
115
+ "meaning they trigger no tool call "+\
116
+ "due to info incompleteness."
117
+ )
118
+
119
+ hf_enrich_dataset = Parameter(
120
+ "hf_enrich_dataset",
121
+ help="dict with 'repo_id', 'config_name' and 'commit_hash', "+\
122
+ "query_attribute' and 'query_attribute_handler' keys. "+\
123
+ "if 'commit_hash is None, falls back to latest version "+\
124
+ "of the dataset available in parquet format."+\
125
+ "'query_attribute' depicts the dataset attribute "+\
126
+ "from which 'queries' are to be sampled."+\
127
+ "'query_attribute_handler' serves for attributes "+\
128
+ "that have complex structure, "+\
129
+ "other than 'string' datatype.",
130
+ type=JSONType,
131
+ # @see https://huggingface.co/datasets/google-research-datasets/natural_questions
132
+ default=dedent("""{
133
+ "repo_id": "lighteval/natural_questions_clean",
134
+ "config_name": "",
135
+ "commit_hash": "",
136
+ "query_attribute": "question",
137
+ "query_attribute_handler": "lambda x: x"
138
+ }""").replace("'", '"').strip('"')
139
+ )
140
+
141
+ enrichment_rate = Parameter(
142
+ "enrichment_rate",
143
+ type=float,
144
+ default=.1,
145
+ help="proportion of records "+\
146
+ "to be added from the 'hf_enrich_dataset'"+\
147
+ "(x% of original dataset is sampled and"+\
148
+ " added as enriching datapoints), i.e. "+\
149
+ "queries to serve as negative examples, "+\
150
+ "due to their complete disconnexion "+\
151
+ "to tool calling situations."
152
+ )
153
+
154
+ dataset_repo_id = Parameter(
155
+ "dataset_repo_id",
156
+ type=str,
157
+ default="retrain-pipelines/func_calls",
158
+ help="The 'repo_id' to be used " + \
159
+ "for the Hugging Face dataset version push " + \
160
+ "(will be created at runtime" + \
161
+ " if doesn't already exist)."
162
+ )
163
+
164
+ hf_base_model = Parameter(
165
+ "hf_base_model",
166
+ help="dict with 'repo_id' and 'commit_hash' keys."+\
167
+ "if 'commit_hash is None, falls back "+\
168
+ "to latest available version of the model.",
169
+ type=JSONType,
170
+ default=dedent("""{
171
+ "repo_id": "unsloth/Qwen2.5-1.5B",
172
+ "commit_hash": ""
173
+ }""").replace("'", '"').strip('"')
174
+ )
175
+
176
+ cpt_training_args = Parameter(
177
+ "cpt_training_args",
178
+ help="dict with `TrainingArguments` params "+\
179
+ "for the CPT job.",
180
+ type=JSONType,
181
+ default=dedent("""{
182
+ "warmup_ratio": 0.1,
183
+ "num_train_epochs": 1
184
+ }""").replace("'", '"').strip('"')
185
+ )
186
+
187
+ sft_training_args = Parameter(
188
+ "sft_training_args",
189
+ help="dict with `TrainingArguments` params "+\
190
+ "for the SFT job.",
191
+ type=JSONType,
192
+ default=dedent("""{
193
+ "warmup_ratio": 0.1,
194
+ "num_train_epochs": 1
195
+ }""").replace("'", '"').strip('"')
196
+ )
197
+
198
+ model_repo_id = Parameter(
199
+ "model_repo_id",
200
+ type=str,
201
+ default="retrain-pipelines/function_caller",
202
+ help="The 'repo_id' to be used " + \
203
+ "for the Hugging Face model version push " + \
204
+ "(will be created at runtime" + \
205
+ " if doesn't already exist)."
206
+ )
207
+
208
+ default_pipeline_card_module_dir = \
209
+ os.path.dirname(
210
+ importlib.util.find_spec(
211
+ f"retrain_pipelines.pipeline_card."+
212
+ f"{RETRAIN_PIPELINE_TYPE}"
213
+ ).origin)
214
+ pipeline_card_artifacts_path = Parameter(
215
+ "pipeline_card_artifacts_path",
216
+ type=str,
217
+ default=default_pipeline_card_module_dir,
218
+ help="pipeline_card artifacts location "+\
219
+ "(i.e. dir hosting your optional " + \
220
+ " custom documentation files :" + \
221
+ " 'pipeline_card.py' and/or 'template.html'"+\
222
+ " and/or 'model_readme.py'"+\
223
+ " and/or 'model_readme_template.md'," +\
224
+ " and/or 'dataset_readme.py'"+\
225
+ " and/or 'dataset_readme_template.md' file), " +\
226
+ "if different from default."
227
+ )
228
+ @staticmethod
229
+ def copy_default_dataset_readme_module(
230
+ target_dir: str,
231
+ exists_ok: bool = False
232
+ ) -> None:
233
+ os.makedirs(target_dir, exist_ok=True)
234
+ if (
235
+ not exists_ok and
236
+ os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
237
+ ):
238
+ print("File already exists. Skipping copy.")
239
+ else:
240
+ filefullname = os.path.join(
241
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
242
+ "dataset_readme.py"
243
+ )
244
+ shutil.copy(filefullname, target_dir)
245
+ print(filefullname)
246
+ @staticmethod
247
+ def copy_default_dataset_readme_template(
248
+ target_dir: str,
249
+ exists_ok: bool = False
250
+ ) -> None:
251
+ os.makedirs(target_dir, exist_ok=True)
252
+ if (
253
+ not exists_ok and
254
+ os.path.exists(os.path.join(target_dir,
255
+ "dataset_readme_template.md"))
256
+ ):
257
+ print("File already exists. Skipping copy.")
258
+ else:
259
+ filefullname = os.path.join(
260
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
261
+ "dataset_readme_template.md")
262
+ shutil.copy(filefullname, target_dir)
263
+ print(filefullname)
264
+ @staticmethod
265
+ def copy_default_model_readme_module(
266
+ target_dir: str,
267
+ exists_ok: bool = False
268
+ ) -> None:
269
+ os.makedirs(target_dir, exist_ok=True)
270
+ if (
271
+ not exists_ok and
272
+ os.path.exists(os.path.join(target_dir, "model_readme.py"))
273
+ ):
274
+ print("File already exists. Skipping copy.")
275
+ else:
276
+ filefullname = os.path.join(
277
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
278
+ "model_readme.py"
279
+ )
280
+ shutil.copy(filefullname, target_dir)
281
+ print(filefullname)
282
+ @staticmethod
283
+ def copy_default_model_readme_template(
284
+ target_dir: str,
285
+ exists_ok: bool = False
286
+ ) -> None:
287
+ os.makedirs(target_dir, exist_ok=True)
288
+ if (
289
+ not exists_ok and
290
+ os.path.exists(os.path.join(target_dir,
291
+ "model_readme_template.md"))
292
+ ):
293
+ print("File already exists. Skipping copy.")
294
+ else:
295
+ filefullname = os.path.join(
296
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
297
+ "model_readme_template.md")
298
+ shutil.copy(filefullname, target_dir)
299
+ print(filefullname)
300
+ @staticmethod
301
+ def copy_default_pipeline_card_module(
302
+ target_dir: str,
303
+ exists_ok: bool = False
304
+ ) -> None:
305
+ os.makedirs(target_dir, exist_ok=True)
306
+ if (
307
+ not exists_ok and
308
+ os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
309
+ ):
310
+ print("File already exists. Skipping copy.")
311
+ else:
312
+ filefullname = os.path.join(
313
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
314
+ "pipeline_card.py"
315
+ )
316
+ shutil.copy(filefullname, target_dir)
317
+ print(filefullname)
318
+ @staticmethod
319
+ def copy_default_pipeline_card_html_template(
320
+ target_dir: str,
321
+ exists_ok: bool = False
322
+ ) -> None:
323
+ os.makedirs(target_dir, exist_ok=True)
324
+ if (
325
+ not exists_ok and
326
+ os.path.exists(os.path.join(target_dir, "template.html"))
327
+ ):
328
+ print("File already exists. Skipping copy.")
329
+ else:
330
+ filefullname = os.path.join(
331
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
332
+ "template.html")
333
+ shutil.copy(filefullname, target_dir)
334
+ print(filefullname)
335
+
336
+ del RETRAIN_PIPELINE_TYPE
337
+
338
+ #---------------------------------------------------------------------------
339
+
340
+ @step
341
+ def start(self):
342
+ print(f"{current.flow_name} - {current.run_id}")
343
+
344
+ # GPU availability
345
+ print(torch.cuda.get_device_name(0))
346
+ print(torch.__version__)
347
+ self.engine = "gpu" if torch.cuda.is_available() else "cpu"
348
+
349
+ # hf_dataset
350
+ hf_dataset_dict = \
351
+ get_lazy_df(
352
+ repo_id=self.hf_dataset["repo_id"],
353
+ commit_hash=self.hf_dataset["commit_hash"],
354
+ files_filter=(
355
+ self.hf_dataset['config_name']+"/.*\\.parquet"
356
+ if (
357
+ self.hf_dataset["config_name"] and
358
+ "" < self.hf_dataset["config_name"]
359
+ ) else ".*\\.parquet"
360
+ ),
361
+ hf_token=os.getenv("HF_TOKEN", None)
362
+ )
363
+ try:
364
+ print(hf_dataset_dict["repo_id"], ", ",
365
+ hf_dataset_dict["commit_hash"], " - ",
366
+ hf_dataset_dict["commit_datetime"], "\n",
367
+ hf_dataset_dict["lazy_df"].explain())
368
+ except ComputeError as ex:
369
+ if "HF_TOKEN" not in os.environ:
370
+ print("Does the Hugging Face-hosted dataset " +
371
+ "require authentication ?",
372
+ file=sys.stderr, flush=True)
373
+ raise ex
374
+ self.hf_dataset_dict = hf_dataset_dict
375
+
376
+ # hf_enrich_dataset
377
+ print(self.hf_enrich_dataset)
378
+ hf_enrich_dataset_dict = \
379
+ get_lazy_df(
380
+ repo_id=self.hf_enrich_dataset["repo_id"],
381
+ commit_hash=self.hf_enrich_dataset["commit_hash"],
382
+ files_filter=(
383
+ self.hf_enrich_dataset['config_name']+"/.*\\.parquet"
384
+ if (
385
+ self.hf_enrich_dataset["config_name"] and
386
+ "" < self.hf_enrich_dataset["config_name"]
387
+ ) else ".*\\.parquet"
388
+ ),
389
+ hf_token=os.getenv("HF_TOKEN", None)
390
+ )
391
+ print(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
392
+ for k in ['commit_hash',
393
+ 'commit_datetime']))
394
+ self.hf_enrich_dataset_dict = hf_enrich_dataset_dict
395
+
396
+ # hf_base_model
397
+ hf_base_model_commits = list_repo_commits(
398
+ repo_id=self.hf_base_model["repo_id"],
399
+ revision=(
400
+ None if (rev_commit_hash:=self.hf_base_model["commit_hash"]) == ""
401
+ else rev_commit_hash
402
+ ),
403
+ repo_type="model",
404
+ token=os.getenv("HF_TOKEN", None))
405
+ self.hf_base_model_dict = {
406
+ "repo_id": self.hf_base_model["repo_id"],
407
+ "commit_hash": hf_base_model_commits[0].commit_id,
408
+ "commit_datetime": \
409
+ hf_base_model_commits[0].created_at
410
+ }
411
+
412
+ self.model_version_blessed = False
413
+ self.current_blessed_run = None
414
+ self.current_blessed_version_dict = None
415
+ current.run.remove_tag("model_version_blessed")
416
+
417
+ self.retrain_pipelines = f"retrain-pipelines {__version__}"
418
+ self.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
419
+
420
+ self.serving_artifacts_local_folder = \
421
+ os.path.realpath(os.path.join(
422
+ os.path.dirname(__file__),
423
+ '..', '..', 'serving_artifacts',
424
+ os.path.sep.join(current.run.path_components)
425
+ ))
426
+
427
+ if not os.path.exists(self.serving_artifacts_local_folder):
428
+ os.makedirs(self.serving_artifacts_local_folder)
429
+
430
+ self.unsloth_dir = os.path.join(
431
+ self.serving_artifacts_local_folder,
432
+ "Unsloth"
433
+ )
434
+ print(f"unsloth_dir : {self.unsloth_dir}")
435
+ self.cpt_model_dir = os.path.join(
436
+ self.unsloth_dir, "cpt_model")
437
+ self.sft_model_dir = os.path.join(
438
+ self.unsloth_dir, "sft_model")
439
+
440
+ self.next(self.eda)
441
+
442
+
443
+ @step
444
+ def eda(self):
445
+ """
446
+ exploratory data analysis.
447
+ """
448
+
449
+ ############################
450
+ # features and label #
451
+ # basic counts #
452
+ ############################
453
+ self.records_count = self.hf_dataset_dict["lazy_df"] \
454
+ .select(pl.len()).collect(engine=self.engine).item()
455
+ self.data_schema = get_column_info(
456
+ self.hf_dataset_dict["lazy_df"], engine=self.engine)
457
+ ############################
458
+
459
+ ############################
460
+ # Answers #
461
+ # tools count #
462
+ ############################
463
+ struct_schema = pl.Struct([
464
+ pl.Field("name",
465
+ pl.String
466
+ ),
467
+ pl.Field("arguments",
468
+ pl.List(pl.String) # we retrieve list of args names
469
+ # (without assigned values)
470
+ )
471
+ ])
472
+ tool_answer_occurrences_df = \
473
+ count_tool_occurrences(
474
+ self.hf_dataset_dict["lazy_df"],
475
+ self.hf_dataset["attributes"]["answers_attr"],
476
+ struct_schema) \
477
+ .collect(engine=self.engine)
478
+ print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
479
+ f"query/tool-calls pairs")
480
+ fig = plot_tools_occurences(tool_answer_occurrences_df,
481
+ title_prefix="Dataset answers - ")
482
+ self.answers_tools_count_fig = fig
483
+ ############################
484
+
485
+ ############################
486
+ # Query #
487
+ # words count #
488
+ ############################
489
+ queries_max_length = self.hf_dataset_dict["lazy_df"].select(
490
+ pl.col(
491
+ self.hf_dataset["attributes"]["query_attr"]
492
+ ).str.len_chars().max().alias("max_query_length")
493
+ ).collect(engine=self.engine)
494
+ print(f"longuest query counts " +
495
+ f"{queries_max_length['max_query_length'][0]:,} characters")
496
+
497
+ # queries length quartiles
498
+ self.query_words_stats = \
499
+ column_words_stats(
500
+ self.hf_dataset_dict["lazy_df"],
501
+ self.hf_dataset["attributes"]["query_attr"]
502
+ ).collect(engine=self.engine)
503
+ print(self.query_words_stats.to_pandas().to_string(index=False))
504
+ print("Two thirds of the records have a query with less than " +
505
+ f"{self.query_words_stats['q3'][0]} words.")
506
+
507
+ fig = plot_words_count(
508
+ self.hf_dataset_dict["lazy_df"],
509
+ column_name=self.hf_dataset["attributes"]["query_attr"],
510
+ engine=self.engine)
511
+ self.words_count_fig = fig
512
+ ############################
513
+
514
+ ############################
515
+ # hf_enrich_dataset #
516
+ # Query words count #
517
+ ############################
518
+ enrich_question_words_stats = \
519
+ column_words_stats(
520
+ self.hf_enrich_dataset_dict['lazy_df'],
521
+ self.hf_enrich_dataset["query_attribute"],
522
+ column_attr_handler=eval(
523
+ self.hf_enrich_dataset["query_attribute_handler"])
524
+ ).collect(engine=self.engine)
525
+ print(enrich_question_words_stats.to_pandas()
526
+ .to_string(index=False))
527
+ del enrich_question_words_stats
528
+ ############################
529
+
530
+ self.next(self.augment_data)
531
+
532
+
533
+ @step
534
+ def augment_data(self):
535
+ """
536
+ Add 'negative' examples, where
537
+ queries do not trigger any tool call.
538
+ To achieve that, we sample long user queries,
539
+ truncate at half words count, and
540
+ associate this to an empty list of tool-calls.
541
+ """
542
+ """
543
+ We only consider :
544
+ - records with longuest queries,
545
+ i.e. queries in the last quartile
546
+ of "queries with most word-counts"
547
+ (this is to avoid that 'truncated' queries
548
+ get really short)
549
+ - records with answers consisting
550
+ in a single tool-call
551
+ (in order to minimize the risk
552
+ that truncating actually gives
553
+ a valid answer with
554
+ one tool-call [or more])
555
+
556
+ Note on flow 'augmentation_rate' :
557
+ we add that many records (at most),
558
+ as quartiles size permits.
559
+ """
560
+
561
+ print("Sampling within the population with more than " +
562
+ str(self.query_words_stats['q3'][0]) +
563
+ " words (longest queries quartile) =>")
564
+
565
+ samples_count = \
566
+ int(self.records_count * self.augmentation_rate)
567
+ print(f"would represent {samples_count:,.0f} " +
568
+ f"records to be sampled")
569
+
570
+ eligible_records_df = \
571
+ self.hf_dataset_dict["lazy_df"].filter(
572
+ pl.col(
573
+ self.hf_dataset["attributes"]["query_attr"]
574
+ )
575
+ .str.extract_all(r"\w+")
576
+ .map_elements(
577
+ lambda arr: len(arr),
578
+ return_dtype=pl.Int16)
579
+ .gt(self.query_words_stats['q3'][0])
580
+ & pl.col("answers")
581
+ .map_elements(
582
+ lambda x: len(json.loads(x)) == 1
583
+ if isinstance(x, str)
584
+ else False,
585
+ return_dtype=pl.Boolean)
586
+ ) \
587
+ .collect(engine=self.engine)
588
+ eligible_records_count = \
589
+ eligible_records_df.select(pl.len())["len"][0]
590
+ print(f"eligible_records_count : " +
591
+ f"{eligible_records_count:,.0f}")
592
+ samples_count = min(samples_count, eligible_records_count)
593
+ self.actual_augmentation_rate = \
594
+ samples_count / self.records_count
595
+ print("actual augmentation rate : " +
596
+ f"{self.actual_augmentation_rate:.1%}")
597
+ sampled_records_df = eligible_records_df.sample(
598
+ n=samples_count
599
+ )
600
+
601
+ self.augmented_records_df = \
602
+ sampled_records_df.with_columns(
603
+ pl.col("query")
604
+ .map_elements(
605
+ lambda query:
606
+ " ".join(
607
+ query.split()[
608
+ :len(query.split()) // 2]),
609
+ return_dtype=pl.Utf8)
610
+ .alias("truncated_query")
611
+ ).select([
612
+ pl.col("truncated_query").alias("query"),
613
+ pl.lit("[]").alias("answers")
614
+ ])
615
+ print(self.augmented_records_df.height,
616
+ self.augmented_records_df.columns)
617
+
618
+ self.next(self.enrich_data)
619
+
620
+
621
+ @step
622
+ def enrich_data(self):
623
+ """
624
+ Further enrich our dataset with 'negative' records from
625
+ another dataset (can be general-purpose text dataset)
626
+ as specified by the the flow 'hf_enrich_dataset' argument.
627
+ """
628
+ """
629
+ Note : we here use the Hugging Face `datasets` library
630
+ in 'streaming' mode for records sampling.
631
+ """
632
+
633
+ hf_enrich_ds = load_dataset(
634
+ path=self.hf_enrich_dataset["repo_id"],
635
+ name=self.hf_enrich_dataset["config_name"],
636
+ revision=self.hf_enrich_dataset_dict["commit_hash"],
637
+ streaming=True)
638
+ print(hf_enrich_ds["train"])
639
+
640
+ samples_count = \
641
+ int(self.records_count * self.enrichment_rate)
642
+ print(f"Samplig {samples_count:,.0f} records")
643
+
644
+ query_attribute_handler = \
645
+ eval(self.hf_enrich_dataset["query_attribute_handler"])
646
+ samples_iterator = iterable_dataset_multi_buffer_sampler(
647
+ hf_enrich_ds["train"],
648
+ total_samples=samples_count,
649
+ attributes_selector=\
650
+ (lambda x:query_attribute_handler(
651
+ x[self.hf_enrich_dataset["query_attribute"]])),
652
+ buffer_size=3_000,
653
+ num_passes=3,
654
+ seed=None
655
+ )
656
+ # Capitalize and add end punctuation if missing
657
+ start_time = time.time()
658
+ print("Starting sample enriching records, " +
659
+ "this may take some time if the source dataset " +
660
+ "has a complex structure..")
661
+ samples_list = [
662
+ s.capitalize() + ("" if s[-1] in ".!?" else "?")
663
+ for s in samples_iterator]
664
+ elapsed_time = time.time() - start_time
665
+ print(f".. sampling completed " +
666
+ f"({int(elapsed_time // 3_600)}h:" +
667
+ f"{int((elapsed_time % 3_600) // 60)}m:" +
668
+ f"{int(elapsed_time % 60)}s).")
669
+ enriched_records_df = pl.DataFrame(
670
+ {"query": samples_list,
671
+ "answers": \
672
+ ["[]"] * \
673
+ len(samples_list)}
674
+ )
675
+ self.enriched_records_df = enriched_records_df
676
+
677
+ self.next(self.dataset_to_hub)
678
+
679
+
680
+ @step
681
+ def dataset_to_hub(self):
682
+ """
683
+ Push to hub dataset version
684
+ - continued pre-training dataset
685
+ - training and validation splits of the
686
+ augmented and enriched
687
+ supervised finetuning dataset
688
+ - readme with versioning info
689
+ """
690
+
691
+ #############################
692
+ # case of user-provided #
693
+ # documentation artifact(s) #
694
+ #############################
695
+ # note that user can provide either
696
+ # 'pipeline_card.py' or 'template.html'
697
+ # or 'dataset_readme.py'
698
+ # or 'dataset_readme_template.md'
699
+ # or 'model_readme.py'
700
+ # or 'model_readme_template.md'
701
+ # or any combination of those
702
+ # when specifying custom
703
+ # 'pipeline_card_artifacts_path'
704
+ if (
705
+ "dataset_readme_template.md" in
706
+ os.listdir(self.pipeline_card_artifacts_path)
707
+ ):
708
+ template_dir = self.pipeline_card_artifacts_path
709
+ else:
710
+ template_dir = os.path.dirname(
711
+ importlib.util.find_spec(
712
+ f"retrain_pipelines.pipeline_card."+
713
+ f"{os.getenv('retrain_pipeline_type')}"
714
+ ).origin)
715
+ print(f"template_dir : '{template_dir}'")
716
+ #############################
717
+ if "dataset_readme.py" in os.listdir(
718
+ self.pipeline_card_artifacts_path):
719
+ from retrain_pipelines.utils import \
720
+ get_get_dataset_readme_content
721
+ get_dataset_readme_content = \
722
+ get_get_dataset_readme_content(
723
+ self.pipeline_card_artifacts_path)
724
+ else:
725
+ from retrain_pipelines.pipeline_card import \
726
+ get_dataset_readme_content
727
+ #############################
728
+
729
+
730
+ #############################
731
+ # augmented & enriched #
732
+ # finetuning dataset #
733
+ #############################
734
+ merged_df = pl.concat([
735
+ # dataset
736
+ self.hf_dataset_dict["lazy_df"].select([
737
+ self.hf_dataset["attributes"]["query_attr"],
738
+ self.hf_dataset["attributes"]["answers_attr"]
739
+ ]).collect(engine=self.engine),
740
+ # truncated queries augmentation
741
+ self.augmented_records_df,
742
+ # enriching dataset
743
+ self.enriched_records_df
744
+ ]).sample(
745
+ # shuffling
746
+ fraction=1,
747
+ shuffle=True,
748
+ with_replacement=False
749
+ )
750
+ merged_df = merged_df.sample(fraction=1, shuffle=True)
751
+ merged_df.rechunk()
752
+ print(("merged_df", f"{merged_df.shape[0]:,.0F}",
753
+ merged_df.columns))
754
+
755
+ pandas_df = merged_df.to_pandas()
756
+ train_size = int(0.8 * len(pandas_df))
757
+ print(f"validation : {len(pandas_df) - train_size}")
758
+ sft_dataset = DatasetDict({
759
+ "train": Dataset.from_pandas(pandas_df[:train_size]),
760
+ "validation": Dataset.from_pandas(pandas_df[train_size:])
761
+ })
762
+ #############################
763
+
764
+ #############################
765
+ # continued pre-training #
766
+ # dataset #
767
+ #############################
768
+ struct_schema = pl.Struct([
769
+ pl.Field("name", pl.String),
770
+ pl.Field("description", pl.String),
771
+ pl.Field(
772
+ "parameters",
773
+ pl.String # Use String to allow
774
+ # for varying structures
775
+ # (different tools indeed having
776
+ # different sets of parameters
777
+ # i.e. different parameters counts,
778
+ # datatypes and names)
779
+ # so parsing must be tolerant.
780
+ )
781
+ ])
782
+ unique_tools_df = get_unique_tools(
783
+ self.hf_dataset_dict["lazy_df"],
784
+ tools_attr_name=\
785
+ self.hf_dataset["attributes"]["tools_attr"],
786
+ struct_schema=struct_schema
787
+ ).collect(engine=self.engine)
788
+ unique_tools_arrow_table = unique_tools_df.to_arrow()
789
+ self.unique_tools_dataset = \
790
+ Dataset(unique_tools_arrow_table)
791
+ print(self.unique_tools_dataset)
792
+ #############################
793
+
794
+ #############################
795
+ # DatasetDict #
796
+ # with multiple tables #
797
+ #############################
798
+ dataset_dict = DatasetDict({
799
+ "continued_pre_training": \
800
+ self.unique_tools_dataset,
801
+ "supervised_finetuning": sft_dataset
802
+ })
803
+ print(dataset_dict, flush=True)
804
+ #############################
805
+
806
+ #############################
807
+ # dataset README #
808
+ # from template #
809
+ #############################
810
+ commit_datetime = datetime.utcnow()
811
+ new_dataset_version_label = get_new_repo_minor_version(
812
+ repo_id=self.dataset_repo_id,
813
+ repo_type="dataset",
814
+ hf_token=os.getenv("HF_TOKEN", None))
815
+ readme_content = get_dataset_readme_content(
816
+ template_folder=template_dir,
817
+
818
+ hf_dataset_dict=self.hf_dataset_dict,
819
+ hf_enrich_dataset_dict=self.hf_enrich_dataset_dict,
820
+ dataset_dict=dataset_dict,
821
+
822
+ augmentation_rate=self.actual_augmentation_rate,
823
+ enrichment_rate=self.enrichment_rate,
824
+
825
+ version_label=new_dataset_version_label,
826
+ commit_datetime=commit_datetime,
827
+
828
+ mf_flow_name=current.flow_name,
829
+ mf_run_id=current.run.id,
830
+ engine=self.engine
831
+ )
832
+ #############################
833
+
834
+ dataset_commit_hash = push_dataset_version_to_hub(
835
+ repo_id=self.dataset_repo_id,
836
+ version_label=new_dataset_version_label,
837
+ timestamp_str=commit_datetime.strftime(
838
+ "%Y-%m-%d %H:%M:%S UTC"),
839
+ dataset_dict=dataset_dict,
840
+ dataset_readme_content=readme_content,
841
+ hf_token=os.getenv("HF_TOKEN", None)
842
+ )
843
+ if not dataset_commit_hash:
844
+ raise Exception(
845
+ "Failed to publish dataset version.")
846
+ print(f"https://huggingface.co/datasets/{self.dataset_repo_id}" +
847
+ f"/blob/{dataset_commit_hash}/README.md")
848
+ self.dataset_commit_dict = {
849
+ "repo_id": self.dataset_repo_id,
850
+ "commit_hash": dataset_commit_hash,
851
+ "version_label": new_dataset_version_label,
852
+ "commit_datetime": commit_datetime,
853
+ }
854
+
855
+ self.next(self.continued_pre_training)
856
+
857
+
858
+ @step
859
+ def continued_pre_training(self):
860
+ """
861
+ Gives the base model some additional intrinsic knowkledge
862
+ through continued pre-training.
863
+ See unsloth.ai/blog/contpretraining
864
+ """
865
+ from retrain_pipelines.model.hf_utils import \
866
+ plot_log_history
867
+
868
+ #######################################
869
+ # base-model and associated tokenizer #
870
+ # from Hub (or local cache) #
871
+ #######################################
872
+ self.max_seq_length = 2048
873
+ model, tokenizer = FastLanguageModel.from_pretrained(
874
+ model_name=self.hf_base_model_dict["repo_id"],
875
+ revision=self.hf_base_model_dict["commit_hash"],
876
+ max_seq_length=self.max_seq_length,
877
+ dtype=None,
878
+ load_in_4bit=False,
879
+ # case of a gated or private base-model
880
+ token=os.getenv("HF_TOKEN", None)
881
+ )
882
+ #######################################
883
+
884
+ #######################################
885
+ # dataset prompt_template mapping #
886
+ #######################################
887
+ tools_dataset = DatasetDict(
888
+ {"train": self.unique_tools_dataset})
889
+ print(tools_dataset)
890
+ tool_prompt_template = "tool: {}"
891
+ def formatting_prompts_func(tools_batch):
892
+ tools_batch = tools_batch["tool"]
893
+ outputs = []
894
+ for tool in tools_batch:
895
+ # Must add EOS_TOKEN,
896
+ # otherwise generation will go on forever!
897
+ text = tool_prompt_template.format(tool) + \
898
+ tokenizer.eos_token
899
+ outputs.append(text)
900
+ return { "tools" : outputs, }
901
+ cpt_dataset = tools_dataset["train"].map(
902
+ formatting_prompts_func, batched=True,)
903
+ #######################################
904
+
905
+ #######################################
906
+ # PEFT adapter #
907
+ # for continued pre-training #
908
+ #######################################
909
+ model = FastLanguageModel.get_peft_model(
910
+ model,
911
+ r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
912
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
913
+ "gate_proj", "up_proj", "down_proj",
914
+ # Add for continued pretraining
915
+ "embed_tokens", "lm_head",],
916
+ lora_alpha = 32,
917
+ lora_dropout = 0, # Supports any, 0 is optimized
918
+ bias = "none", # Supports any, "none" is optimized
919
+ # True or "unsloth" for very long context
920
+ use_gradient_checkpointing = "unsloth",
921
+ use_rslora = True, # rank-stabilized LoRA
922
+ loftq_config = None, # LoftQ
923
+ #random_state = 3407,
924
+ )
925
+ #######################################
926
+
927
+ #######################################
928
+ # cpt_trainer #
929
+ #######################################
930
+ if (
931
+ "records_cap" in self.cpt_training_args and
932
+ self.cpt_training_args["records_cap"] is not None and
933
+ isinstance(self.cpt_training_args["records_cap"], int)
934
+ ):
935
+ cpt_dataset = cpt_dataset.take(
936
+ self.cpt_training_args["records_cap"])
937
+ print(f"cpt_dataset : {cpt_dataset}")
938
+
939
+ train_args = UnslothTrainingArguments(
940
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
941
+ per_device_train_batch_size=2,
942
+ gradient_accumulation_steps=8,
943
+
944
+ **{k: v for k, v in self.cpt_training_args.items()
945
+ if k != "records_cap"},
946
+
947
+ # 2 to 10x smaller learning rate
948
+ # for the embedding matrices
949
+ learning_rate=5e-5,
950
+ embedding_learning_rate=1e-5,
951
+
952
+ fp16=not is_bfloat16_supported(),
953
+ bf16=is_bfloat16_supported(),
954
+ logging_steps=1,
955
+ optim="adamw_8bit",
956
+ weight_decay=0.01,
957
+ lr_scheduler_type="linear",
958
+ #seed=3407,
959
+
960
+ output_dir=os.path.join(
961
+ self.unsloth_dir, "outputs", "cpt"),
962
+ save_total_limit = 2,
963
+
964
+ report_to="tensorboard",
965
+ logging_dir=os.path.join(
966
+ self.sft_model_dir,
967
+ "runs", "cpt")
968
+ )
969
+
970
+ self.cpt_traces_file_fullname = os.path.join(
971
+ self.unsloth_dir, "cpt_trainer_traces.txt")
972
+ print("Training started. " +
973
+ f"Check {self.cpt_traces_file_fullname} for live traces.",
974
+ flush=True)
975
+
976
+ trainer = UnslothTrainer(
977
+ model=model, tokenizer=tokenizer,
978
+ train_dataset=cpt_dataset,
979
+ dataset_text_field="tools",
980
+ max_seq_length=self.max_seq_length,
981
+ dataset_num_proc=2,
982
+ args=train_args,
983
+ )
984
+ #######################################
985
+
986
+ #######################################
987
+ # Show current memory stats #
988
+ #######################################
989
+ torch.cuda.ipc_collect()
990
+ torch.cuda.empty_cache()
991
+ _ = gc.collect()
992
+
993
+ gpu_stats = torch.cuda.get_device_properties(0)
994
+ self.start_gpu_memory = \
995
+ round(torch.cuda.max_memory_reserved()
996
+ / 1024 / 1024 / 1024, 3)
997
+ self.max_memory = \
998
+ round(gpu_stats.total_memory
999
+ / 1024 / 1024 / 1024, 3)
1000
+ print(f"GPU = {gpu_stats.name}. " +
1001
+ f"Max memory = {self.max_memory} GB.")
1002
+ print(f"{self.start_gpu_memory} GB of memory reserved.")
1003
+ #######################################
1004
+
1005
+ with open(self.cpt_traces_file_fullname, 'w') as f:
1006
+ with redirect_stdout(f):
1007
+ trainer_stats = trainer.train()
1008
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1009
+ f"seconds used for training " +
1010
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1011
+ f" minutes).")
1012
+
1013
+ self.cpt_log_history = trainer.state.log_history
1014
+ # print(self.cpt_log_history)
1015
+ self.cpt_log_history_fig = \
1016
+ plot_log_history(
1017
+ self.cpt_log_history,
1018
+ title="Continued pretraining loss"
1019
+ )
1020
+
1021
+ model.save_pretrained_merged(
1022
+ save_directory=self.cpt_model_dir,
1023
+ tokenizer=tokenizer,
1024
+ save_method="lora"
1025
+ )
1026
+ print(f"cpt_model_dir : {self.cpt_model_dir}\n")
1027
+
1028
+ self.next(self.supervised_finetuning)
1029
+
1030
+
1031
+ @step
1032
+ def supervised_finetuning(self):
1033
+ """
1034
+ Trains the model on tool-calling
1035
+ task specialization.
1036
+ """
1037
+ from retrain_pipelines.model.hf_utils import \
1038
+ plot_log_history
1039
+
1040
+ torch.cuda.ipc_collect()
1041
+ torch.cuda.empty_cache()
1042
+ _ = gc.collect()
1043
+
1044
+ model, tokenizer = FastLanguageModel.from_pretrained(
1045
+ model_name=self.cpt_model_dir,
1046
+ max_seq_length=self.max_seq_length,
1047
+ dtype=None,
1048
+ load_in_4bit=False,
1049
+ )
1050
+ # !!!! bug fix BEGIN !!!!
1051
+ # otherwise, 'embed_tokens' and 'lm_head'
1052
+ # trained during CPT are "ignored",
1053
+ # i.e. not saved after SFT
1054
+ # (note that, alternatively, we could also
1055
+ # do this fix after sft-training and
1056
+ # just before saving ;
1057
+ # which would be equivalent to
1058
+ # freezing embeddings during finetuning
1059
+ # for better pretrained knowledge retention)
1060
+ # @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
1061
+ model.model.model.embed_tokens.modules_to_save.default.to(
1062
+ device="cuda:0",
1063
+ dtype=torch.float32,
1064
+ non_blocking=True)
1065
+ model.model.model.embed_tokens.modules_to_save.default \
1066
+ .requires_grad_(True)
1067
+ model.model.lm_head.modules_to_save.default.to(
1068
+ device="cuda:0",
1069
+ dtype=torch.float32,
1070
+ non_blocking=True)
1071
+ model.model.lm_head.modules_to_save.default \
1072
+ .requires_grad_(True)
1073
+ # !!!! bug fix END !!!!
1074
+
1075
+ #######################################
1076
+ # dataset prompt_template mapping #
1077
+ #######################################
1078
+ # download from Hub (or get from local cache)
1079
+ queries_dataset = load_dataset(
1080
+ path=self.dataset_commit_dict["repo_id"],
1081
+ name="supervised_finetuning",
1082
+ revision=self.dataset_commit_dict["commit_hash"],
1083
+ token=os.getenv("HF_TOKEN", None))
1084
+ print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
1085
+ self.sft_prompt_template = dedent("""
1086
+ You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
1087
+
1088
+ Rules:
1089
+ 1. You can only use tools you know. Do not create new tools under any circumstances.
1090
+ 2. If a query does not match any known tool, return an empty list ([]).
1091
+ 3. If information is missing to use a known tool, do not attempt to use it.
1092
+ 4. Your response must always be a valid JSON array, and nothing else.
1093
+
1094
+ Be precise and do not guess.
1095
+
1096
+ # query:
1097
+ {}
1098
+ # response:
1099
+ {}
1100
+ """).strip()
1101
+ tokenizer.chat_template = self.sft_prompt_template
1102
+
1103
+ EOS_TOKEN = tokenizer.eos_token
1104
+ def formatting_prompts_func(records):
1105
+ query = records["query"]
1106
+ tools = records["answers"]
1107
+ outputs = []
1108
+ for query, tools in zip(query, tools):
1109
+ # Must add EOS_TOKEN,
1110
+ # otherwise your generation will go on forever
1111
+ text = self.sft_prompt_template.format(query, tools) \
1112
+ + EOS_TOKEN
1113
+ outputs.append(text)
1114
+ return { "text" : outputs, }
1115
+ sft_train_dataset = queries_dataset["train"].map(
1116
+ formatting_prompts_func, batched=True)
1117
+ sft_valid_dataset = queries_dataset["validation"].map(
1118
+ formatting_prompts_func, batched=True,)
1119
+ #######################################
1120
+
1121
+ #######################################
1122
+ # PEFT adapter #
1123
+ # for supervised finetuning #
1124
+ #######################################
1125
+ # for cases where CPT has been merged into overall model
1126
+ # otherwize, keep on training current LoRa adapter
1127
+ # model = FastLanguageModel.get_peft_model(
1128
+ # model,
1129
+ # r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
1130
+ # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
1131
+ # "gate_proj", "up_proj", "down_proj"],
1132
+ # lora_alpha = 32,
1133
+ # lora_dropout = 0, # Supports any, but = 0 is optimized
1134
+ # bias = "none", # Supports any, but = "none" is optimized
1135
+ # # True or "unsloth" for very long context
1136
+ # use_gradient_checkpointing = "unsloth",
1137
+ # random_state = 3407,
1138
+ # use_rslora = True, # rank stabilized LoRA
1139
+ # loftq_config = None, # LoftQ
1140
+ # )
1141
+ #######################################
1142
+
1143
+ #######################################
1144
+ # sft_trainer #
1145
+ #######################################
1146
+ split = sft_train_dataset.train_test_split(
1147
+ test_size=1000,
1148
+ #seed=42
1149
+ )
1150
+ train_dataset = split['train']
1151
+ eval_dataset = split['test']
1152
+ if (
1153
+ "records_cap" in self.sft_training_args and
1154
+ self.sft_training_args["records_cap"] is not None and
1155
+ isinstance(self.sft_training_args["records_cap"], int)
1156
+ ):
1157
+ train_dataset = train_dataset.take(
1158
+ self.sft_training_args["records_cap"])
1159
+ eval_dataset = eval_dataset.take(
1160
+ self.sft_training_args["records_cap"])
1161
+ print(f"train_dataset : {train_dataset}")
1162
+ print(f"eval_dataset : {eval_dataset}")
1163
+
1164
+ train_args = UnslothTrainingArguments(
1165
+ per_device_train_batch_size=2,
1166
+ gradient_accumulation_steps=8,
1167
+
1168
+ **{k: v for k, v in self.sft_training_args.items()
1169
+ if k != "records_cap"},
1170
+
1171
+ per_device_eval_batch_size=2,
1172
+ eval_steps=200,
1173
+ eval_strategy="steps",
1174
+ do_eval=True,
1175
+
1176
+ learning_rate=5e-5,
1177
+ # embedding_learning_rate=1e-5, # Optionally here
1178
+
1179
+ fp16=not is_bfloat16_supported(),
1180
+ bf16=is_bfloat16_supported(),
1181
+
1182
+ optim="adamw_8bit",
1183
+ weight_decay=0.00,
1184
+ lr_scheduler_type="linear",
1185
+ #seed=3407,
1186
+
1187
+ output_dir=os.path.join(
1188
+ self.unsloth_dir, "outputs", "sft"),
1189
+ save_total_limit=2,
1190
+
1191
+ disable_tqdm=True,
1192
+ logging_steps=1,
1193
+ report_to="tensorboard",
1194
+ logging_dir=os.path.join(
1195
+ self.sft_model_dir,
1196
+ "runs", "sft")
1197
+ )
1198
+
1199
+ self.sft_traces_file_fullname = os.path.join(
1200
+ self.unsloth_dir, "sft_trainer_traces.txt")
1201
+ print("Training started. " +
1202
+ f"Check {self.sft_traces_file_fullname} for live traces.",
1203
+ flush=True)
1204
+
1205
+ trainer = UnslothTrainer(
1206
+ model=model, tokenizer=tokenizer,
1207
+ train_dataset=train_dataset,
1208
+ dataset_text_field="text",
1209
+ eval_dataset=eval_dataset,
1210
+ max_seq_length=self.max_seq_length,
1211
+ dataset_num_proc=8,
1212
+ args=train_args
1213
+ )
1214
+ trainer.can_return_loss = True
1215
+ #######################################
1216
+
1217
+ #######################################
1218
+ # Show current memory stats #
1219
+ #######################################
1220
+ torch.cuda.ipc_collect()
1221
+ torch.cuda.empty_cache()
1222
+ _ = gc.collect()
1223
+
1224
+ used_memory = \
1225
+ round(torch.cuda.max_memory_reserved()
1226
+ /1024/1024/1024, 3)
1227
+ used_memory_for_lora = \
1228
+ round(used_memory-self.start_gpu_memory, 3)
1229
+ used_percentage = \
1230
+ round(used_memory/self.max_memory*100, 3)
1231
+ lora_percentage = \
1232
+ round(used_memory_for_lora/self.max_memory*100,
1233
+ 3)
1234
+ print(f"Peak reserved memory = " +
1235
+ f"{used_memory} GB.")
1236
+ print(f"Peak reserved memory for " +
1237
+ f"training = {used_memory_for_lora} " +
1238
+ f"GB.")
1239
+ print(f"Peak reserved memory % of " +
1240
+ f"max memory = {used_percentage} %.")
1241
+ print(f"Peak reserved memory for training " +
1242
+ f"% of max memory = {lora_percentage} %.")
1243
+ #######################################
1244
+
1245
+ with open(self.sft_traces_file_fullname, 'w') as f:
1246
+ with redirect_stdout(f):
1247
+ trainer_stats = trainer.train()
1248
+
1249
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1250
+ f"seconds used for training " +
1251
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1252
+ f" minutes).")
1253
+
1254
+ self.sft_log_history = trainer.state.log_history
1255
+ self.sft_log_history_fig = \
1256
+ plot_log_history(
1257
+ self.sft_log_history,
1258
+ title="Supervised finetuning loss"
1259
+ )
1260
+
1261
+ model.save_pretrained_merged(
1262
+ self.sft_model_dir, tokenizer,
1263
+ save_method = "lora"
1264
+ )
1265
+ print(f"sft_model_dir : {self.sft_model_dir}\n")
1266
+
1267
+ self.next(self.evaluate_model)
1268
+
1269
+
1270
+ @step
1271
+ def evaluate_model(self):
1272
+ """
1273
+ Batch inference on the SFT validation dataset.
1274
+ """
1275
+ from retrain_pipelines.model import \
1276
+ infer_validation, compute_counts_n_metrics, \
1277
+ plot_validation_completions
1278
+
1279
+ torch.cuda.ipc_collect()
1280
+ torch.cuda.empty_cache()
1281
+ _ = gc.collect()
1282
+
1283
+
1284
+ ######################################################
1285
+ # loading trained adapter #
1286
+ ######################################################
1287
+ # Unsloth [and hf transformers before it] #
1288
+ # (if loading both model & tokenizer at once #
1289
+ # same as we did in prior tasks, but now #
1290
+ # with tokenizer.chat_template being set #
1291
+ # in tokenizer.config) is forcing on us some kind of #
1292
+ # chat_template format hard-requirements. #
1293
+ ######################################################
1294
+ # load base from cache
1295
+ # (with base tokenizer, which we ignore)
1296
+ model, _ = FastLanguageModel.from_pretrained(
1297
+ model_name=self.hf_base_model_dict["repo_id"],
1298
+ revision=self.hf_base_model_dict["commit_hash"],
1299
+ max_seq_length=self.max_seq_length,
1300
+ dtype=None,
1301
+ load_in_4bit=False,
1302
+ # case of a gated or private base-model
1303
+ token=os.getenv("HF_TOKEN", None)
1304
+ )
1305
+ model = FastLanguageModel.for_inference(model)
1306
+ # load our CPT+SFT trained & locally-saved adapter
1307
+ model.load_adapter(peft_model_id=self.sft_model_dir)
1308
+ # Separately load our (potentially trained &)
1309
+ # locally-saved adapter-tokenizer
1310
+ # (loading it below via HF and not Unsloth)
1311
+ tokenizer = AutoTokenizer.from_pretrained(
1312
+ pretrained_model_name_or_path=self.sft_model_dir
1313
+ )
1314
+ ######################################################
1315
+
1316
+ ######################################################
1317
+ # validation dataset #
1318
+ ######################################################
1319
+ # download from Hub (or get from local cache)
1320
+ queries_dataset = load_dataset(
1321
+ path=self.dataset_commit_dict["repo_id"],
1322
+ name="supervised_finetuning",
1323
+ revision=self.dataset_commit_dict["commit_hash"],
1324
+ token=os.getenv("HF_TOKEN", None))
1325
+ if (
1326
+ "records_cap" in self.sft_training_args and
1327
+ self.sft_training_args["records_cap"] is not None and
1328
+ isinstance(self.sft_training_args["records_cap"], int)
1329
+ ):
1330
+ validation_data = queries_dataset["validation"].take(
1331
+ self.sft_training_args["records_cap"])
1332
+ else:
1333
+ validation_data = queries_dataset["validation"]
1334
+ print(validation_data, flush=True)
1335
+ ######################################################
1336
+
1337
+ self.max_new_tokens = 400
1338
+ start_time = time.time()
1339
+ validation_results = infer_validation(
1340
+ tokenizer=tokenizer,
1341
+ model=model,
1342
+ validation_data=validation_data,
1343
+ prompt_template=tokenizer.chat_template,
1344
+ batch_size=32, # 64,
1345
+ queries_attr_name=\
1346
+ self.hf_dataset["attributes"]["query_attr"],
1347
+ answers_attr_name=\
1348
+ self.hf_dataset["attributes"]["answers_attr"],
1349
+ max_new_tokens=self.max_new_tokens,
1350
+ device="cuda"
1351
+ )
1352
+ print("infer_validation - Elapsed time: " +
1353
+ f"{(time.time() - start_time):.2f} seconds")
1354
+ self.validation_results = validation_results # <= to artifacts store
1355
+
1356
+ eval_df = pl.LazyFrame(validation_results)
1357
+
1358
+ records = eval_df.with_columns(
1359
+ (pl.col("answer") == pl.col("completion")) \
1360
+ .alias("is_ground_truth_identical")
1361
+ ).collect() #engine=self.engine)
1362
+ print("perfect characters-match accuracy : " +
1363
+ str(records['is_ground_truth_identical'].mean()))
1364
+
1365
+ eval_metrics_df = compute_counts_n_metrics(
1366
+ eval_df, is_format_fault_tolerant=True)
1367
+ overall_metrics_df = eval_metrics_df.select([
1368
+ pl.col("precision").mean(),
1369
+ pl.col("recall").mean(),
1370
+ pl.col("f1").mean(),
1371
+ pl.col("jaccard").mean()
1372
+ ]).collect() #engine=self.engine)
1373
+ self.perf_metrics = overall_metrics_df.row(0, named=True)
1374
+ print(self.perf_metrics)
1375
+
1376
+ self.validation_completions_fig = \
1377
+ plot_validation_completions(
1378
+ eval_metrics_df, engine=self.engine)
1379
+
1380
+ del model
1381
+ del tokenizer
1382
+ torch.cuda.ipc_collect()
1383
+ torch.cuda.empty_cache()
1384
+ _ = gc.collect()
1385
+
1386
+ self.next(self.model_version_blessing)
1387
+
1388
+
1389
+ @step
1390
+ def model_version_blessing(self):
1391
+ """
1392
+ Comparing newly-retrained model version
1393
+ against best-performing predecessor.
1394
+ """
1395
+ """
1396
+ Note: for Hugging Face integrated pipelines,
1397
+ we compare against lastest commit of main branch
1398
+ of the model repository there.
1399
+ When it comes to local "mf_run_id" of the pipeline run
1400
+ having generated that best prior model version
1401
+ (retrieved from model card metadata from HF yaml section),
1402
+ we check against records of the herein ML-framework instance,
1403
+ as "prior best version" of the model here beign retrained
1404
+ may have been originated from another one
1405
+ than the one executing the current retraining
1406
+ (in which case, we simply don't includ a "local" hyperlink
1407
+ in the model version pipeline_cards that will be
1408
+ produced later in the herein pipeline run).
1409
+ """
1410
+ from retrain_pipelines.model.hf_utils import \
1411
+ current_blessed_model_version_dict
1412
+
1413
+ main_perf_metric_name = "jaccard"
1414
+
1415
+ current_blessed_version_dict = \
1416
+ current_blessed_model_version_dict(
1417
+ repo_id=self.model_repo_id,
1418
+ hf_token=os.getenv("HF_TOKEN", None)
1419
+ )
1420
+ print("current_blessed_version_dict : " +
1421
+ str(current_blessed_version_dict))
1422
+
1423
+ if current_blessed_version_dict is None:
1424
+ print("case 'no prior blessed model version found"
1425
+ " => blessing.'")
1426
+ self.model_version_blessed = True
1427
+
1428
+ elif (
1429
+ main_perf_metric_name in
1430
+ current_blessed_version_dict["perf_metrics"]
1431
+ ):
1432
+ current_blessed_run_id = \
1433
+ current_blessed_version_dict["mf_run_id"]
1434
+ print(f"current_blessed_run_id : {current_blessed_run_id}")
1435
+ current_blessed_metric_value = \
1436
+ current_blessed_version_dict[
1437
+ "perf_metrics"][main_perf_metric_name]
1438
+
1439
+ self.model_version_blessed = (
1440
+ self.perf_metrics[main_perf_metric_name] >=
1441
+ current_blessed_metric_value
1442
+ )
1443
+
1444
+ if not self.model_version_blessed:
1445
+ self.current_blessed_version_dict = \
1446
+ current_blessed_version_dict
1447
+ for run in Flow(self.__class__.__name__):
1448
+ if str(run.id) == current_blessed_run_id:
1449
+ run_steps = iter(run.steps())
1450
+ last_run_step = next(run_steps)
1451
+ last_task = next(iter(last_run_step.tasks()))
1452
+
1453
+ # tasks are listed backwards, so last task is first item :
1454
+ # Has the run seen task "pipeline_card" prior to last task
1455
+ # (meaning, "pipeline_card" completed successfully and
1456
+ # "run" has generated a sutom pipeline-card artifact) ?
1457
+ # If not, hyperlink generation will later fail.
1458
+ run_has_custom_card_artifact = False
1459
+ for step in run_steps:
1460
+ if "pipeline_card" == step.id:
1461
+ run_has_custom_card_artifact = True
1462
+ break
1463
+
1464
+ if not run_has_custom_card_artifact:
1465
+ print(
1466
+ f"Run #{current_blessed_run_id} " +
1467
+ "Doesn't seem to have successfully " +
1468
+ "generated a pipeline-card artifact.",
1469
+ file=sys.stderr, flush=True)
1470
+ break
1471
+ else:
1472
+ # further filtering on successful runs that are
1473
+ # retraining of a prior version of the same model
1474
+ # (to minimize the risk that this was obtained
1475
+ # on another ML-framework instance)
1476
+ if (
1477
+ # last_task.successful and
1478
+ # may have failed after the "pipeline_card" step
1479
+ # and been resumed
1480
+ hasattr(last_task.artifacts,
1481
+ 'model_version_blessed') and
1482
+ last_task.artifacts.model_version_blessed.data and
1483
+ hasattr(last_task.artifacts,
1484
+ 'model_repo_id') and
1485
+ last_task.artifacts.model_repo_id.data == \
1486
+ self.model_repo_id
1487
+ ):
1488
+ self.current_blessed_run = run
1489
+ break
1490
+
1491
+ if not self.current_blessed_run:
1492
+ print(
1493
+ "Couldn't find blessed run " +
1494
+ f"{current_blessed_run_id} !\n" +
1495
+ "It seems that prior blessed run was " +
1496
+ "executed on another ML framework instance.",
1497
+ file=sys.stderr, flush=True)
1498
+
1499
+ print("new : " +
1500
+ str(self.perf_metrics[main_perf_metric_name]) +
1501
+ " - previous best : " +
1502
+ str(current_blessed_metric_value) +
1503
+ " - model_version_blessing : " +
1504
+ str(self.model_version_blessed))
1505
+
1506
+ else:
1507
+ raise Exception(
1508
+ "Performance metric '" +
1509
+ main_perf_metric_name +
1510
+ "' can't be found in eval results " +
1511
+ "from blessed run " +
1512
+ str(current_blessed_version_dict[
1513
+ "mf_run_id"]) + " !")
1514
+
1515
+ # self.model_version_blessed = True ### DEBUG - DELETE ###
1516
+
1517
+ self.next(self.model_to_hub)
1518
+
1519
+
1520
+ @step
1521
+ def model_to_hub(self):
1522
+ """
1523
+ Push to hub model version, including
1524
+ readme with versioning info.
1525
+ """
1526
+
1527
+ #############################
1528
+ # case of user-provided #
1529
+ # documentation artifact(s) #
1530
+ #############################
1531
+ # note that user can provide either
1532
+ # 'pipeline_card.py' or 'template.html'
1533
+ # or 'dataset_readme.py'
1534
+ # or 'dataset_readme_template.md'
1535
+ # or 'model_readme.py'
1536
+ # or 'model_readme_template.md'
1537
+ # or any combination of those
1538
+ # when specifying custom
1539
+ # 'pipeline_card_artifacts_path'
1540
+ if (
1541
+ "model_readme_template.md" in
1542
+ os.listdir(self.pipeline_card_artifacts_path)
1543
+ ):
1544
+ template_dir = self.pipeline_card_artifacts_path
1545
+ else:
1546
+ template_dir = os.path.dirname(
1547
+ importlib.util.find_spec(
1548
+ f"retrain_pipelines.pipeline_card."+
1549
+ f"{os.getenv('retrain_pipeline_type')}"
1550
+ ).origin)
1551
+ print(f"template_dir : '{template_dir}'")
1552
+ #############################
1553
+ if "model_readme.py" in os.listdir(
1554
+ self.pipeline_card_artifacts_path):
1555
+ from retrain_pipelines.utils import \
1556
+ get_get_model_readme_content
1557
+ get_model_readme_content = \
1558
+ get_get_model_readme_content(
1559
+ self.pipeline_card_artifacts_path)
1560
+ else:
1561
+ from retrain_pipelines.pipeline_card import \
1562
+ get_model_readme_content
1563
+ #############################
1564
+ from retrain_pipelines.model.hf_utils import \
1565
+ push_model_version_to_hub
1566
+
1567
+ #############################
1568
+ # model README #
1569
+ # from template #
1570
+ #############################
1571
+ commit_datetime = datetime.utcnow()
1572
+ new_model_version_label = get_new_repo_minor_version(
1573
+ repo_id=self.model_repo_id,
1574
+ repo_type="model",
1575
+ hf_token=os.getenv("HF_TOKEN", None))
1576
+ readme_content = get_model_readme_content(
1577
+ template_folder=template_dir,
1578
+
1579
+ model_repo_id=self.model_repo_id,
1580
+
1581
+ base_model_dict=self.hf_base_model_dict,
1582
+ training_dataset_dict=self.dataset_commit_dict,
1583
+
1584
+ version_label=new_model_version_label,
1585
+ commit_datetime=commit_datetime,
1586
+ perf_metrics=self.perf_metrics,
1587
+
1588
+ mf_flow_name=current.flow_name,
1589
+ mf_run_id=current.run.id
1590
+ )
1591
+ #############################
1592
+
1593
+ print("Pushing model version to HF hub " +
1594
+ ("(blessed). " if self.model_version_blessed
1595
+ else "(not blessed). ") +
1596
+ "May take a while..",
1597
+ flush=True)
1598
+ model_commit_hash = push_model_version_to_hub(
1599
+ repo_id=self.model_repo_id,
1600
+ model_version_blessed=\
1601
+ self.model_version_blessed,
1602
+ version_label=new_model_version_label,
1603
+ timestamp_str=commit_datetime.strftime(
1604
+ "%Y-%m-%d %H:%M:%S UTC"),
1605
+ model_dir=self.sft_model_dir,
1606
+ model_readme_content=readme_content,
1607
+ hf_token=os.getenv("HF_TOKEN", None)
1608
+ )
1609
+ if not model_commit_hash:
1610
+ raise Exception(
1611
+ "Failed to publish model version.")
1612
+ print("Push of model version to HF hub completed.",
1613
+ flush=True)
1614
+ print(f"https://huggingface.co/{self.model_repo_id}" +
1615
+ f"/blob/{model_commit_hash}/README.md")
1616
+
1617
+ self.model_commit_dict = {
1618
+ "repo_id": self.model_repo_id,
1619
+ "commit_hash": model_commit_hash,
1620
+ "version_label": new_model_version_label,
1621
+ "commit_datetime": commit_datetime,
1622
+ }
1623
+
1624
+ self.next(self.infra_validator)
1625
+
1626
+
1627
+ @step
1628
+ def infra_validator(self):
1629
+ """
1630
+ If the trained model version is blessed,
1631
+ validate serving.
1632
+ """
1633
+ """
1634
+ Note that using isolated virtual env
1635
+ (using @conda task decorator)
1636
+ is advisable to not embark the whole
1637
+ pipeline dependencies into the local server.
1638
+ We don't for educational purpose,
1639
+ keep things "simple" to grasp
1640
+ as well as to avoid forcing conda
1641
+ (for instance miniconda) as
1642
+ a virtual environment management mean
1643
+ to the user.
1644
+ """
1645
+ """
1646
+ Note : We load base model from HF-cache
1647
+ (mounted as /huggingface_hub_cache
1648
+ docker volume) and adapter from local dir
1649
+ (mounted as /FuncCallAdater docker volume.
1650
+ """
1651
+
1652
+ self.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
1653
+
1654
+ if self.model_version_blessed:
1655
+ from retrain_pipelines.utils.docker import \
1656
+ env_has_docker
1657
+
1658
+ if env_has_docker():
1659
+ model_module_dir = \
1660
+ os.path.dirname(
1661
+ importlib.util.find_spec(
1662
+ "retrain_pipelines.model." +
1663
+ os.getenv('retrain_pipeline_type')
1664
+ ).origin)
1665
+
1666
+ # server & data-model & server-config modules artifacts
1667
+ files_to_copy = [
1668
+ "litserve_server.py",
1669
+ "litserve_datamodel.py",
1670
+ "litserve_serverconfig.py",
1671
+ ".dockerignore" # docker context loading
1672
+ # at image-build time,
1673
+ # exclude model weights
1674
+ ]
1675
+ for filename in files_to_copy:
1676
+ shutil.copy(
1677
+ os.path.join(model_module_dir, "litserve",
1678
+ filename),
1679
+ os.path.join(self.serving_artifacts_local_folder,
1680
+ filename)
1681
+ )
1682
+
1683
+ # save dependencies as artifact
1684
+ create_requirements(self.serving_artifacts_local_folder,
1685
+ exclude=["cudf-polars-.*", "cuda-python",
1686
+ "nvidia-.*", "(py)?libcudf-.*",
1687
+ "nvtx", "rmm-.*", "litserve",
1688
+ ".*retrain-pipelines.*"]
1689
+ )
1690
+
1691
+ # server config yaml
1692
+ env = Environment(loader=FileSystemLoader(
1693
+ os.path.join(model_module_dir, "litserve")))
1694
+ template = env.get_template(
1695
+ "litserve_serverconfig_template.yaml")
1696
+ server_config_data = {
1697
+ "port": "8000",
1698
+ "max_seq_length": self.max_seq_length,
1699
+ "max_new_token": self.max_new_tokens,
1700
+ "base_model": {
1701
+ "repo_id": self.hf_base_model_dict["repo_id"],
1702
+ "revision": self.hf_base_model_dict["commit_hash"]
1703
+ },
1704
+ "adapters": [
1705
+ {
1706
+ "name": "func_caller",
1707
+ "path": "/FuncCallAdapter"
1708
+ }
1709
+ ]
1710
+ }
1711
+ server_config_yaml = template.render(server_config_data)
1712
+ print(server_config_yaml)
1713
+ with open(os.path.join(
1714
+ self.serving_artifacts_local_folder,
1715
+ "litserve_serverconfig.yaml"), 'w'
1716
+ ) as output_file:
1717
+ output_file.write(server_config_yaml)
1718
+
1719
+ # Dockerfile
1720
+ env = Environment(loader=FileSystemLoader(
1721
+ os.path.join(model_module_dir)))
1722
+ template = env.get_template(
1723
+ "Dockerfile.litserve_template")
1724
+ # Change CUDA version here from available list
1725
+ # @see https://hub.docker.com/r/nvidia/cuda/tags
1726
+ dockerfile_content = template.render(
1727
+ {"cuda_version": "12.0.0"})
1728
+ with open(os.path.join(
1729
+ self.serving_artifacts_local_folder,
1730
+ "Dockerfile.litserve"), 'w'
1731
+ ) as output_file:
1732
+ output_file.write(dockerfile_content)
1733
+
1734
+ os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
1735
+
1736
+ ############################################
1737
+ # actually deploy the inference service #
1738
+ ############################################
1739
+ start_time = time.time()
1740
+ from retrain_pipelines.utils.docker import \
1741
+ build_and_run_docker, print_container_log_tail, \
1742
+ cleanup_docker
1743
+ from retrain_pipelines.model.litserve import \
1744
+ endpoint_started, endpoint_is_ready
1745
+
1746
+ self.port = 8765
1747
+ HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
1748
+ os.getenv(
1749
+ "HF_HUB_CACHE",
1750
+ os.path.join(os.getenv("HF_HOME",
1751
+ "~/.cache/huggingface"),
1752
+ "hub")
1753
+ )))
1754
+ print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
1755
+ image_name = container_name = "litserve-model"
1756
+
1757
+ serving_container = build_and_run_docker(
1758
+ image_name=image_name, image_tag="1.0",
1759
+ build_path=self.serving_artifacts_local_folder,
1760
+ dockerfile="Dockerfile.litserve",
1761
+ ports_publish_dict={'8000/tcp': self.port},
1762
+ env_vars_dict={
1763
+ "HF_HUB_CACHE": "/huggingface_hub_cache",
1764
+ "HF_TOKEN": os.getenv("HF_TOKEN")
1765
+ },
1766
+ volumes_dict={
1767
+ self.sft_model_dir:
1768
+ {"bind": "/FuncCallAdapter",
1769
+ "mode": "ro"},
1770
+ HF_HUB_CACHE:
1771
+ {"bind": "/huggingface_hub_cache",
1772
+ "mode": "ro"}
1773
+ }
1774
+ )
1775
+
1776
+ if not serving_container:
1777
+ print("failed spinning the LitServe container",
1778
+ file=sys.stderr)
1779
+ self.local_serve_is_ready = \
1780
+ LocalServeReadinessEnum.FAILURE
1781
+ try:
1782
+ cleanup_docker(
1783
+ container_name=container_name,
1784
+ image_name=f"{image_name}:1.0",
1785
+ no_pruning=True # for intermediate layers recycling
1786
+ # (during later re-runs)
1787
+ # to avoid long rebuild time
1788
+ # of exactly the same.
1789
+ )
1790
+ except Exception as cleanup_ex:
1791
+ # fail silently
1792
+ pass
1793
+ else:
1794
+ print("Awaiting endpoint launch..")
1795
+ start_time = time.time()
1796
+ if not endpoint_started(
1797
+ container_name, port=self.port, timeout=10*60
1798
+ ):
1799
+ print(
1800
+ f"The endpoint '{container_name}' " +
1801
+ f"did not start.")
1802
+ self.local_serve_is_ready = \
1803
+ LocalServeReadinessEnum.FAILURE
1804
+ # health check on the spun-up endpoint
1805
+ elif endpoint_is_ready(port=self.port):
1806
+ self.local_serve_is_ready = \
1807
+ LocalServeReadinessEnum.SUCCESS
1808
+ elapsed_time = time.time() - start_time
1809
+ print("deploy_local - Elapsed time: " +
1810
+ f"{elapsed_time:.2f} seconds")
1811
+ ############################################
1812
+ else:
1813
+ # env doesn't have docker
1814
+ self.local_serve_is_ready = \
1815
+ LocalServeReadinessEnum.FAILURE_NO_DOCKER
1816
+
1817
+ if LocalServeReadinessEnum.SUCCESS == self.local_serve_is_ready:
1818
+ from retrain_pipelines.model.litserve.litserve_datamodel \
1819
+ import Response
1820
+
1821
+ import requests
1822
+
1823
+ url = f"http://localhost:{self.port}/predict"
1824
+ headers = {"accept": "application/x-www-form-urlencoded"}
1825
+
1826
+ try:
1827
+ start_time = time.time()
1828
+ data = {
1829
+ "adapter_name": "func_caller",
1830
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1831
+ }
1832
+ print(f"inference test - data: {data}")
1833
+ response = requests.post(url, headers=headers, data=data)
1834
+ parsed_response = Response(**{"output": response.json()})
1835
+ elapsed_time = time.time() - start_time
1836
+ print("parsed_response ('func_caller' adapter ON) :" +
1837
+ str(parsed_response) +
1838
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1839
+
1840
+ start_time = time.time()
1841
+ data = {
1842
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1843
+ }
1844
+ print(f"inference test - data: {data}")
1845
+ response = requests.post(url, headers=headers, data=data)
1846
+ parsed_response = Response(**{"output": response.json()})
1847
+ elapsed_time = time.time() - start_time
1848
+ print(f"parsed_response (no adapter) : {parsed_response}" +
1849
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1850
+
1851
+ except Exception as ex:
1852
+ print(ex, file=sys.stderr)
1853
+ traceback.print_tb(ex.__traceback__, file=sys.stderr)
1854
+ self.local_serve_is_ready = \
1855
+ LocalServeReadinessEnum.FAILURE
1856
+ pass
1857
+
1858
+ try:
1859
+ cleanup_docker(
1860
+ container_name=container_name,
1861
+ image_name=f"{image_name}:1.0",
1862
+ no_pruning=True # for intermediate layers recycling
1863
+ # (during later re-runs)
1864
+ # to avoid long rebuild time
1865
+ # of exactly the same.
1866
+ )
1867
+ except Exception as cleanup_ex:
1868
+ # fail silently
1869
+ pass
1870
+
1871
+ self.next(self.pipeline_card)
1872
+
1873
+
1874
+ @card(id='default')
1875
+ @card(type='html', id='custom')
1876
+ @step
1877
+ def pipeline_card(self):
1878
+ import re
1879
+ import datetime
1880
+ import importlib.metadata
1881
+
1882
+ #############################
1883
+ # case of user-provided #
1884
+ # documentation artifact(s) #
1885
+ #############################
1886
+ # note that user can provide either
1887
+ # 'pipeline_card.py' or 'template.html'
1888
+ # or 'dataset_readme.py'
1889
+ # or 'dataset_readme_template.md'
1890
+ # or 'model_readme.py'
1891
+ # or 'model_readme_template.md'
1892
+ # or any combination of those
1893
+ # when specifying custom
1894
+ # 'pipeline_card_artifacts_path'
1895
+ if "template.html" in os.listdir(
1896
+ self.pipeline_card_artifacts_path
1897
+ ):
1898
+ template_dir = self.pipeline_card_artifacts_path
1899
+ else:
1900
+ template_dir = os.path.dirname(
1901
+ importlib.util.find_spec(
1902
+ f"retrain_pipelines.pipeline_card."+
1903
+ f"{os.getenv('retrain_pipeline_type')}"
1904
+ ).origin)
1905
+ #############################
1906
+ if "pipeline_card.py" in os.listdir(
1907
+ self.pipeline_card_artifacts_path
1908
+ ):
1909
+ from retrain_pipelines.utils import get_get_html
1910
+ get_html = \
1911
+ get_get_html(self.pipeline_card_artifacts_path)
1912
+ else:
1913
+ from retrain_pipelines.pipeline_card import \
1914
+ get_html
1915
+ from retrain_pipelines.pipeline_card.helpers import \
1916
+ mf_dag_svg
1917
+ #############################
1918
+
1919
+
1920
+ #############################
1921
+ ## "default" card ##
1922
+ #############################
1923
+ self.metadata = {
1924
+ "name": "TabNet Model",
1925
+ "version": "1.0",
1926
+ "retrain_pipelines": f"retrain-pipelines {__version__}",
1927
+ "retrain_pipeline_type": os.environ["retrain_pipeline_type"],
1928
+ "description": "A PyTorch TabNet model retrained",
1929
+ "authors": [current.username],
1930
+ "tags": ["classification", "tabnet"],
1931
+ "license": "MIT License",
1932
+ "data_augmentation": [
1933
+ {
1934
+ "name": "Augmentation",
1935
+ "description": "Truncating queries and " + \
1936
+ "associate those to " + \
1937
+ "no tool-call answers. " + \
1938
+ "Intent being to instruct on " + \
1939
+ "not hallucinating missing " + \
1940
+ "tool-calls parameters values."
1941
+ },
1942
+ {
1943
+ "name": "Enrichment",
1944
+ "description": "Addition of records " + \
1945
+ "from an external data-source. " + \
1946
+ "Here to instruct on no tool-call."
1947
+ }
1948
+ ],
1949
+ "references": [
1950
+ {
1951
+ "title": "Base model",
1952
+ "link": f"https://hf.co/{self.hf_base_model_dict['repo_id']}"
1953
+ },
1954
+ {
1955
+ "title": "Function-calling dataset",
1956
+ "link": f"https://hf.co/{self.hf_dataset_dict['repo_id']}"
1957
+ },
1958
+ {
1959
+ "title": "Data-enrichment dataset",
1960
+ "link": f"https://hf.co/{self.hf_enrich_dataset_dict['repo_id']}"
1961
+ },
1962
+ {
1963
+ "title": "Unsloth",
1964
+ "link": "https://unsloth.ai/blog/contpretraining"
1965
+ }
1966
+ ]
1967
+ }
1968
+
1969
+ current.card['default'].append(Markdown(
1970
+ "model_version_blessed : **%s**" % str(self.model_version_blessed)))
1971
+ current.card['default'].append(Artifact(
1972
+ {"model_version_blessed": self.model_version_blessed}))
1973
+
1974
+ current.card['default'].append(
1975
+ Image.from_matplotlib(self.sft_log_history_fig))
1976
+ current.card['default'].append(
1977
+ Image.from_matplotlib(self.validation_completions_fig))
1978
+ #############################
1979
+
1980
+ #############################
1981
+ ## html "custom" card ##
1982
+ #############################
1983
+ dt = datetime.datetime.now(tz=datetime.timezone.utc)
1984
+ formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
1985
+ task_obj_python_cmd = f"metaflow.Task(" + \
1986
+ f"\"{current.pathspec}\", " + \
1987
+ f"attempt={str(current.retry_count)})"
1988
+ params={
1989
+ 'template_dir': template_dir,
1990
+ 'title': f"{current.flow_name}",
1991
+ "subtitle": f"(flow run # {len(list(current.run.parent.runs()))}," + \
1992
+ f" run_id: {str(current.run.id)} - {formatted_dt})",
1993
+
1994
+ # blessed status / current_blessed version
1995
+ 'model_version_blessed': self.model_version_blessed,
1996
+ 'current_blessed_version_label': (
1997
+ self.current_blessed_version_dict["version_label"]
1998
+ if self.current_blessed_version_dict
1999
+ else None
2000
+ ),
2001
+ 'current_blessed_commit_datetime': (
2002
+ self.current_blessed_version_dict["commit_datetime"]
2003
+ if self.current_blessed_version_dict
2004
+ else None
2005
+ ),
2006
+ 'current_blessed_model_commit_hash': (
2007
+ self.current_blessed_version_dict["commit_hash"]
2008
+ if self.current_blessed_version_dict
2009
+ else None
2010
+ ),
2011
+ 'current_blessed_run': self.current_blessed_run,
2012
+
2013
+ 'LocalServeReadinessEnum': LocalServeReadinessEnum,
2014
+ 'local_serve_is_ready': self.local_serve_is_ready,
2015
+ # EDA
2016
+ 'main_dataset_repo_id': self.hf_dataset['repo_id'],
2017
+ 'main_dataset_commit_hash': self.hf_dataset_dict['commit_hash'],
2018
+ 'main_dataset_commit_datetime': \
2019
+ self.hf_dataset_dict['commit_datetime'],
2020
+
2021
+ 'records_count': self.records_count,
2022
+ 'data_schema': self.data_schema,
2023
+ 'answers_tools_count_fig': self.answers_tools_count_fig,
2024
+ 'words_count_fig': self.words_count_fig,
2025
+
2026
+ # model training
2027
+ 'dataset_repo_id': self.dataset_repo_id,
2028
+ 'dataset_version_label': self.dataset_commit_dict["version_label"],
2029
+ 'dataset_commit_datetime': self.dataset_commit_dict["commit_datetime"],
2030
+ 'dataset_commit_hash': self.dataset_commit_dict["commit_hash"],
2031
+ 'dataset_augmentation_rate': self.actual_augmentation_rate,
2032
+ 'dataset_enrichment_rate': self.enrichment_rate,
2033
+
2034
+ 'model_repo_id': self.model_repo_id,
2035
+ 'model_version_label': self.model_commit_dict["version_label"],
2036
+ 'model_commit_datetime': self.model_commit_dict["commit_datetime"],
2037
+ 'model_commit_hash': self.model_commit_dict["commit_hash"],
2038
+
2039
+ 'cpt_log_history_fig': self.cpt_log_history_fig,
2040
+ 'sft_log_history_fig': self.sft_log_history_fig,
2041
+
2042
+ 'validation_completions_fig': self.validation_completions_fig,
2043
+
2044
+ 'pipeline_parameters_dict': {"cpt": self.cpt_training_args,
2045
+ "sft": self.sft_training_args},
2046
+
2047
+ 'metrics_dict': self.perf_metrics,
2048
+
2049
+ 'task_obj_python_cmd': task_obj_python_cmd,
2050
+ 'dag_svg': mf_dag_svg(self)
2051
+ }
2052
+ self.html = get_html(params)
2053
+ #############################
2054
+ current
2055
+ #############################
2056
+
2057
+ self.next(self.pipeline_to_hub)
2058
+
2059
+
2060
+ @step
2061
+ def pipeline_to_hub(self):
2062
+ """
2063
+ publish versioned source-code and pipeline-card
2064
+ for ths run on the Hugging Face Hub.
2065
+ """
2066
+
2067
+ model_commit_datetime = \
2068
+ self.model_commit_dict["commit_datetime"]
2069
+ timestamp_str = \
2070
+ "{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
2071
+ "{:03d}".format(model_commit_datetime.microsecond//1000) + \
2072
+ "_UTC"
2073
+ subfolder_name = \
2074
+ "v" + self.model_commit_dict["version_label"] + \
2075
+ "_" + timestamp_str
2076
+ commit_datetime = datetime.utcnow()
2077
+
2078
+ ###############################
2079
+ # source-code #
2080
+ ###############################
2081
+ # We upload only herein file #
2082
+ # plus user-provided versions #
2083
+ # of the customizable ones #
2084
+ # (if any). #
2085
+ ###############################
2086
+ custom_source_files = [os.path.abspath(__file__)]
2087
+ if (
2088
+ self.pipeline_card_artifacts_path != \
2089
+ self.default_pipeline_card_module_dir
2090
+ ):
2091
+ candidate_source_files = [
2092
+ "pipeline_card.py",
2093
+ "template.html",
2094
+ "dataset_readme.py",
2095
+ "dataset_readme_template.md",
2096
+ "model_readme.py",
2097
+ "model_readme_template.md"
2098
+ ]
2099
+ for candidate_source_file in candidate_source_files:
2100
+ file_fullpath = os.path.join(
2101
+ self.pipeline_card_artifacts_path,
2102
+ candidate_source_file)
2103
+ if os.path.exists(file_fullpath):
2104
+ custom_source_files.append(file_fullpath)
2105
+
2106
+ source_code_commit_hash = \
2107
+ push_files_to_hub_repo_branch(
2108
+ repo_id=self.model_repo_id,
2109
+ branch_name="retrain-pipelines_source-code",
2110
+ file_fullnames=custom_source_files,
2111
+ include_requirements_txt=True,
2112
+ path_in_repo=subfolder_name,
2113
+ commit_message=\
2114
+ "source-code for model version " + \
2115
+ subfolder_name + \
2116
+ f"- retrain-pipelines {__version__}",
2117
+ repo_type="model",
2118
+ hf_token=os.getenv("HF_TOKEN", None)
2119
+ )
2120
+ print(source_code_commit_hash)
2121
+ self.source_code_commit_dict = {
2122
+ "repo_id": self.model_repo_id,
2123
+ "branch_name": "retrain-pipelines_source-code",
2124
+ "commit_datetime": commit_datetime,
2125
+ "commit_hash": source_code_commit_hash
2126
+ }
2127
+ ###############################
2128
+
2129
+ ###############################
2130
+ # pipeline-card #
2131
+ ###############################
2132
+ pipeline_card_fullname = None
2133
+ for run_step in current.run.steps():
2134
+ task = list(run_step.tasks())[0]
2135
+ task_name = task.path_components[2]
2136
+ if "pipeline_card" == task_name:
2137
+ pipeline_card = get_cards(
2138
+ task, id='custom', type='html')[0]
2139
+ pipeline_card_fullname = os.path.realpath(
2140
+ os.path.join(
2141
+ task.metadata_dict.get("ds-root", None),
2142
+ mf_config.CARD_SUFFIX, pipeline_card.path
2143
+ ))
2144
+ print(pipeline_card_fullname)
2145
+ break
2146
+ pipeline_card_commit_hash = \
2147
+ push_files_to_hub_repo_branch(
2148
+ repo_id=self.model_repo_id,
2149
+ branch_name="retrain-pipelines_pipeline-card",
2150
+ file_fullnames=[pipeline_card_fullname],
2151
+ path_in_repo=subfolder_name,
2152
+ commit_message=\
2153
+ "pipeline-card for model version " + \
2154
+ subfolder_name + \
2155
+ f"- retrain-pipelines {__version__}",
2156
+ repo_type="model",
2157
+ hf_token=os.getenv("HF_TOKEN", None)
2158
+ )
2159
+ print(pipeline_card_commit_hash)
2160
+ self.pipeline_card_commit_dict = {
2161
+ "repo_id": self.model_repo_id,
2162
+ "branch_name": "retrain-pipelines_pipeline-card",
2163
+ "commit_datetime": commit_datetime,
2164
+ "commit_hash": pipeline_card_commit_hash
2165
+ }
2166
+ ###############################
2167
+
2168
+ self.next(self.deploy)
2169
+
2170
+
2171
+ @step
2172
+ def deploy(self):
2173
+ """
2174
+ placeholder for the serving SDK deploy call
2175
+ (on the target production platform).
2176
+ Include any artifact you want,
2177
+ consider including the portable pipelione-card
2178
+ itself !
2179
+ """
2180
+
2181
+ if (
2182
+ self.model_version_blessed and
2183
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2184
+ ):
2185
+ pass # your code here
2186
+
2187
+ self.next(self.load_test)
2188
+
2189
+
2190
+ @step
2191
+ def load_test(self):
2192
+ """
2193
+ placeholder
2194
+ """
2195
+
2196
+ if (
2197
+ self.model_version_blessed and
2198
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2199
+ ):
2200
+ pass # your code here
2201
+
2202
+ self.next(self.end)
2203
+
2204
+
2205
+ @step
2206
+ def end(self):
2207
+ pass
2208
+
2209
+
2210
+ if __name__ == "__main__":
2211
+ UnslothFuncCallFlow()
2212
+