Aurelien-Morgan-Bot commited on
Commit
3b99b8c
·
verified ·
1 Parent(s): 20a6c2f

source-code for model version v0.18_20250323_235054255_UTC- retrain-pipelines 0.1.1

Browse files
v0.18_20250323_235054255_UTC/requirements.txt ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==1.5.2
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.11.14
5
+ aiosignal==1.3.2
6
+ alabaster==1.0.0
7
+ albucore==0.0.23
8
+ albumentations==2.0.5
9
+ ale-py==0.10.2
10
+ altair==5.5.0
11
+ annotated-types==0.7.0
12
+ anyio==4.9.0
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ array_record==0.7.1
16
+ arviz==0.21.0
17
+ astropy==7.0.1
18
+ astropy-iers-data==0.2025.3.17.0.34.53
19
+ astunparse==1.6.3
20
+ atpublic==5.1
21
+ attrs==25.3.0
22
+ audioread==3.0.1
23
+ autograd==1.7.0
24
+ babel==2.17.0
25
+ backcall==0.2.0
26
+ beautifulsoup4==4.13.3
27
+ betterproto==2.0.0b6
28
+ bigframes==1.40.0
29
+ bigquery-magics==0.8.0
30
+ bitsandbytes==0.45.3
31
+ bleach==6.2.0
32
+ blinker==1.9.0
33
+ blis==1.2.0
34
+ blosc2==3.2.0
35
+ bokeh==3.6.3
36
+ boto3==1.37.18
37
+ botocore==1.37.18
38
+ Bottleneck==1.4.2
39
+ bqplot==0.12.44
40
+ branca==0.8.1
41
+ CacheControl==0.14.2
42
+ cachetools==5.5.2
43
+ catalogue==2.0.10
44
+ certifi==2025.1.31
45
+ cffi==1.17.1
46
+ chardet==5.2.0
47
+ charset-normalizer==3.4.1
48
+ chex==0.1.89
49
+ clarabel==0.10.0
50
+ click==8.1.8
51
+ cloudpathlib==0.21.0
52
+ cloudpickle==3.1.1
53
+ cmake==3.31.6
54
+ cmdstanpy==1.2.5
55
+ colorama==0.4.6
56
+ colorcet==3.1.0
57
+ colorlover==0.3.0
58
+ colour==0.1.5
59
+ comm==0.2.2
60
+ community==1.0.0b1
61
+ confection==0.1.5
62
+ cons==0.4.6
63
+ contourpy==1.3.1
64
+ cramjam==2.9.1
65
+ cryptography==43.0.3
66
+ cuda-python==12.6.0
67
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
68
+ cudf-polars-cu12==24.12.0
69
+ cufflinks==0.17.3
70
+ cuml-cu12==25.2.1
71
+ cupy-cuda12x==13.3.0
72
+ cut-cross-entropy==25.1.1
73
+ cuvs-cu12==25.2.1
74
+ cvxopt==1.3.2
75
+ cvxpy==1.6.4
76
+ cycler==0.12.1
77
+ cyipopt==1.5.0
78
+ cymem==2.0.11
79
+ Cython==3.0.12
80
+ dask==2024.12.1
81
+ dask-cuda==25.2.0
82
+ dask-cudf-cu12==25.2.2
83
+ dask-expr==1.1.21
84
+ datascience==0.17.6
85
+ datasets==3.1.0
86
+ db-dtypes==1.4.2
87
+ dbus-python==1.2.18
88
+ debugpy==1.8.0
89
+ decorator==4.4.2
90
+ defusedxml==0.7.1
91
+ Deprecated==1.2.18
92
+ diffusers==0.32.2
93
+ dill==0.3.8
94
+ distributed==2024.12.1
95
+ distributed-ucxx-cu12==0.42.0
96
+ distro==1.9.0
97
+ dlib==19.24.2
98
+ dm-tree==0.1.9
99
+ docker==7.1.0
100
+ docker-pycreds==0.4.0
101
+ docstring_parser==0.16
102
+ docutils==0.21.2
103
+ dopamine_rl==4.1.2
104
+ duckdb==1.2.1
105
+ earthengine-api==1.5.7
106
+ easydict==1.13
107
+ editdistance==0.8.1
108
+ eerepr==0.1.1
109
+ einops==0.8.1
110
+ en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
111
+ entrypoints==0.4
112
+ et_xmlfile==2.0.0
113
+ etils==1.12.2
114
+ etuples==0.3.9
115
+ Farama-Notifications==0.0.4
116
+ fastai==2.7.19
117
+ fastapi==0.115.11
118
+ fastcore==1.7.29
119
+ fastdownload==0.0.7
120
+ fastjsonschema==2.21.1
121
+ fastprogress==1.0.3
122
+ fastrlock==0.8.3
123
+ filelock==3.18.0
124
+ firebase-admin==6.7.0
125
+ Flask==3.1.0
126
+ flatbuffers==25.2.10
127
+ flax==0.10.4
128
+ folium==0.19.5
129
+ fonttools==4.56.0
130
+ frozendict==2.4.6
131
+ frozenlist==1.5.0
132
+ fsspec==2024.9.0
133
+ future==1.0.0
134
+ gast==0.6.0
135
+ GDAL==3.6.4
136
+ gdown==5.2.0
137
+ geemap==0.35.3
138
+ geocoder==1.38.1
139
+ geographiclib==2.0
140
+ geopandas==1.0.1
141
+ geopy==2.4.1
142
+ gin-config==0.5.0
143
+ gitdb==4.0.12
144
+ GitPython==3.1.44
145
+ glob2==0.7
146
+ google==2.0.3
147
+ google-ai-generativelanguage==0.6.15
148
+ google-api-core==2.24.2
149
+ google-api-python-client==2.164.0
150
+ google-auth==2.38.0
151
+ google-auth-httplib2==0.2.0
152
+ google-auth-oauthlib==1.2.1
153
+ google-cloud-aiplatform==1.84.0
154
+ google-cloud-bigquery==3.29.0
155
+ google-cloud-bigquery-connection==1.18.2
156
+ google-cloud-bigquery-storage==2.29.1
157
+ google-cloud-bigtable==2.29.0
158
+ google-cloud-core==2.4.3
159
+ google-cloud-dataproc==5.18.1
160
+ google-cloud-datastore==2.20.2
161
+ google-cloud-firestore==2.20.1
162
+ google-cloud-functions==1.20.2
163
+ google-cloud-iam==2.18.2
164
+ google-cloud-language==2.17.1
165
+ google-cloud-pubsub==2.28.0
166
+ google-cloud-resource-manager==1.14.2
167
+ google-cloud-spanner==3.53.0
168
+ google-cloud-storage==2.19.0
169
+ google-cloud-translate==3.20.2
170
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
171
+ google-crc32c==1.7.0
172
+ google-genai==1.5.0
173
+ google-generativeai==0.8.4
174
+ google-pasta==0.2.0
175
+ google-resumable-media==2.7.2
176
+ google-spark-connect==0.5.2
177
+ googleapis-common-protos==1.69.2
178
+ googledrivedownloader==1.1.0
179
+ graphviz==0.20.3
180
+ greenlet==3.1.1
181
+ grpc-google-iam-v1==0.14.2
182
+ grpc-interceptor==0.15.4
183
+ grpcio==1.71.0
184
+ grpcio-status==1.71.0
185
+ grpclib==0.4.7
186
+ gspread==6.2.0
187
+ gspread-dataframe==4.0.0
188
+ gym==0.25.2
189
+ gym-notices==0.0.8
190
+ gymnasium==1.1.1
191
+ h11==0.14.0
192
+ h2==4.2.0
193
+ h5netcdf==1.6.1
194
+ h5py==3.13.0
195
+ hdbscan==0.8.40
196
+ hf_transfer==0.1.9
197
+ highspy==1.9.0
198
+ holidays==0.69
199
+ holoviews==1.20.2
200
+ hpack==4.1.0
201
+ html5lib==1.1
202
+ httpcore==1.0.7
203
+ httpimport==1.4.1
204
+ httplib2==0.22.0
205
+ httptools==0.6.4
206
+ httpx==0.28.1
207
+ huggingface-hub==0.27.1
208
+ humanize==4.12.1
209
+ hyperframe==6.1.0
210
+ hyperopt==0.2.7
211
+ ibis-framework==9.5.0
212
+ idna==3.10
213
+ imageio==2.37.0
214
+ imageio-ffmpeg==0.6.0
215
+ imagesize==1.4.1
216
+ imbalanced-learn==0.13.0
217
+ immutabledict==4.2.1
218
+ importlib_metadata==8.6.1
219
+ importlib_resources==6.5.2
220
+ imutils==0.5.4
221
+ inflect==7.5.0
222
+ iniconfig==2.0.0
223
+ intel-cmplr-lib-ur==2025.0.5
224
+ intel-openmp==2025.0.5
225
+ ipyevents==2.0.2
226
+ ipyfilechooser==0.6.0
227
+ ipykernel==6.29.5
228
+ ipyleaflet==0.19.2
229
+ ipyparallel==8.8.0
230
+ ipython==7.34.0
231
+ ipython-genutils==0.2.0
232
+ ipython-sql==0.5.0
233
+ ipytree==0.2.2
234
+ ipywidgets==7.7.1
235
+ itsdangerous==2.2.0
236
+ jax==0.5.2
237
+ jax-cuda12-pjrt==0.5.1
238
+ jax-cuda12-plugin==0.5.1
239
+ jaxlib==0.5.1
240
+ jedi==0.19.2
241
+ jeepney==0.7.1
242
+ jellyfish==1.1.0
243
+ jieba==0.42.1
244
+ Jinja2==3.1.4
245
+ jiter==0.9.0
246
+ jmespath==1.0.1
247
+ joblib==1.4.2
248
+ jsonpatch==1.33
249
+ jsonpickle==4.0.2
250
+ jsonpointer==3.0.0
251
+ jsonschema==4.23.0
252
+ jsonschema-specifications==2024.10.1
253
+ jupyter-client==6.1.12
254
+ jupyter-console==6.1.0
255
+ jupyter-leaflet==0.19.2
256
+ jupyter-server==1.16.0
257
+ jupyter_core==5.7.2
258
+ jupyterlab_pygments==0.3.0
259
+ jupyterlab_widgets==3.0.13
260
+ kaggle==1.7.4.2
261
+ kagglehub==0.3.10
262
+ keras==3.8.0
263
+ keras-hub==0.18.1
264
+ keras-nlp==0.18.1
265
+ keyring==23.5.0
266
+ kiwisolver==1.4.8
267
+ langchain==0.3.20
268
+ langchain-core==0.3.45
269
+ langchain-text-splitters==0.3.6
270
+ langcodes==3.5.0
271
+ langsmith==0.3.15
272
+ language_data==1.3.0
273
+ launchpadlib==1.10.16
274
+ lazr.restfulclient==0.14.4
275
+ lazr.uri==1.0.6
276
+ lazy_loader==0.4
277
+ libclang==18.1.1
278
+ libcudf-cu12==24.12.0
279
+ libcugraph-cu12==25.2.0
280
+ libcuml-cu12==25.2.1
281
+ libcuvs-cu12==25.2.1
282
+ libkvikio-cu12==24.12.1
283
+ libraft-cu12==25.2.0
284
+ librosa==0.11.0
285
+ libucx-cu12==1.18.0
286
+ libucxx-cu12==0.42.0
287
+ lightgbm==4.5.0
288
+ linkify-it-py==2.0.3
289
+ litserve==0.2.6
290
+ llvmlite==0.43.0
291
+ locket==1.0.0
292
+ logical-unification==0.4.6
293
+ lxml==5.3.0
294
+ Mako==1.1.3
295
+ marisa-trie==1.2.1
296
+ Markdown==3.7
297
+ markdown-it-py==3.0.0
298
+ MarkupSafe==3.0.2
299
+ matplotlib==3.9.2
300
+ matplotlib-inline==0.1.7
301
+ matplotlib-venn==1.1.2
302
+ mdit-py-plugins==0.4.2
303
+ mdurl==0.1.2
304
+ metaflow==2.10.0
305
+ metaflow-card-html==1.0.2
306
+ miniKanren==1.0.3
307
+ missingno==0.5.2
308
+ mistune==3.1.2
309
+ mizani==0.13.1
310
+ mkl==2025.0.1
311
+ ml-dtypes==0.4.1
312
+ mlxtend==0.23.4
313
+ more-itertools==10.6.0
314
+ moviepy==1.0.3
315
+ mpmath==1.3.0
316
+ msgpack==1.1.0
317
+ multidict==6.2.0
318
+ multipledispatch==1.0.0
319
+ multiprocess==0.70.16
320
+ multitasking==0.0.11
321
+ murmurhash==1.0.12
322
+ music21==9.3.0
323
+ namex==0.0.8
324
+ narwhals==1.31.0
325
+ natsort==8.4.0
326
+ nbclassic==1.2.0
327
+ nbclient==0.10.2
328
+ nbconvert==7.16.6
329
+ nbformat==5.10.4
330
+ ndindex==1.9.2
331
+ nest-asyncio==1.6.0
332
+ networkx==3.2.1
333
+ nibabel==5.3.2
334
+ nltk==3.9.1
335
+ notebook==6.5.7
336
+ notebook_shim==0.2.4
337
+ numba==0.60.0
338
+ numba-cuda==0.2.0
339
+ numexpr==2.10.2
340
+ numpy==1.26.4
341
+ nvidia-cublas-cu12==12.4.5.8
342
+ nvidia-cuda-cupti-cu12==12.4.127
343
+ nvidia-cuda-nvcc-cu12==12.5.82
344
+ nvidia-cuda-nvrtc-cu12==12.4.127
345
+ nvidia-cuda-runtime-cu12==12.4.127
346
+ nvidia-cudnn-cu12==9.1.0.70
347
+ nvidia-cufft-cu12==11.2.1.3
348
+ nvidia-curand-cu12==10.3.5.147
349
+ nvidia-cusolver-cu12==11.6.1.9
350
+ nvidia-cusparse-cu12==12.3.1.170
351
+ nvidia-cusparselt-cu12==0.6.2
352
+ nvidia-ml-py==12.570.86
353
+ nvidia-nccl-cu12==2.21.5
354
+ nvidia-nvcomp-cu12==4.1.0.6
355
+ nvidia-nvjitlink-cu12==12.4.127
356
+ nvidia-nvtx-cu12==12.4.127
357
+ nvtx==0.2.11
358
+ nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.2.0-py3-none-any.whl
359
+ oauth2client==4.1.3
360
+ oauthlib==3.2.2
361
+ openai==1.66.3
362
+ opencv-contrib-python==4.11.0.86
363
+ opencv-python==4.11.0.86
364
+ opencv-python-headless==4.11.0.86
365
+ openpyxl==3.1.5
366
+ opentelemetry-api==1.31.0
367
+ opentelemetry-sdk==1.31.0
368
+ opentelemetry-semantic-conventions==0.52b0
369
+ opt_einsum==3.4.0
370
+ optax==0.2.4
371
+ optree==0.14.1
372
+ orbax-checkpoint==0.11.9
373
+ orjson==3.10.15
374
+ osqp==0.6.7.post3
375
+ packaging==24.2
376
+ pandas==2.2.2
377
+ pandas-datareader==0.10.0
378
+ pandas-gbq==0.28.0
379
+ pandas-stubs==2.2.2.240909
380
+ pandocfilters==1.5.1
381
+ panel==1.6.1
382
+ param==2.2.0
383
+ parso==0.8.4
384
+ parsy==2.1
385
+ partd==1.4.2
386
+ pathlib==1.0.1
387
+ patsy==1.0.1
388
+ peewee==3.17.9
389
+ peft==0.14.0
390
+ pexpect==4.9.0
391
+ pickleshare==0.7.5
392
+ pillow==11.1.0
393
+ platformdirs==4.3.6
394
+ plotly==5.24.1
395
+ plotnine==0.14.5
396
+ pluggy==1.5.0
397
+ ply==3.11
398
+ polars==1.11.0
399
+ pooch==1.8.2
400
+ portpicker==1.5.2
401
+ preshed==3.0.9
402
+ prettytable==3.15.1
403
+ proglog==0.1.10
404
+ progressbar2==4.5.0
405
+ prometheus_client==0.21.1
406
+ promise==2.3
407
+ prompt_toolkit==3.0.50
408
+ propcache==0.3.0
409
+ prophet==1.1.6
410
+ proto-plus==1.26.1
411
+ protobuf==3.20.3
412
+ psutil==5.9.5
413
+ psycopg2==2.9.10
414
+ ptyprocess==0.7.0
415
+ py-cpuinfo==9.0.0
416
+ py4j==0.10.9.7
417
+ pyarrow==17.0.0
418
+ pyasn1==0.6.1
419
+ pyasn1_modules==0.4.1
420
+ pycairo==1.27.0
421
+ pycocotools==2.0.8
422
+ pycparser==2.22
423
+ pydantic==2.9.2
424
+ pydantic_core==2.23.4
425
+ pydata-google-auth==1.9.1
426
+ pydot==1.4.2
427
+ pydotplus==2.0.2
428
+ PyDrive==1.3.1
429
+ PyDrive2==1.21.3
430
+ pyerfa==2.0.1.5
431
+ pygame==2.6.1
432
+ pygit2==1.17.0
433
+ Pygments==2.18.0
434
+ PyGObject==3.42.0
435
+ PyJWT==2.10.1
436
+ pylibcudf-cu12==24.12.0
437
+ pylibcugraph-cu12==25.2.0
438
+ pylibraft-cu12==25.2.0
439
+ pymc==5.21.1
440
+ pymystem3==0.2.0
441
+ pynndescent==0.5.13
442
+ pynvjitlink-cu12==0.5.2
443
+ pynvml==12.0.0
444
+ pyogrio==0.10.0
445
+ Pyomo==6.8.2
446
+ PyOpenGL==3.1.9
447
+ pyOpenSSL==24.2.1
448
+ pyparsing==3.2.1
449
+ pyperclip==1.9.0
450
+ pyproj==3.7.1
451
+ pyshp==2.3.1
452
+ PySocks==1.7.1
453
+ pyspark==3.5.5
454
+ pytensor==2.28.3
455
+ pytest==8.3.3
456
+ python-apt==0.0.0
457
+ python-box==7.3.2
458
+ python-dateutil==2.8.2
459
+ python-dotenv==1.0.1
460
+ python-louvain==0.16
461
+ python-multipart==0.0.20
462
+ python-slugify==8.0.4
463
+ python-snappy==0.7.3
464
+ python-utils==3.9.1
465
+ pytz==2025.1
466
+ pyviz_comms==3.0.4
467
+ PyYAML==6.0.2
468
+ pyzmq==24.0.1
469
+ qdldl==0.1.7.post5
470
+ raft-dask-cu12==25.2.0
471
+ rapids-dask-dependency==25.2.0
472
+ ratelim==0.1.6
473
+ referencing==0.36.2
474
+ regex==2024.11.6
475
+ requests==2.32.3
476
+ requests-oauthlib==2.0.0
477
+ requests-toolbelt==1.0.0
478
+ requirements-parser==0.9.0
479
+ retrain_pipelines @ git+https://github.com/aurelienmorgan/retrain-pipelines.git@9a5b7f7992744e6739bc28700f3d5d8915796d71#subdirectory=pkg_src
480
+ rich==13.9.4
481
+ rmm-cu12==24.12.0
482
+ roman-numerals-py==3.1.0
483
+ rpds-py==0.23.1
484
+ rpy2==3.5.17
485
+ rsa==4.9
486
+ s3transfer==0.11.4
487
+ safetensors==0.5.3
488
+ scikit-image==0.25.2
489
+ scikit-learn==1.6.1
490
+ scipy==1.14.1
491
+ scooby==0.10.0
492
+ scs==3.2.7.post2
493
+ seaborn==0.13.2
494
+ SecretStorage==3.3.1
495
+ Send2Trash==1.8.3
496
+ sentence-transformers==3.4.1
497
+ sentencepiece==0.2.0
498
+ sentry-sdk==2.23.1
499
+ setproctitle==1.3.5
500
+ shap==0.47.0
501
+ shapely==2.0.7
502
+ shellingham==1.5.4
503
+ shtab==1.7.1
504
+ simple-parsing==0.1.7
505
+ simplejson==3.20.1
506
+ simsimd==6.2.1
507
+ six==1.17.0
508
+ sklearn-compat==0.1.3
509
+ sklearn-pandas==2.2.0
510
+ slicer==0.0.8
511
+ smart-open==7.1.0
512
+ smmap==5.0.2
513
+ sniffio==1.3.1
514
+ snowballstemmer==2.2.0
515
+ sortedcontainers==2.4.0
516
+ soundfile==0.13.1
517
+ soupsieve==2.6
518
+ soxr==0.5.0.post1
519
+ spacy==3.8.4
520
+ spacy-legacy==3.0.12
521
+ spacy-loggers==1.0.5
522
+ spanner-graph-notebook==1.1.3
523
+ Sphinx==8.2.3
524
+ sphinxcontrib-applehelp==2.0.0
525
+ sphinxcontrib-devhelp==2.0.0
526
+ sphinxcontrib-htmlhelp==2.1.0
527
+ sphinxcontrib-jsmath==1.0.1
528
+ sphinxcontrib-qthelp==2.0.0
529
+ sphinxcontrib-serializinghtml==2.0.0
530
+ SQLAlchemy==2.0.39
531
+ sqlglot==25.20.2
532
+ sqlparse==0.5.3
533
+ srsly==2.5.1
534
+ stanio==0.5.1
535
+ starlette==0.46.1
536
+ statsmodels==0.14.4
537
+ stringzilla==3.12.3
538
+ sympy==1.13.1
539
+ tables==3.10.2
540
+ tabulate==0.9.0
541
+ tbb==2022.0.0
542
+ tblib==3.0.0
543
+ tcmlib==1.2.0
544
+ tenacity==9.0.0
545
+ tensorboard==2.18.0
546
+ tensorboard-data-server==0.7.2
547
+ tensorflow==2.18.0
548
+ tensorflow-datasets==4.9.8
549
+ tensorflow-hub==0.16.1
550
+ tensorflow-io-gcs-filesystem==0.37.1
551
+ tensorflow-metadata==1.16.1
552
+ tensorflow-probability==0.25.0
553
+ tensorflow-text==2.18.1
554
+ tensorstore==0.1.72
555
+ termcolor==2.5.0
556
+ terminado==0.18.1
557
+ text-unidecode==1.3
558
+ textblob==0.19.0
559
+ tf-slim==1.1.0
560
+ tf_keras==2.18.0
561
+ thinc==8.3.4
562
+ threadpoolctl==3.6.0
563
+ tifffile==2025.3.13
564
+ timm==1.0.15
565
+ tinycss2==1.4.0
566
+ tokenizers==0.20.3
567
+ toml==0.10.2
568
+ toolz==0.12.1
569
+ torch==2.5.0
570
+ torchsummary==1.5.1
571
+ torchvision==0.20.0
572
+ tornado==6.4.2
573
+ tqdm==4.67.1
574
+ traitlets==5.7.1
575
+ traittypes==0.2.1
576
+ transformers==4.46.2
577
+ treelite==4.4.1
578
+ treescope==0.1.9
579
+ triton==3.1.0
580
+ trl==0.12.0
581
+ tweepy==4.15.0
582
+ typeguard==4.4.2
583
+ typer==0.15.2
584
+ types-pytz==2025.1.0.20250318
585
+ types-setuptools==76.0.0.20250313
586
+ typing_extensions==4.12.2
587
+ tyro==0.9.17
588
+ tzdata==2025.1
589
+ tzlocal==5.3.1
590
+ uc-micro-py==1.0.3
591
+ ucx-py-cu12==0.42.0
592
+ ucxx-cu12==0.42.0
593
+ umap-learn==0.5.7
594
+ umf==0.9.1
595
+ unsloth @ git+https://github.com/unslothai/unsloth.git@3a1e7ef8299f3c96fa6e8de11fd0772af3cbc83f
596
+ unsloth_zoo==2024.11.4
597
+ uritemplate==4.1.1
598
+ urllib3==2.3.0
599
+ uvicorn==0.34.0
600
+ uvloop==0.21.0
601
+ vega-datasets==0.9.0
602
+ wadllib==1.3.6
603
+ wandb==0.19.8
604
+ wasabi==1.1.3
605
+ watchfiles==1.0.4
606
+ wcwidth==0.2.13
607
+ weasel==0.4.1
608
+ webcolors==24.11.1
609
+ webencodings==0.5.1
610
+ websocket-client==1.8.0
611
+ websockets==14.2
612
+ Werkzeug==3.1.3
613
+ widgetsnbextension==3.6.10
614
+ wordcloud==1.9.4
615
+ wrapt==1.17.2
616
+ xarray==2025.1.2
617
+ xarray-einstats==0.8.0
618
+ xformers==0.0.28.post2
619
+ xgboost==2.1.4
620
+ xlrd==2.0.1
621
+ xxhash==3.5.0
622
+ xyzservices==2025.1.0
623
+ yarl==1.18.3
624
+ yellowbrick==1.5
625
+ yfinance==0.2.54
626
+ zict==3.0.0
627
+ zipp==3.21.0
628
+ zstandard==0.23.0
v0.18_20250323_235054255_UTC/retraining_pipeline.py ADDED
@@ -0,0 +1,2219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from unsloth import FastLanguageModel, \
3
+ is_bfloat16_supported, UnslothTrainer, \
4
+ UnslothTrainingArguments
5
+
6
+ import torch
7
+
8
+ import os
9
+ import sys
10
+
11
+ import gc
12
+ import json
13
+ import time
14
+ import shutil
15
+ import logging
16
+ import traceback
17
+ import subprocess
18
+ import importlib.util
19
+ from enum import Enum
20
+ from io import StringIO
21
+ from textwrap import dedent
22
+ from datetime import datetime
23
+ from contextlib import redirect_stdout
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ import polars as pl
29
+ from polars.exceptions import ComputeError
30
+
31
+ import matplotlib
32
+ import matplotlib.pyplot as plt
33
+
34
+ from jinja2 import Environment, FileSystemLoader
35
+
36
+ from metaflow import FlowSpec, step, Parameter, JSONType, \
37
+ IncludeFile, current, metaflow_config as mf_config, \
38
+ resources, Flow, Task, card
39
+ from metaflow.current import Current
40
+ from metaflow.cards import Image, Table, Markdown, \
41
+ Artifact, get_cards
42
+
43
+ from datasets import load_dataset, Dataset, DatasetDict
44
+ from datasets.config import HF_DATASETS_CACHE, HF_CACHE_HOME
45
+ from huggingface_hub import list_repo_commits
46
+ from transformers import AutoTokenizer
47
+ from transformers.utils import logging as hf_logging
48
+
49
+ from retrain_pipelines import __version__
50
+ from retrain_pipelines.dataset.hf_utils import get_lazy_df, \
51
+ get_column_info, iterable_dataset_multi_buffer_sampler, \
52
+ push_dataset_version_to_hub
53
+ from retrain_pipelines.dataset.tool_calls import \
54
+ get_unique_tools, count_tool_occurrences, \
55
+ plot_tools_occurences, column_words_stats, \
56
+ plot_words_count
57
+ from retrain_pipelines.utils.hf_utils import \
58
+ get_new_repo_minor_version, push_files_to_hub_repo_branch
59
+ from retrain_pipelines.utils import create_requirements
60
+
61
+
62
+ class LocalServeReadinessEnum(Enum):
63
+ """
64
+ tracking local-serve (infra-validation)
65
+ status using a "3+"-states enum :
66
+ - "-1" for "not applicable"
67
+ (i.e. "model version not blessed"),
68
+ - "0/1" bool for failure/success.
69
+ """
70
+ NOT_APPLICABLE = -1
71
+ FAILURE = 0
72
+ FAILURE_NO_DOCKER = 2
73
+ SUCCESS = 1
74
+
75
+
76
+ class UnslothFuncCallFlow(FlowSpec):
77
+ """
78
+ Training pipeline
79
+ """
80
+ # @see https://github.com/unslothai/unsloth/wiki
81
+
82
+ #--- flow parameters -------------------------------------------------------
83
+
84
+ RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
85
+ # in order to share the config across subprocesses
86
+ os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
87
+
88
+ hf_dataset = Parameter(
89
+ "hf_dataset",
90
+ help="dict with 'repo_id' and 'commit_hash' keys. " + \
91
+ "if 'commit_hash is None, falls back to latest version " +\
92
+ "of the dataset available in parquet format.\n" +
93
+ "Note that there are 3 required 'attributes' of type " + \
94
+ "str, list[str], list[str]",
95
+ type=JSONType,
96
+ default=dedent("""{
97
+ "repo_id": "Salesforce/xlam-function-calling-60k",
98
+ "config_name": "",
99
+ "commit_hash": "",
100
+ "attributes": {
101
+ "query_attr": "query",
102
+ "answers_attr": "answers",
103
+ "tools_attr": "tools"
104
+ }
105
+ }""").replace("'", '"').strip('"')
106
+ )
107
+
108
+ augmentation_rate = Parameter(
109
+ "augmentation_rate",
110
+ type=float,
111
+ default=.05,
112
+ help="proportion of records to be augmented "+\
113
+ "(x% of original dataset is created"+\
114
+ " as additional augmented datapoints), i.e. "+\
115
+ "truncated queries to serve as negative examples, "+\
116
+ "meaning they trigger no tool call "+\
117
+ "due to info incompleteness."
118
+ )
119
+
120
+ hf_enrich_dataset = Parameter(
121
+ "hf_enrich_dataset",
122
+ help="dict with 'repo_id', 'config_name' and 'commit_hash', "+\
123
+ "query_attribute' and 'query_attribute_handler' keys. "+\
124
+ "if 'commit_hash is None, falls back to latest version "+\
125
+ "of the dataset available in parquet format."+\
126
+ "'query_attribute' depicts the dataset attribute "+\
127
+ "from which 'queries' are to be sampled."+\
128
+ "'query_attribute_handler' serves for attributes "+\
129
+ "that have complex structure, "+\
130
+ "other than 'string' datatype.",
131
+ type=JSONType,
132
+ # @see https://huggingface.co/datasets/google-research-datasets/natural_questions
133
+ default=dedent("""{
134
+ "repo_id": "lighteval/natural_questions_clean",
135
+ "config_name": "",
136
+ "commit_hash": "",
137
+ "query_attribute": "question",
138
+ "query_attribute_handler": "lambda x: x"
139
+ }""").replace("'", '"').strip('"')
140
+ )
141
+
142
+ enrichment_rate = Parameter(
143
+ "enrichment_rate",
144
+ type=float,
145
+ default=.1,
146
+ help="proportion of records "+\
147
+ "to be added from the 'hf_enrich_dataset'"+\
148
+ "(x% of original dataset is sampled and"+\
149
+ " added as enriching datapoints), i.e. "+\
150
+ "queries to serve as negative examples, "+\
151
+ "due to their complete disconnexion "+\
152
+ "to tool calling situations."
153
+ )
154
+
155
+ dataset_repo_id = Parameter(
156
+ "dataset_repo_id",
157
+ type=str,
158
+ default="retrain-pipelines/func_calls",
159
+ help="The 'repo_id' to be used " + \
160
+ "for the Hugging Face dataset version push " + \
161
+ "(will be created at runtime" + \
162
+ " if doesn't already exist)."
163
+ )
164
+
165
+ hf_base_model = Parameter(
166
+ "hf_base_model",
167
+ help="dict with 'repo_id' and 'commit_hash' keys."+\
168
+ "if 'commit_hash is None, falls back "+\
169
+ "to latest available version of the model.",
170
+ type=JSONType,
171
+ default=dedent("""{
172
+ "repo_id": "unsloth/Qwen2.5-1.5B",
173
+ "commit_hash": ""
174
+ }""").replace("'", '"').strip('"')
175
+ )
176
+
177
+ cpt_training_args = Parameter(
178
+ "cpt_training_args",
179
+ help="dict with `TrainingArguments` params "+\
180
+ "for the CPT job.",
181
+ type=JSONType,
182
+ default=dedent("""{
183
+ "warmup_ratio": 0.1,
184
+ "num_train_epochs": 1
185
+ }""").replace("'", '"').strip('"')
186
+ )
187
+
188
+ sft_training_args = Parameter(
189
+ "sft_training_args",
190
+ help="dict with `TrainingArguments` params "+\
191
+ "for the SFT job.",
192
+ type=JSONType,
193
+ default=dedent("""{
194
+ "warmup_ratio": 0.1,
195
+ "num_train_epochs": 1
196
+ }""").replace("'", '"').strip('"')
197
+ )
198
+
199
+ model_repo_id = Parameter(
200
+ "model_repo_id",
201
+ type=str,
202
+ default="retrain-pipelines/function_caller",
203
+ help="The 'repo_id' to be used " + \
204
+ "for the Hugging Face model version push " + \
205
+ "(will be created at runtime" + \
206
+ " if doesn't already exist)."
207
+ )
208
+
209
+ default_pipeline_card_module_dir = \
210
+ os.path.dirname(
211
+ importlib.util.find_spec(
212
+ f"retrain_pipelines.pipeline_card."+
213
+ f"{RETRAIN_PIPELINE_TYPE}"
214
+ ).origin)
215
+ pipeline_card_artifacts_path = Parameter(
216
+ "pipeline_card_artifacts_path",
217
+ type=str,
218
+ default=default_pipeline_card_module_dir,
219
+ help="pipeline_card artifacts location "+\
220
+ "(i.e. dir hosting your optional " + \
221
+ " custom documentation files :" + \
222
+ " 'pipeline_card.py' and/or 'template.html'"+\
223
+ " and/or 'model_readme.py'"+\
224
+ " and/or 'model_readme_template.md'," +\
225
+ " and/or 'dataset_readme.py'"+\
226
+ " and/or 'dataset_readme_template.md' file), " +\
227
+ "if different from default."
228
+ )
229
+ @staticmethod
230
+ def copy_default_dataset_readme_module(
231
+ target_dir: str,
232
+ exists_ok: bool = False
233
+ ) -> None:
234
+ os.makedirs(target_dir, exist_ok=True)
235
+ if (
236
+ not exists_ok and
237
+ os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
238
+ ):
239
+ print("File already exists. Skipping copy.")
240
+ else:
241
+ filefullname = os.path.join(
242
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
243
+ "dataset_readme.py"
244
+ )
245
+ shutil.copy(filefullname, target_dir)
246
+ print(filefullname)
247
+ @staticmethod
248
+ def copy_default_dataset_readme_template(
249
+ target_dir: str,
250
+ exists_ok: bool = False
251
+ ) -> None:
252
+ os.makedirs(target_dir, exist_ok=True)
253
+ if (
254
+ not exists_ok and
255
+ os.path.exists(os.path.join(target_dir,
256
+ "dataset_readme_template.md"))
257
+ ):
258
+ print("File already exists. Skipping copy.")
259
+ else:
260
+ filefullname = os.path.join(
261
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
262
+ "dataset_readme_template.md")
263
+ shutil.copy(filefullname, target_dir)
264
+ print(filefullname)
265
+ @staticmethod
266
+ def copy_default_model_readme_module(
267
+ target_dir: str,
268
+ exists_ok: bool = False
269
+ ) -> None:
270
+ os.makedirs(target_dir, exist_ok=True)
271
+ if (
272
+ not exists_ok and
273
+ os.path.exists(os.path.join(target_dir, "model_readme.py"))
274
+ ):
275
+ print("File already exists. Skipping copy.")
276
+ else:
277
+ filefullname = os.path.join(
278
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
279
+ "model_readme.py"
280
+ )
281
+ shutil.copy(filefullname, target_dir)
282
+ print(filefullname)
283
+ @staticmethod
284
+ def copy_default_model_readme_template(
285
+ target_dir: str,
286
+ exists_ok: bool = False
287
+ ) -> None:
288
+ os.makedirs(target_dir, exist_ok=True)
289
+ if (
290
+ not exists_ok and
291
+ os.path.exists(os.path.join(target_dir,
292
+ "model_readme_template.md"))
293
+ ):
294
+ print("File already exists. Skipping copy.")
295
+ else:
296
+ filefullname = os.path.join(
297
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
298
+ "model_readme_template.md")
299
+ shutil.copy(filefullname, target_dir)
300
+ print(filefullname)
301
+ @staticmethod
302
+ def copy_default_pipeline_card_module(
303
+ target_dir: str,
304
+ exists_ok: bool = False
305
+ ) -> None:
306
+ os.makedirs(target_dir, exist_ok=True)
307
+ if (
308
+ not exists_ok and
309
+ os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
310
+ ):
311
+ print("File already exists. Skipping copy.")
312
+ else:
313
+ filefullname = os.path.join(
314
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
315
+ "pipeline_card.py"
316
+ )
317
+ shutil.copy(filefullname, target_dir)
318
+ print(filefullname)
319
+ @staticmethod
320
+ def copy_default_pipeline_card_html_template(
321
+ target_dir: str,
322
+ exists_ok: bool = False
323
+ ) -> None:
324
+ os.makedirs(target_dir, exist_ok=True)
325
+ if (
326
+ not exists_ok and
327
+ os.path.exists(os.path.join(target_dir, "template.html"))
328
+ ):
329
+ print("File already exists. Skipping copy.")
330
+ else:
331
+ filefullname = os.path.join(
332
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
333
+ "template.html")
334
+ shutil.copy(filefullname, target_dir)
335
+ print(filefullname)
336
+
337
+ del RETRAIN_PIPELINE_TYPE
338
+
339
+ #---------------------------------------------------------------------------
340
+
341
+ @step
342
+ def start(self):
343
+ print(f"{current.flow_name} - {current.run_id}")
344
+
345
+ # GPU availability
346
+ print(torch.cuda.get_device_name(0))
347
+ print(torch.__version__)
348
+ self.engine = "gpu" if torch.cuda.is_available() else "cpu"
349
+
350
+ # hf_dataset
351
+ hf_dataset_dict = \
352
+ get_lazy_df(
353
+ repo_id=self.hf_dataset["repo_id"],
354
+ commit_hash=self.hf_dataset["commit_hash"],
355
+ files_filter=(
356
+ self.hf_dataset['config_name']+"/.*\\.parquet"
357
+ if (
358
+ self.hf_dataset["config_name"] and
359
+ "" < self.hf_dataset["config_name"]
360
+ ) else ".*\\.parquet"
361
+ ),
362
+ hf_token=os.getenv("HF_TOKEN", None)
363
+ )
364
+ try:
365
+ print(hf_dataset_dict["repo_id"], ", ",
366
+ hf_dataset_dict["commit_hash"], " - ",
367
+ hf_dataset_dict["commit_datetime"], "\n",
368
+ hf_dataset_dict["lazy_df"].explain())
369
+ except ComputeError as ex:
370
+ if "HF_TOKEN" not in os.environ:
371
+ print("Does the Hugging Face-hosted dataset " +
372
+ "require authentication ?",
373
+ file=sys.stderr, flush=True)
374
+ raise ex
375
+ self.hf_dataset_dict = hf_dataset_dict
376
+
377
+ # hf_enrich_dataset
378
+ print(self.hf_enrich_dataset)
379
+ hf_enrich_dataset_dict = \
380
+ get_lazy_df(
381
+ repo_id=self.hf_enrich_dataset["repo_id"],
382
+ commit_hash=self.hf_enrich_dataset["commit_hash"],
383
+ files_filter=(
384
+ self.hf_enrich_dataset['config_name']+"/.*\\.parquet"
385
+ if (
386
+ self.hf_enrich_dataset["config_name"] and
387
+ "" < self.hf_enrich_dataset["config_name"]
388
+ ) else ".*\\.parquet"
389
+ ),
390
+ hf_token=os.getenv("HF_TOKEN", None)
391
+ )
392
+ print(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
393
+ for k in ['commit_hash',
394
+ 'commit_datetime']))
395
+ self.hf_enrich_dataset_dict = hf_enrich_dataset_dict
396
+
397
+ # hf_base_model
398
+ hf_base_model_commits = list_repo_commits(
399
+ repo_id=self.hf_base_model["repo_id"],
400
+ revision=(
401
+ None if (rev_commit_hash:=self.hf_base_model["commit_hash"]) == ""
402
+ else rev_commit_hash
403
+ ),
404
+ repo_type="model",
405
+ token=os.getenv("HF_TOKEN", None))
406
+ self.hf_base_model_dict = {
407
+ "repo_id": self.hf_base_model["repo_id"],
408
+ "commit_hash": hf_base_model_commits[0].commit_id,
409
+ "commit_datetime": \
410
+ hf_base_model_commits[0].created_at
411
+ }
412
+
413
+ self.model_version_blessed = False
414
+ self.current_blessed_run = None
415
+ self.current_blessed_version_dict = None
416
+ current.run.remove_tag("model_version_blessed")
417
+
418
+ self.retrain_pipelines = f"retrain-pipelines {__version__}"
419
+ self.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
420
+
421
+ self.serving_artifacts_local_folder = \
422
+ os.path.realpath(os.path.join(
423
+ os.path.dirname(__file__),
424
+ '..', '..', 'serving_artifacts',
425
+ os.path.sep.join(current.run.path_components)
426
+ ))
427
+
428
+ if not os.path.exists(self.serving_artifacts_local_folder):
429
+ os.makedirs(self.serving_artifacts_local_folder)
430
+
431
+ self.unsloth_dir = os.path.join(
432
+ self.serving_artifacts_local_folder,
433
+ "Unsloth"
434
+ )
435
+ print(f"unsloth_dir : {self.unsloth_dir}")
436
+ self.cpt_model_dir = os.path.join(
437
+ self.unsloth_dir, "cpt_model")
438
+ self.sft_model_dir = os.path.join(
439
+ self.unsloth_dir, "sft_model")
440
+
441
+ self.next(self.eda)
442
+
443
+
444
+ @step
445
+ def eda(self):
446
+ """
447
+ exploratory data analysis.
448
+ """
449
+
450
+ ############################
451
+ # features and label #
452
+ # basic counts #
453
+ ############################
454
+ self.records_count = self.hf_dataset_dict["lazy_df"] \
455
+ .select(pl.len()).collect(engine=self.engine).item()
456
+ self.data_schema = get_column_info(
457
+ self.hf_dataset_dict["lazy_df"], engine=self.engine)
458
+ ############################
459
+
460
+ ############################
461
+ # Answers #
462
+ # tools count #
463
+ ############################
464
+ struct_schema = pl.Struct([
465
+ pl.Field("name",
466
+ pl.String
467
+ ),
468
+ pl.Field("arguments",
469
+ pl.List(pl.String) # we retrieve list of args names
470
+ # (without assigned values)
471
+ )
472
+ ])
473
+ tool_answer_occurrences_df = \
474
+ count_tool_occurrences(
475
+ self.hf_dataset_dict["lazy_df"],
476
+ self.hf_dataset["attributes"]["answers_attr"],
477
+ struct_schema) \
478
+ .collect(engine=self.engine)
479
+ print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
480
+ f"query/tool-calls pairs")
481
+ fig = plot_tools_occurences(tool_answer_occurrences_df,
482
+ title_prefix="Dataset answers - ")
483
+ self.answers_tools_count_fig = fig
484
+ ############################
485
+
486
+ ############################
487
+ # Query #
488
+ # words count #
489
+ ############################
490
+ queries_max_length = self.hf_dataset_dict["lazy_df"].select(
491
+ pl.col(
492
+ self.hf_dataset["attributes"]["query_attr"]
493
+ ).str.len_chars().max().alias("max_query_length")
494
+ ).collect(engine=self.engine)
495
+ print(f"longuest query counts " +
496
+ f"{queries_max_length['max_query_length'][0]:,} characters")
497
+
498
+ # queries length quartiles
499
+ self.query_words_stats = \
500
+ column_words_stats(
501
+ self.hf_dataset_dict["lazy_df"],
502
+ self.hf_dataset["attributes"]["query_attr"]
503
+ ).collect(engine=self.engine)
504
+ print(self.query_words_stats.to_pandas().to_string(index=False))
505
+ print("Two thirds of the records have a query with less than " +
506
+ f"{self.query_words_stats['q3'][0]} words.")
507
+
508
+ fig = plot_words_count(
509
+ self.hf_dataset_dict["lazy_df"],
510
+ column_name=self.hf_dataset["attributes"]["query_attr"],
511
+ engine=self.engine)
512
+ self.words_count_fig = fig
513
+ ############################
514
+
515
+ ############################
516
+ # hf_enrich_dataset #
517
+ # Query words count #
518
+ ############################
519
+ enrich_question_words_stats = \
520
+ column_words_stats(
521
+ self.hf_enrich_dataset_dict['lazy_df'],
522
+ self.hf_enrich_dataset["query_attribute"],
523
+ column_attr_handler=eval(
524
+ self.hf_enrich_dataset["query_attribute_handler"])
525
+ ).collect(engine=self.engine)
526
+ print(enrich_question_words_stats.to_pandas()
527
+ .to_string(index=False))
528
+ del enrich_question_words_stats
529
+ ############################
530
+
531
+ self.next(self.augment_data)
532
+
533
+
534
+ @step
535
+ def augment_data(self):
536
+ """
537
+ Add 'negative' examples, where
538
+ queries do not trigger any tool call.
539
+ To achieve that, we sample long user queries,
540
+ truncate at half words count, and
541
+ associate this to an empty list of tool-calls.
542
+ """
543
+ """
544
+ We only consider :
545
+ - records with longuest queries,
546
+ i.e. queries in the last quartile
547
+ of "queries with most word-counts"
548
+ (this is to avoid that 'truncated' queries
549
+ get really short)
550
+ - records with answers consisting
551
+ in a single tool-call
552
+ (in order to minimize the risk
553
+ that truncating actually gives
554
+ a valid answer with
555
+ one tool-call [or more])
556
+
557
+ Note on flow 'augmentation_rate' :
558
+ we add that many records (at most),
559
+ as quartiles size permits.
560
+ """
561
+
562
+ print("Sampling within the population with more than " +
563
+ str(self.query_words_stats['q3'][0]) +
564
+ " words (longest queries quartile) =>")
565
+
566
+ samples_count = \
567
+ int(self.records_count * self.augmentation_rate)
568
+ print(f"would represent {samples_count:,.0f} " +
569
+ f"records to be sampled")
570
+
571
+ eligible_records_df = \
572
+ self.hf_dataset_dict["lazy_df"].filter(
573
+ pl.col(
574
+ self.hf_dataset["attributes"]["query_attr"]
575
+ )
576
+ .str.extract_all(r"\w+")
577
+ .map_elements(
578
+ lambda arr: len(arr),
579
+ return_dtype=pl.Int16)
580
+ .gt(self.query_words_stats['q3'][0])
581
+ & pl.col("answers")
582
+ .map_elements(
583
+ lambda x: len(json.loads(x)) == 1
584
+ if isinstance(x, str)
585
+ else False,
586
+ return_dtype=pl.Boolean)
587
+ ) \
588
+ .collect(engine=self.engine)
589
+ eligible_records_count = \
590
+ eligible_records_df.select(pl.len())["len"][0]
591
+ print(f"eligible_records_count : " +
592
+ f"{eligible_records_count:,.0f}")
593
+ samples_count = min(samples_count, eligible_records_count)
594
+ self.actual_augmentation_rate = \
595
+ samples_count / self.records_count
596
+ print("actual augmentation rate : " +
597
+ f"{self.actual_augmentation_rate:.1%}")
598
+ sampled_records_df = eligible_records_df.sample(
599
+ n=samples_count
600
+ )
601
+
602
+ self.augmented_records_df = \
603
+ sampled_records_df.with_columns(
604
+ pl.col("query")
605
+ .map_elements(
606
+ lambda query:
607
+ " ".join(
608
+ query.split()[
609
+ :len(query.split()) // 2]),
610
+ return_dtype=pl.Utf8)
611
+ .alias("truncated_query")
612
+ ).select([
613
+ pl.col("truncated_query").alias("query"),
614
+ pl.lit("[]").alias("answers")
615
+ ])
616
+ print(self.augmented_records_df.height,
617
+ self.augmented_records_df.columns)
618
+
619
+ self.next(self.enrich_data)
620
+
621
+
622
+ @step
623
+ def enrich_data(self):
624
+ """
625
+ Further enrich our dataset with 'negative' records from
626
+ another dataset (can be general-purpose text dataset)
627
+ as specified by the the flow 'hf_enrich_dataset' argument.
628
+ """
629
+ """
630
+ Note : we here use the Hugging Face `datasets` library
631
+ in 'streaming' mode for records sampling.
632
+ """
633
+
634
+ hf_enrich_ds = load_dataset(
635
+ path=self.hf_enrich_dataset["repo_id"],
636
+ name=self.hf_enrich_dataset["config_name"],
637
+ revision=self.hf_enrich_dataset_dict["commit_hash"],
638
+ streaming=True)
639
+ print(hf_enrich_ds["train"])
640
+
641
+ samples_count = \
642
+ int(self.records_count * self.enrichment_rate)
643
+ print(f"Samplig {samples_count:,.0f} records")
644
+
645
+ query_attribute_handler = \
646
+ eval(self.hf_enrich_dataset["query_attribute_handler"])
647
+ samples_iterator = iterable_dataset_multi_buffer_sampler(
648
+ hf_enrich_ds["train"],
649
+ total_samples=samples_count,
650
+ attributes_selector=\
651
+ (lambda x:query_attribute_handler(
652
+ x[self.hf_enrich_dataset["query_attribute"]])),
653
+ buffer_size=3_000,
654
+ num_passes=3,
655
+ seed=None
656
+ )
657
+ # Capitalize and add end punctuation if missing
658
+ start_time = time.time()
659
+ print("Starting sample enriching records, " +
660
+ "this may take some time if the source dataset " +
661
+ "has a complex structure..")
662
+ samples_list = [
663
+ s.capitalize() + ("" if s[-1] in ".!?" else "?")
664
+ for s in samples_iterator]
665
+ elapsed_time = time.time() - start_time
666
+ print(f".. sampling completed " +
667
+ f"({int(elapsed_time // 3_600)}h:" +
668
+ f"{int((elapsed_time % 3_600) // 60)}m:" +
669
+ f"{int(elapsed_time % 60)}s).")
670
+ enriched_records_df = pl.DataFrame(
671
+ {"query": samples_list,
672
+ "answers": \
673
+ ["[]"] * \
674
+ len(samples_list)}
675
+ )
676
+ self.enriched_records_df = enriched_records_df
677
+
678
+ self.next(self.dataset_to_hub)
679
+
680
+
681
+ @step
682
+ def dataset_to_hub(self):
683
+ """
684
+ Push to hub dataset version
685
+ - continued pre-training dataset
686
+ - training and validation splits of the
687
+ augmented and enriched
688
+ supervised finetuning dataset
689
+ - readme with versioning info
690
+ """
691
+
692
+ #############################
693
+ # case of user-provided #
694
+ # documentation artifact(s) #
695
+ #############################
696
+ # note that user can provide either
697
+ # 'pipeline_card.py' or 'template.html'
698
+ # or 'dataset_readme.py'
699
+ # or 'dataset_readme_template.md'
700
+ # or 'model_readme.py'
701
+ # or 'model_readme_template.md'
702
+ # or any combination of those
703
+ # when specifying custom
704
+ # 'pipeline_card_artifacts_path'
705
+ if (
706
+ "dataset_readme_template.md" in
707
+ os.listdir(self.pipeline_card_artifacts_path)
708
+ ):
709
+ template_dir = self.pipeline_card_artifacts_path
710
+ else:
711
+ template_dir = os.path.dirname(
712
+ importlib.util.find_spec(
713
+ f"retrain_pipelines.pipeline_card."+
714
+ f"{os.getenv('retrain_pipeline_type')}"
715
+ ).origin)
716
+ print(f"template_dir : '{template_dir}'")
717
+ #############################
718
+ if "dataset_readme.py" in os.listdir(
719
+ self.pipeline_card_artifacts_path):
720
+ from retrain_pipelines.utils import \
721
+ get_get_dataset_readme_content
722
+ get_dataset_readme_content = \
723
+ get_get_dataset_readme_content(
724
+ self.pipeline_card_artifacts_path)
725
+ else:
726
+ from retrain_pipelines.pipeline_card import \
727
+ get_dataset_readme_content
728
+ #############################
729
+
730
+
731
+ #############################
732
+ # augmented & enriched #
733
+ # finetuning dataset #
734
+ #############################
735
+ merged_df = pl.concat([
736
+ # dataset
737
+ self.hf_dataset_dict["lazy_df"].select([
738
+ self.hf_dataset["attributes"]["query_attr"],
739
+ self.hf_dataset["attributes"]["answers_attr"]
740
+ ]).collect(engine=self.engine),
741
+ # truncated queries augmentation
742
+ self.augmented_records_df,
743
+ # enriching dataset
744
+ self.enriched_records_df
745
+ ]).sample(
746
+ # shuffling
747
+ fraction=1,
748
+ shuffle=True,
749
+ with_replacement=False
750
+ )
751
+ merged_df = merged_df.sample(fraction=1, shuffle=True)
752
+ merged_df.rechunk()
753
+ print(("merged_df", f"{merged_df.shape[0]:,.0F}",
754
+ merged_df.columns))
755
+
756
+ pandas_df = merged_df.to_pandas()
757
+ train_size = int(0.8 * len(pandas_df))
758
+ print(f"validation : {len(pandas_df) - train_size}")
759
+ sft_dataset = DatasetDict({
760
+ "train": Dataset.from_pandas(pandas_df[:train_size]),
761
+ "validation": Dataset.from_pandas(pandas_df[train_size:])
762
+ })
763
+ #############################
764
+
765
+ #############################
766
+ # continued pre-training #
767
+ # dataset #
768
+ #############################
769
+ struct_schema = pl.Struct([
770
+ pl.Field("name", pl.String),
771
+ pl.Field("description", pl.String),
772
+ pl.Field(
773
+ "parameters",
774
+ pl.String # Use String to allow
775
+ # for varying structures
776
+ # (different tools indeed having
777
+ # different sets of parameters
778
+ # i.e. different parameters counts,
779
+ # datatypes and names)
780
+ # so parsing must be tolerant.
781
+ )
782
+ ])
783
+ unique_tools_df = get_unique_tools(
784
+ self.hf_dataset_dict["lazy_df"],
785
+ tools_attr_name=\
786
+ self.hf_dataset["attributes"]["tools_attr"],
787
+ struct_schema=struct_schema
788
+ ).collect(engine=self.engine)
789
+ unique_tools_arrow_table = unique_tools_df.to_arrow()
790
+ self.unique_tools_dataset = \
791
+ Dataset(unique_tools_arrow_table)
792
+ print(self.unique_tools_dataset)
793
+ #############################
794
+
795
+ #############################
796
+ # DatasetDict #
797
+ # with multiple tables #
798
+ #############################
799
+ dataset_dict = DatasetDict({
800
+ "continued_pre_training": \
801
+ self.unique_tools_dataset,
802
+ "supervised_finetuning": sft_dataset
803
+ })
804
+ print(dataset_dict, flush=True)
805
+ #############################
806
+
807
+ #############################
808
+ # dataset README #
809
+ # from template #
810
+ #############################
811
+ commit_datetime = datetime.utcnow()
812
+ new_dataset_version_label = get_new_repo_minor_version(
813
+ repo_id=self.dataset_repo_id,
814
+ repo_type="dataset",
815
+ hf_token=os.getenv("HF_TOKEN", None))
816
+ readme_content = get_dataset_readme_content(
817
+ template_folder=template_dir,
818
+
819
+ hf_dataset_dict=self.hf_dataset_dict,
820
+ hf_enrich_dataset_dict=self.hf_enrich_dataset_dict,
821
+ dataset_dict=dataset_dict,
822
+
823
+ augmentation_rate=self.actual_augmentation_rate,
824
+ enrichment_rate=self.enrichment_rate,
825
+
826
+ version_label=new_dataset_version_label,
827
+ commit_datetime=commit_datetime,
828
+
829
+ mf_flow_name=current.flow_name,
830
+ mf_run_id=current.run.id,
831
+ engine=self.engine
832
+ )
833
+ #############################
834
+
835
+ dataset_commit_hash = push_dataset_version_to_hub(
836
+ repo_id=self.dataset_repo_id,
837
+ version_label=new_dataset_version_label,
838
+ timestamp_str=commit_datetime.strftime(
839
+ "%Y-%m-%d %H:%M:%S UTC"),
840
+ dataset_dict=dataset_dict,
841
+ dataset_readme_content=readme_content,
842
+ hf_token=os.getenv("HF_TOKEN", None)
843
+ )
844
+ if not dataset_commit_hash:
845
+ raise Exception(
846
+ "Failed to publish dataset version.")
847
+ print(f"https://huggingface.co/datasets/{self.dataset_repo_id}" +
848
+ f"/blob/{dataset_commit_hash}/README.md")
849
+ self.dataset_commit_dict = {
850
+ "repo_id": self.dataset_repo_id,
851
+ "commit_hash": dataset_commit_hash,
852
+ "version_label": new_dataset_version_label,
853
+ "commit_datetime": commit_datetime,
854
+ }
855
+
856
+ self.next(self.continued_pre_training)
857
+
858
+
859
+ @step
860
+ def continued_pre_training(self):
861
+ """
862
+ Gives the base model some additional intrinsic knowkledge
863
+ through continued pre-training.
864
+ See unsloth.ai/blog/contpretraining
865
+ """
866
+ from retrain_pipelines.model.hf_utils import \
867
+ plot_log_history
868
+
869
+ #######################################
870
+ # base-model and associated tokenizer #
871
+ # from Hub (or local cache) #
872
+ #######################################
873
+ self.max_seq_length = 2048
874
+ model, tokenizer = FastLanguageModel.from_pretrained(
875
+ model_name=self.hf_base_model_dict["repo_id"],
876
+ revision=self.hf_base_model_dict["commit_hash"],
877
+ max_seq_length=self.max_seq_length,
878
+ dtype=None,
879
+ load_in_4bit=False,
880
+ # case of a gated or private base-model
881
+ token=os.getenv("HF_TOKEN", None)
882
+ )
883
+ #######################################
884
+
885
+ #######################################
886
+ # dataset prompt_template mapping #
887
+ #######################################
888
+ tools_dataset = DatasetDict(
889
+ {"train": self.unique_tools_dataset})
890
+ print(tools_dataset)
891
+ tool_prompt_template = "tool: {}"
892
+ def formatting_prompts_func(tools_batch):
893
+ tools_batch = tools_batch["tool"]
894
+ outputs = []
895
+ for tool in tools_batch:
896
+ # Must add EOS_TOKEN,
897
+ # otherwise generation will go on forever!
898
+ text = tool_prompt_template.format(tool) + \
899
+ tokenizer.eos_token
900
+ outputs.append(text)
901
+ return { "tools" : outputs, }
902
+ cpt_dataset = tools_dataset["train"].map(
903
+ formatting_prompts_func, batched=True,)
904
+ #######################################
905
+
906
+ #######################################
907
+ # PEFT adapter #
908
+ # for continued pre-training #
909
+ #######################################
910
+ model = FastLanguageModel.get_peft_model(
911
+ model,
912
+ r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
913
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
914
+ "gate_proj", "up_proj", "down_proj",
915
+ # Add for continued pretraining
916
+ "embed_tokens", "lm_head",],
917
+ lora_alpha = 32,
918
+ lora_dropout = 0, # Supports any, 0 is optimized
919
+ bias = "none", # Supports any, "none" is optimized
920
+ # True or "unsloth" for very long context
921
+ use_gradient_checkpointing = "unsloth",
922
+ use_rslora = True, # rank-stabilized LoRA
923
+ loftq_config = None, # LoftQ
924
+ #random_state = 3407,
925
+ )
926
+ #######################################
927
+
928
+ #######################################
929
+ # cpt_trainer #
930
+ #######################################
931
+ if (
932
+ "records_cap" in self.cpt_training_args and
933
+ self.cpt_training_args["records_cap"] is not None and
934
+ isinstance(self.cpt_training_args["records_cap"], int)
935
+ ):
936
+ cpt_dataset = cpt_dataset.take(
937
+ self.cpt_training_args["records_cap"])
938
+ print(f"cpt_dataset : {cpt_dataset}")
939
+
940
+ train_args = UnslothTrainingArguments(
941
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
942
+ per_device_train_batch_size=2,
943
+ gradient_accumulation_steps=8,
944
+
945
+ **{k: v for k, v in self.cpt_training_args.items()
946
+ if k != "records_cap"},
947
+
948
+ # 2 to 10x smaller learning rate
949
+ # for the embedding matrices
950
+ learning_rate=5e-5,
951
+ embedding_learning_rate=1e-5,
952
+
953
+ fp16=not is_bfloat16_supported(),
954
+ bf16=is_bfloat16_supported(),
955
+ logging_steps=1,
956
+ optim="adamw_8bit",
957
+ weight_decay=0.01,
958
+ lr_scheduler_type="linear",
959
+ #seed=3407,
960
+
961
+ output_dir=os.path.join(
962
+ self.unsloth_dir, "outputs", "cpt"),
963
+ save_total_limit = 2,
964
+
965
+ report_to="tensorboard",
966
+ logging_dir=os.path.join(
967
+ self.sft_model_dir,
968
+ "runs", "cpt")
969
+ )
970
+
971
+ self.cpt_traces_file_fullname = os.path.join(
972
+ self.unsloth_dir, "cpt_trainer_traces.txt")
973
+ print("Training started. " +
974
+ f"Check {self.cpt_traces_file_fullname} for live traces.",
975
+ flush=True)
976
+
977
+ trainer = UnslothTrainer(
978
+ model=model, tokenizer=tokenizer,
979
+ train_dataset=cpt_dataset,
980
+ dataset_text_field="tools",
981
+ max_seq_length=self.max_seq_length,
982
+ dataset_num_proc=2,
983
+ args=train_args,
984
+ )
985
+ #######################################
986
+
987
+ #######################################
988
+ # Show current memory stats #
989
+ #######################################
990
+ torch.cuda.ipc_collect()
991
+ torch.cuda.empty_cache()
992
+ _ = gc.collect()
993
+
994
+ gpu_stats = torch.cuda.get_device_properties(0)
995
+ self.start_gpu_memory = \
996
+ round(torch.cuda.max_memory_reserved()
997
+ / 1024 / 1024 / 1024, 3)
998
+ self.max_memory = \
999
+ round(gpu_stats.total_memory
1000
+ / 1024 / 1024 / 1024, 3)
1001
+ print(f"GPU = {gpu_stats.name}. " +
1002
+ f"Max memory = {self.max_memory} GB.")
1003
+ print(f"{self.start_gpu_memory} GB of memory reserved.")
1004
+ #######################################
1005
+
1006
+ with open(self.cpt_traces_file_fullname, 'w') as f:
1007
+ with redirect_stdout(f):
1008
+ hf_logging.set_verbosity_error()
1009
+ hf_logging.disable_progress_bar()
1010
+ trainer_stats = trainer.train()
1011
+ hf_logging.set_verbosity_info()
1012
+ hf_logging.enable_progress_bar()
1013
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1014
+ f"seconds used for training " +
1015
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1016
+ f" minutes).")
1017
+
1018
+ self.cpt_log_history = trainer.state.log_history
1019
+ # print(self.cpt_log_history)
1020
+ self.cpt_log_history_fig = \
1021
+ plot_log_history(
1022
+ self.cpt_log_history,
1023
+ title="Continued pretraining loss"
1024
+ )
1025
+
1026
+ model.save_pretrained_merged(
1027
+ save_directory=self.cpt_model_dir,
1028
+ tokenizer=tokenizer,
1029
+ save_method="lora"
1030
+ )
1031
+ print(f"cpt_model_dir : {self.cpt_model_dir}\n")
1032
+
1033
+ self.next(self.supervised_finetuning)
1034
+
1035
+
1036
+ @step
1037
+ def supervised_finetuning(self):
1038
+ """
1039
+ Trains the model on tool-calling
1040
+ task specialization.
1041
+ """
1042
+ from retrain_pipelines.model.hf_utils import \
1043
+ plot_log_history
1044
+
1045
+ torch.cuda.ipc_collect()
1046
+ torch.cuda.empty_cache()
1047
+ _ = gc.collect()
1048
+
1049
+ model, tokenizer = FastLanguageModel.from_pretrained(
1050
+ model_name=self.cpt_model_dir,
1051
+ max_seq_length=self.max_seq_length,
1052
+ dtype=None,
1053
+ load_in_4bit=False,
1054
+ )
1055
+ # !!!! bug fix BEGIN !!!!
1056
+ # otherwise, 'embed_tokens' and 'lm_head'
1057
+ # trained during CPT are "ignored",
1058
+ # i.e. not saved after SFT
1059
+ # (note that, alternatively, we could also
1060
+ # do this fix after sft-training and
1061
+ # just before saving ;
1062
+ # which would be equivalent to
1063
+ # freezing embeddings during finetuning
1064
+ # for better pretrained knowledge retention)
1065
+ # @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
1066
+ model.model.model.embed_tokens.modules_to_save.default.to(
1067
+ device="cuda:0",
1068
+ dtype=torch.float32,
1069
+ non_blocking=True)
1070
+ model.model.model.embed_tokens.modules_to_save.default \
1071
+ .requires_grad_(True)
1072
+ model.model.lm_head.modules_to_save.default.to(
1073
+ device="cuda:0",
1074
+ dtype=torch.float32,
1075
+ non_blocking=True)
1076
+ model.model.lm_head.modules_to_save.default \
1077
+ .requires_grad_(True)
1078
+ # !!!! bug fix END !!!!
1079
+
1080
+ #######################################
1081
+ # dataset prompt_template mapping #
1082
+ #######################################
1083
+ # download from Hub (or get from local cache)
1084
+ queries_dataset = load_dataset(
1085
+ path=self.dataset_commit_dict["repo_id"],
1086
+ name="supervised_finetuning",
1087
+ revision=self.dataset_commit_dict["commit_hash"],
1088
+ token=os.getenv("HF_TOKEN", None))
1089
+ print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
1090
+ self.sft_prompt_template = dedent("""
1091
+ You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
1092
+
1093
+ Rules:
1094
+ 1. You can only use tools you know. Do not create new tools under any circumstances.
1095
+ 2. If a query does not match any known tool, return an empty list ([]).
1096
+ 3. If information is missing to use a known tool, do not attempt to use it.
1097
+ 4. Your response must always be a valid JSON array, and nothing else.
1098
+
1099
+ Be precise and do not guess.
1100
+
1101
+ # query:
1102
+ {}
1103
+ # response:
1104
+ {}
1105
+ """).strip()
1106
+ tokenizer.chat_template = self.sft_prompt_template
1107
+
1108
+ EOS_TOKEN = tokenizer.eos_token
1109
+ def formatting_prompts_func(records):
1110
+ query = records["query"]
1111
+ tools = records["answers"]
1112
+ outputs = []
1113
+ for query, tools in zip(query, tools):
1114
+ # Must add EOS_TOKEN,
1115
+ # otherwise your generation will go on forever
1116
+ text = self.sft_prompt_template.format(query, tools) \
1117
+ + EOS_TOKEN
1118
+ outputs.append(text)
1119
+ return { "text" : outputs, }
1120
+ sft_train_dataset = queries_dataset["train"].map(
1121
+ formatting_prompts_func, batched=True)
1122
+ sft_valid_dataset = queries_dataset["validation"].map(
1123
+ formatting_prompts_func, batched=True,)
1124
+ #######################################
1125
+
1126
+ #######################################
1127
+ # PEFT adapter #
1128
+ # for supervised finetuning #
1129
+ #######################################
1130
+ # for cases where CPT has been merged into overall model
1131
+ # otherwize, keep on training current LoRa adapter
1132
+ # model = FastLanguageModel.get_peft_model(
1133
+ # model,
1134
+ # r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
1135
+ # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
1136
+ # "gate_proj", "up_proj", "down_proj"],
1137
+ # lora_alpha = 32,
1138
+ # lora_dropout = 0, # Supports any, but = 0 is optimized
1139
+ # bias = "none", # Supports any, but = "none" is optimized
1140
+ # # True or "unsloth" for very long context
1141
+ # use_gradient_checkpointing = "unsloth",
1142
+ # random_state = 3407,
1143
+ # use_rslora = True, # rank stabilized LoRA
1144
+ # loftq_config = None, # LoftQ
1145
+ # )
1146
+ #######################################
1147
+
1148
+ #######################################
1149
+ # sft_trainer #
1150
+ #######################################
1151
+ split = sft_train_dataset.train_test_split(
1152
+ test_size=1000,
1153
+ #seed=42
1154
+ )
1155
+ train_dataset = split['train']
1156
+ eval_dataset = split['test']
1157
+ if (
1158
+ "records_cap" in self.sft_training_args and
1159
+ self.sft_training_args["records_cap"] is not None and
1160
+ isinstance(self.sft_training_args["records_cap"], int)
1161
+ ):
1162
+ train_dataset = train_dataset.take(
1163
+ self.sft_training_args["records_cap"])
1164
+ eval_dataset = eval_dataset.take(
1165
+ self.sft_training_args["records_cap"])
1166
+ print(f"train_dataset : {train_dataset}")
1167
+ print(f"eval_dataset : {eval_dataset}")
1168
+
1169
+ train_args = UnslothTrainingArguments(
1170
+ per_device_train_batch_size=2,
1171
+ gradient_accumulation_steps=8,
1172
+
1173
+ **{k: v for k, v in self.sft_training_args.items()
1174
+ if k != "records_cap"},
1175
+
1176
+ per_device_eval_batch_size=2,
1177
+ eval_steps=200,
1178
+ eval_strategy="steps",
1179
+ do_eval=True,
1180
+
1181
+ learning_rate=5e-5,
1182
+ # embedding_learning_rate=1e-5, # Optionally here
1183
+
1184
+ fp16=not is_bfloat16_supported(),
1185
+ bf16=is_bfloat16_supported(),
1186
+
1187
+ optim="adamw_8bit",
1188
+ weight_decay=0.00,
1189
+ lr_scheduler_type="linear",
1190
+ #seed=3407,
1191
+
1192
+ output_dir=os.path.join(
1193
+ self.unsloth_dir, "outputs", "sft"),
1194
+ save_total_limit=2,
1195
+
1196
+ logging_steps=1,
1197
+ report_to="tensorboard",
1198
+ logging_dir=os.path.join(
1199
+ self.sft_model_dir,
1200
+ "runs", "sft")
1201
+ )
1202
+
1203
+ self.sft_traces_file_fullname = os.path.join(
1204
+ self.unsloth_dir, "sft_trainer_traces.txt")
1205
+ print("Training started. " +
1206
+ f"Check {self.sft_traces_file_fullname} for live traces.",
1207
+ flush=True)
1208
+
1209
+ trainer = UnslothTrainer(
1210
+ model=model, tokenizer=tokenizer,
1211
+ train_dataset=train_dataset,
1212
+ dataset_text_field="text",
1213
+ eval_dataset=eval_dataset,
1214
+ max_seq_length=self.max_seq_length,
1215
+ dataset_num_proc=8,
1216
+ args=train_args
1217
+ )
1218
+ trainer.can_return_loss = True
1219
+ #######################################
1220
+
1221
+ #######################################
1222
+ # Show current memory stats #
1223
+ #######################################
1224
+ torch.cuda.ipc_collect()
1225
+ torch.cuda.empty_cache()
1226
+ _ = gc.collect()
1227
+
1228
+ used_memory = \
1229
+ round(torch.cuda.max_memory_reserved()
1230
+ /1024/1024/1024, 3)
1231
+ used_memory_for_lora = \
1232
+ round(used_memory-self.start_gpu_memory, 3)
1233
+ used_percentage = \
1234
+ round(used_memory/self.max_memory*100, 3)
1235
+ lora_percentage = \
1236
+ round(used_memory_for_lora/self.max_memory*100,
1237
+ 3)
1238
+ print(f"Peak reserved memory = " +
1239
+ f"{used_memory} GB.")
1240
+ print(f"Peak reserved memory for " +
1241
+ f"training = {used_memory_for_lora} " +
1242
+ f"GB.")
1243
+ print(f"Peak reserved memory % of " +
1244
+ f"max memory = {used_percentage} %.")
1245
+ print(f"Peak reserved memory for training " +
1246
+ f"% of max memory = {lora_percentage} %.")
1247
+ #######################################
1248
+
1249
+ with open(self.sft_traces_file_fullname, 'w') as f:
1250
+ with redirect_stdout(f):
1251
+ hf_logging.set_verbosity_error()
1252
+ hf_logging.disable_progress_bar()
1253
+ trainer_stats = trainer.train()
1254
+ hf_logging.set_verbosity_info()
1255
+ hf_logging.enable_progress_bar()
1256
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1257
+ f"seconds used for training " +
1258
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1259
+ f" minutes).")
1260
+
1261
+ self.sft_log_history = trainer.state.log_history
1262
+ self.sft_log_history_fig = \
1263
+ plot_log_history(
1264
+ self.sft_log_history,
1265
+ title="Supervised finetuning loss"
1266
+ )
1267
+
1268
+ model.save_pretrained_merged(
1269
+ self.sft_model_dir, tokenizer,
1270
+ save_method = "lora"
1271
+ )
1272
+ print(f"sft_model_dir : {self.sft_model_dir}\n")
1273
+
1274
+ self.next(self.evaluate_model)
1275
+
1276
+
1277
+ @step
1278
+ def evaluate_model(self):
1279
+ """
1280
+ Batch inference on the SFT validation dataset.
1281
+ """
1282
+ from retrain_pipelines.model import \
1283
+ infer_validation, compute_counts_n_metrics, \
1284
+ plot_validation_completions
1285
+
1286
+ torch.cuda.ipc_collect()
1287
+ torch.cuda.empty_cache()
1288
+ _ = gc.collect()
1289
+
1290
+
1291
+ ######################################################
1292
+ # loading trained adapter #
1293
+ ######################################################
1294
+ # Unsloth [and hf transformers before it] #
1295
+ # (if loading both model & tokenizer at once #
1296
+ # same as we did in prior tasks, but now #
1297
+ # with tokenizer.chat_template being set #
1298
+ # in tokenizer.config) is forcing on us some kind of #
1299
+ # chat_template format hard-requirements. #
1300
+ ######################################################
1301
+ # load base from cache
1302
+ # (with base tokenizer, which we ignore)
1303
+ model, _ = FastLanguageModel.from_pretrained(
1304
+ model_name=self.hf_base_model_dict["repo_id"],
1305
+ revision=self.hf_base_model_dict["commit_hash"],
1306
+ max_seq_length=self.max_seq_length,
1307
+ dtype=None,
1308
+ load_in_4bit=False,
1309
+ # case of a gated or private base-model
1310
+ token=os.getenv("HF_TOKEN", None)
1311
+ )
1312
+ model = FastLanguageModel.for_inference(model)
1313
+ # load our CPT+SFT trained & locally-saved adapter
1314
+ model.load_adapter(peft_model_id=self.sft_model_dir)
1315
+ # Separately load our (potentially trained &)
1316
+ # locally-saved adapter-tokenizer
1317
+ # (loading it below via HF and not Unsloth)
1318
+ tokenizer = AutoTokenizer.from_pretrained(
1319
+ pretrained_model_name_or_path=self.sft_model_dir
1320
+ )
1321
+ ######################################################
1322
+
1323
+ ######################################################
1324
+ # validation dataset #
1325
+ ######################################################
1326
+ # download from Hub (or get from local cache)
1327
+ queries_dataset = load_dataset(
1328
+ path=self.dataset_commit_dict["repo_id"],
1329
+ name="supervised_finetuning",
1330
+ revision=self.dataset_commit_dict["commit_hash"],
1331
+ token=os.getenv("HF_TOKEN", None))
1332
+ if (
1333
+ "records_cap" in self.sft_training_args and
1334
+ self.sft_training_args["records_cap"] is not None and
1335
+ isinstance(self.sft_training_args["records_cap"], int)
1336
+ ):
1337
+ validation_data = queries_dataset["validation"].take(
1338
+ self.sft_training_args["records_cap"])
1339
+ else:
1340
+ validation_data = queries_dataset["validation"]
1341
+ print(validation_data, flush=True)
1342
+ ######################################################
1343
+
1344
+ self.max_new_tokens = 400
1345
+ start_time = time.time()
1346
+ validation_results = infer_validation(
1347
+ tokenizer=tokenizer,
1348
+ model=model,
1349
+ validation_data=validation_data,
1350
+ prompt_template=tokenizer.chat_template,
1351
+ batch_size=32, # 64,
1352
+ queries_attr_name=\
1353
+ self.hf_dataset["attributes"]["query_attr"],
1354
+ answers_attr_name=\
1355
+ self.hf_dataset["attributes"]["answers_attr"],
1356
+ max_new_tokens=self.max_new_tokens,
1357
+ device="cuda"
1358
+ )
1359
+ print("infer_validation - Elapsed time: " +
1360
+ f"{(time.time() - start_time):.2f} seconds")
1361
+ self.validation_results = validation_results # <= to artifacts store
1362
+
1363
+ eval_df = pl.LazyFrame(validation_results)
1364
+
1365
+ records = eval_df.with_columns(
1366
+ (pl.col("answer") == pl.col("completion")) \
1367
+ .alias("is_ground_truth_identical")
1368
+ ).collect() #engine=self.engine)
1369
+ print("perfect characters-match accuracy : " +
1370
+ str(records['is_ground_truth_identical'].mean()))
1371
+
1372
+ eval_metrics_df = compute_counts_n_metrics(
1373
+ eval_df, is_format_fault_tolerant=True)
1374
+ overall_metrics_df = eval_metrics_df.select([
1375
+ pl.col("precision").mean(),
1376
+ pl.col("recall").mean(),
1377
+ pl.col("f1").mean(),
1378
+ pl.col("jaccard").mean()
1379
+ ]).collect() #engine=self.engine)
1380
+ self.perf_metrics = overall_metrics_df.row(0, named=True)
1381
+ print(self.perf_metrics)
1382
+
1383
+ self.validation_completions_fig = \
1384
+ plot_validation_completions(
1385
+ eval_metrics_df, engine=self.engine)
1386
+
1387
+ del model
1388
+ del tokenizer
1389
+ torch.cuda.ipc_collect()
1390
+ torch.cuda.empty_cache()
1391
+ _ = gc.collect()
1392
+
1393
+ self.next(self.model_version_blessing)
1394
+
1395
+
1396
+ @step
1397
+ def model_version_blessing(self):
1398
+ """
1399
+ Comparing newly-retrained model version
1400
+ against best-performing predecessor.
1401
+ """
1402
+ """
1403
+ Note: for Hugging Face integrated pipelines,
1404
+ we compare against lastest commit of main branch
1405
+ of the model repository there.
1406
+ When it comes to local "mf_run_id" of the pipeline run
1407
+ having generated that best prior model version
1408
+ (retrieved from model card metadata from HF yaml section),
1409
+ we check against records of the herein ML-framework instance,
1410
+ as "prior best version" of the model here beign retrained
1411
+ may have been originated from another one
1412
+ than the one executing the current retraining
1413
+ (in which case, we simply don't includ a "local" hyperlink
1414
+ in the model version pipeline_cards that will be
1415
+ produced later in the herein pipeline run).
1416
+ """
1417
+ from retrain_pipelines.model.hf_utils import \
1418
+ current_blessed_model_version_dict
1419
+
1420
+ main_perf_metric_name = "jaccard"
1421
+
1422
+ current_blessed_version_dict = \
1423
+ current_blessed_model_version_dict(
1424
+ repo_id=self.model_repo_id,
1425
+ hf_token=os.getenv("HF_TOKEN", None)
1426
+ )
1427
+ print("current_blessed_version_dict : " +
1428
+ str(current_blessed_version_dict))
1429
+
1430
+ if current_blessed_version_dict is None:
1431
+ print("case 'no prior blessed model version found"
1432
+ " => blessing.'")
1433
+ self.model_version_blessed = True
1434
+
1435
+ elif (
1436
+ main_perf_metric_name in
1437
+ current_blessed_version_dict["perf_metrics"]
1438
+ ):
1439
+ current_blessed_run_id = \
1440
+ current_blessed_version_dict["mf_run_id"]
1441
+ print(f"current_blessed_run_id : {current_blessed_run_id}")
1442
+ current_blessed_metric_value = \
1443
+ current_blessed_version_dict[
1444
+ "perf_metrics"][main_perf_metric_name]
1445
+
1446
+ self.model_version_blessed = (
1447
+ self.perf_metrics[main_perf_metric_name] >=
1448
+ current_blessed_metric_value
1449
+ )
1450
+
1451
+ if not self.model_version_blessed:
1452
+ self.current_blessed_version_dict = \
1453
+ current_blessed_version_dict
1454
+ for run in Flow(self.__class__.__name__):
1455
+ if str(run.id) == current_blessed_run_id:
1456
+ run_steps = iter(run.steps())
1457
+ last_run_step = next(run_steps)
1458
+ last_task = next(iter(last_run_step.tasks()))
1459
+
1460
+ # tasks are listed backwards, so last task is first item :
1461
+ # Has the run seen task "pipeline_card" prior to last task
1462
+ # (meaning, "pipeline_card" completed successfully and
1463
+ # "run" has generated a sutom pipeline-card artifact) ?
1464
+ # If not, hyperlink generation will later fail.
1465
+ run_has_custom_card_artifact = False
1466
+ for step in run_steps:
1467
+ if "pipeline_card" == step.id:
1468
+ run_has_custom_card_artifact = True
1469
+ break
1470
+
1471
+ if not run_has_custom_card_artifact:
1472
+ print(
1473
+ f"Run #{current_blessed_run_id} " +
1474
+ "Doesn't seem to have successfully " +
1475
+ "generated a pipeline-card artifact.",
1476
+ file=sys.stderr, flush=True)
1477
+ break
1478
+ else:
1479
+ # further filtering on successful runs that are
1480
+ # retraining of a prior version of the same model
1481
+ # (to minimize the risk that this was obtained
1482
+ # on another ML-framework instance)
1483
+ if (
1484
+ # last_task.successful and
1485
+ # may have failed after the "pipeline_card" step
1486
+ # and been resumed
1487
+ hasattr(last_task.artifacts,
1488
+ 'model_version_blessed') and
1489
+ last_task.artifacts.model_version_blessed.data and
1490
+ hasattr(last_task.artifacts,
1491
+ 'model_repo_id') and
1492
+ last_task.artifacts.model_repo_id.data == \
1493
+ self.model_repo_id
1494
+ ):
1495
+ self.current_blessed_run = run
1496
+ break
1497
+
1498
+ if not self.current_blessed_run:
1499
+ print(
1500
+ "Couldn't find blessed run " +
1501
+ f"{current_blessed_run_id} !\n" +
1502
+ "It seems that prior blessed run was " +
1503
+ "executed on another ML framework instance.",
1504
+ file=sys.stderr, flush=True)
1505
+
1506
+ print("new : " +
1507
+ str(self.perf_metrics[main_perf_metric_name]) +
1508
+ " - previous best : " +
1509
+ str(current_blessed_metric_value) +
1510
+ " - model_version_blessing : " +
1511
+ str(self.model_version_blessed))
1512
+
1513
+ else:
1514
+ raise Exception(
1515
+ "Performance metric '" +
1516
+ main_perf_metric_name +
1517
+ "' can't be found in eval results " +
1518
+ "from blessed run " +
1519
+ str(current_blessed_version_dict[
1520
+ "mf_run_id"]) + " !")
1521
+
1522
+ # self.model_version_blessed = True ### DEBUG - DELETE ###
1523
+
1524
+ self.next(self.model_to_hub)
1525
+
1526
+
1527
+ @step
1528
+ def model_to_hub(self):
1529
+ """
1530
+ Push to hub model version, including
1531
+ readme with versioning info.
1532
+ """
1533
+
1534
+ #############################
1535
+ # case of user-provided #
1536
+ # documentation artifact(s) #
1537
+ #############################
1538
+ # note that user can provide either
1539
+ # 'pipeline_card.py' or 'template.html'
1540
+ # or 'dataset_readme.py'
1541
+ # or 'dataset_readme_template.md'
1542
+ # or 'model_readme.py'
1543
+ # or 'model_readme_template.md'
1544
+ # or any combination of those
1545
+ # when specifying custom
1546
+ # 'pipeline_card_artifacts_path'
1547
+ if (
1548
+ "model_readme_template.md" in
1549
+ os.listdir(self.pipeline_card_artifacts_path)
1550
+ ):
1551
+ template_dir = self.pipeline_card_artifacts_path
1552
+ else:
1553
+ template_dir = os.path.dirname(
1554
+ importlib.util.find_spec(
1555
+ f"retrain_pipelines.pipeline_card."+
1556
+ f"{os.getenv('retrain_pipeline_type')}"
1557
+ ).origin)
1558
+ print(f"template_dir : '{template_dir}'")
1559
+ #############################
1560
+ if "model_readme.py" in os.listdir(
1561
+ self.pipeline_card_artifacts_path):
1562
+ from retrain_pipelines.utils import \
1563
+ get_get_model_readme_content
1564
+ get_model_readme_content = \
1565
+ get_get_model_readme_content(
1566
+ self.pipeline_card_artifacts_path)
1567
+ else:
1568
+ from retrain_pipelines.pipeline_card import \
1569
+ get_model_readme_content
1570
+ #############################
1571
+ from retrain_pipelines.model.hf_utils import \
1572
+ push_model_version_to_hub
1573
+
1574
+ #############################
1575
+ # model README #
1576
+ # from template #
1577
+ #############################
1578
+ commit_datetime = datetime.utcnow()
1579
+ new_model_version_label = get_new_repo_minor_version(
1580
+ repo_id=self.model_repo_id,
1581
+ repo_type="model",
1582
+ hf_token=os.getenv("HF_TOKEN", None))
1583
+ readme_content = get_model_readme_content(
1584
+ template_folder=template_dir,
1585
+
1586
+ model_repo_id=self.model_repo_id,
1587
+
1588
+ base_model_dict=self.hf_base_model_dict,
1589
+ training_dataset_dict=self.dataset_commit_dict,
1590
+
1591
+ version_label=new_model_version_label,
1592
+ commit_datetime=commit_datetime,
1593
+ perf_metrics=self.perf_metrics,
1594
+
1595
+ mf_flow_name=current.flow_name,
1596
+ mf_run_id=current.run.id
1597
+ )
1598
+ #############################
1599
+
1600
+ print("Pushing model version to HF hub " +
1601
+ ("(blessed). " if self.model_version_blessed
1602
+ else "(not blessed). ") +
1603
+ "May take a while..",
1604
+ flush=True)
1605
+ model_commit_hash = push_model_version_to_hub(
1606
+ repo_id=self.model_repo_id,
1607
+ model_version_blessed=\
1608
+ self.model_version_blessed,
1609
+ version_label=new_model_version_label,
1610
+ timestamp_str=commit_datetime.strftime(
1611
+ "%Y-%m-%d %H:%M:%S UTC"),
1612
+ model_dir=self.sft_model_dir,
1613
+ model_readme_content=readme_content,
1614
+ hf_token=os.getenv("HF_TOKEN", None)
1615
+ )
1616
+ if not model_commit_hash:
1617
+ raise Exception(
1618
+ "Failed to publish model version.")
1619
+ print("Push of model version to HF hub completed.",
1620
+ flush=True)
1621
+ print(f"https://huggingface.co/{self.model_repo_id}" +
1622
+ f"/blob/{model_commit_hash}/README.md")
1623
+
1624
+ self.model_commit_dict = {
1625
+ "repo_id": self.model_repo_id,
1626
+ "commit_hash": model_commit_hash,
1627
+ "version_label": new_model_version_label,
1628
+ "commit_datetime": commit_datetime,
1629
+ }
1630
+
1631
+ self.next(self.infra_validator)
1632
+
1633
+
1634
+ @step
1635
+ def infra_validator(self):
1636
+ """
1637
+ If the trained model version is blessed,
1638
+ validate serving.
1639
+ """
1640
+ """
1641
+ Note that using isolated virtual env
1642
+ (using @conda task decorator)
1643
+ is advisable to not embark the whole
1644
+ pipeline dependencies into the local server.
1645
+ We don't for educational purpose,
1646
+ keep things "simple" to grasp
1647
+ as well as to avoid forcing conda
1648
+ (for instance miniconda) as
1649
+ a virtual environment management mean
1650
+ to the user.
1651
+ """
1652
+ """
1653
+ Note : We load base model from HF-cache
1654
+ (mounted as /huggingface_hub_cache
1655
+ docker volume) and adapter from local dir
1656
+ (mounted as /FuncCallAdater docker volume.
1657
+ """
1658
+
1659
+ self.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
1660
+
1661
+ if self.model_version_blessed:
1662
+ from retrain_pipelines.utils.docker import \
1663
+ env_has_docker
1664
+
1665
+ if env_has_docker():
1666
+ model_module_dir = \
1667
+ os.path.dirname(
1668
+ importlib.util.find_spec(
1669
+ "retrain_pipelines.model." +
1670
+ os.getenv('retrain_pipeline_type')
1671
+ ).origin)
1672
+
1673
+ # server & data-model & server-config modules artifacts
1674
+ files_to_copy = [
1675
+ "litserve_server.py",
1676
+ "litserve_datamodel.py",
1677
+ "litserve_serverconfig.py",
1678
+ ".dockerignore" # docker context loading
1679
+ # at image-build time,
1680
+ # exclude model weights
1681
+ ]
1682
+ for filename in files_to_copy:
1683
+ shutil.copy(
1684
+ os.path.join(model_module_dir, "litserve",
1685
+ filename),
1686
+ os.path.join(self.serving_artifacts_local_folder,
1687
+ filename)
1688
+ )
1689
+
1690
+ # save dependencies as artifact
1691
+ create_requirements(self.serving_artifacts_local_folder,
1692
+ exclude=["cudf-polars-.*", "cuda-python",
1693
+ "nvidia-.*", "(py)?libcudf-.*",
1694
+ "nvtx", "rmm-.*", "litserve",
1695
+ ".*retrain-pipelines.*"]
1696
+ )
1697
+
1698
+ # server config yaml
1699
+ env = Environment(loader=FileSystemLoader(
1700
+ os.path.join(model_module_dir, "litserve")))
1701
+ template = env.get_template(
1702
+ "litserve_serverconfig_template.yaml")
1703
+ server_config_data = {
1704
+ "port": "8000",
1705
+ "max_seq_length": self.max_seq_length,
1706
+ "max_new_token": self.max_new_tokens,
1707
+ "base_model": {
1708
+ "repo_id": self.hf_base_model_dict["repo_id"],
1709
+ "revision": self.hf_base_model_dict["commit_hash"]
1710
+ },
1711
+ "adapters": [
1712
+ {
1713
+ "name": "func_caller",
1714
+ "path": "/FuncCallAdapter"
1715
+ }
1716
+ ]
1717
+ }
1718
+ server_config_yaml = template.render(server_config_data)
1719
+ print(server_config_yaml)
1720
+ with open(os.path.join(
1721
+ self.serving_artifacts_local_folder,
1722
+ "litserve_serverconfig.yaml"), 'w'
1723
+ ) as output_file:
1724
+ output_file.write(server_config_yaml)
1725
+
1726
+ # Dockerfile
1727
+ env = Environment(loader=FileSystemLoader(
1728
+ os.path.join(model_module_dir)))
1729
+ template = env.get_template(
1730
+ "Dockerfile.litserve_template")
1731
+ # Change CUDA version here from available list
1732
+ # @see https://hub.docker.com/r/nvidia/cuda/tags
1733
+ dockerfile_content = template.render(
1734
+ {"cuda_version": "12.0.0"})
1735
+ with open(os.path.join(
1736
+ self.serving_artifacts_local_folder,
1737
+ "Dockerfile.litserve"), 'w'
1738
+ ) as output_file:
1739
+ output_file.write(dockerfile_content)
1740
+
1741
+ os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
1742
+
1743
+ ############################################
1744
+ # actually deploy the inference service #
1745
+ ############################################
1746
+ start_time = time.time()
1747
+ from retrain_pipelines.utils.docker import \
1748
+ build_and_run_docker, print_container_log_tail, \
1749
+ cleanup_docker
1750
+ from retrain_pipelines.model.litserve import \
1751
+ endpoint_started, endpoint_is_ready
1752
+
1753
+ self.port = 8765
1754
+ HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
1755
+ os.getenv(
1756
+ "HF_HUB_CACHE",
1757
+ os.path.join(os.getenv("HF_HOME",
1758
+ "~/.cache/huggingface"),
1759
+ "hub")
1760
+ )))
1761
+ print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
1762
+ image_name = container_name = "litserve-model"
1763
+
1764
+ serving_container = build_and_run_docker(
1765
+ image_name=image_name, image_tag="1.0",
1766
+ build_path=self.serving_artifacts_local_folder,
1767
+ dockerfile="Dockerfile.litserve",
1768
+ ports_publish_dict={'8000/tcp': self.port},
1769
+ env_vars_dict={
1770
+ "HF_HUB_CACHE": "/huggingface_hub_cache",
1771
+ "HF_TOKEN": os.getenv("HF_TOKEN")
1772
+ },
1773
+ volumes_dict={
1774
+ self.sft_model_dir:
1775
+ {"bind": "/FuncCallAdapter",
1776
+ "mode": "ro"},
1777
+ HF_HUB_CACHE:
1778
+ {"bind": "/huggingface_hub_cache",
1779
+ "mode": "ro"}
1780
+ }
1781
+ )
1782
+
1783
+ if not serving_container:
1784
+ print("failed spinning the LitServe container",
1785
+ file=sys.stderr)
1786
+ self.local_serve_is_ready = \
1787
+ LocalServeReadinessEnum.FAILURE
1788
+ try:
1789
+ cleanup_docker(
1790
+ container_name=container_name,
1791
+ image_name=f"{image_name}:1.0",
1792
+ no_pruning=True # for intermediate layers recycling
1793
+ # (during later re-runs)
1794
+ # to avoid long rebuild time
1795
+ # of exactly the same.
1796
+ )
1797
+ except Exception as cleanup_ex:
1798
+ # fail silently
1799
+ pass
1800
+ else:
1801
+ print("Awaiting endpoint launch..")
1802
+ start_time = time.time()
1803
+ if not endpoint_started(
1804
+ container_name, port=self.port, timeout=10*60
1805
+ ):
1806
+ print(
1807
+ f"The endpoint '{container_name}' " +
1808
+ f"did not start.")
1809
+ self.local_serve_is_ready = \
1810
+ LocalServeReadinessEnum.FAILURE
1811
+ # health check on the spun-up endpoint
1812
+ elif endpoint_is_ready(port=self.port):
1813
+ self.local_serve_is_ready = \
1814
+ LocalServeReadinessEnum.SUCCESS
1815
+ elapsed_time = time.time() - start_time
1816
+ print("deploy_local - Elapsed time: " +
1817
+ f"{elapsed_time:.2f} seconds")
1818
+ ############################################
1819
+ else:
1820
+ # env doesn't have docker
1821
+ self.local_serve_is_ready = \
1822
+ LocalServeReadinessEnum.FAILURE_NO_DOCKER
1823
+
1824
+ if LocalServeReadinessEnum.SUCCESS == self.local_serve_is_ready:
1825
+ from retrain_pipelines.model.litserve.litserve_datamodel \
1826
+ import Response
1827
+
1828
+ import requests
1829
+
1830
+ url = f"http://localhost:{self.port}/predict"
1831
+ headers = {"accept": "application/x-www-form-urlencoded"}
1832
+
1833
+ try:
1834
+ start_time = time.time()
1835
+ data = {
1836
+ "adapter_name": "func_caller",
1837
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1838
+ }
1839
+ print(f"inference test - data: {data}")
1840
+ response = requests.post(url, headers=headers, data=data)
1841
+ parsed_response = Response(**{"output": response.json()})
1842
+ elapsed_time = time.time() - start_time
1843
+ print("parsed_response ('func_caller' adapter ON) :" +
1844
+ str(parsed_response) +
1845
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1846
+
1847
+ start_time = time.time()
1848
+ data = {
1849
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1850
+ }
1851
+ print(f"inference test - data: {data}")
1852
+ response = requests.post(url, headers=headers, data=data)
1853
+ parsed_response = Response(**{"output": response.json()})
1854
+ elapsed_time = time.time() - start_time
1855
+ print(f"parsed_response (no adapter) : {parsed_response}" +
1856
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1857
+
1858
+ except Exception as ex:
1859
+ print(ex, file=sys.stderr)
1860
+ traceback.print_tb(ex.__traceback__, file=sys.stderr)
1861
+ self.local_serve_is_ready = \
1862
+ LocalServeReadinessEnum.FAILURE
1863
+ pass
1864
+
1865
+ try:
1866
+ cleanup_docker(
1867
+ container_name=container_name,
1868
+ image_name=f"{image_name}:1.0",
1869
+ no_pruning=True # for intermediate layers recycling
1870
+ # (during later re-runs)
1871
+ # to avoid long rebuild time
1872
+ # of exactly the same.
1873
+ )
1874
+ except Exception as cleanup_ex:
1875
+ # fail silently
1876
+ pass
1877
+
1878
+ self.next(self.pipeline_card)
1879
+
1880
+
1881
+ @card(id='default')
1882
+ @card(type='html', id='custom')
1883
+ @step
1884
+ def pipeline_card(self):
1885
+ import re
1886
+ import datetime
1887
+ import importlib.metadata
1888
+
1889
+ #############################
1890
+ # case of user-provided #
1891
+ # documentation artifact(s) #
1892
+ #############################
1893
+ # note that user can provide either
1894
+ # 'pipeline_card.py' or 'template.html'
1895
+ # or 'dataset_readme.py'
1896
+ # or 'dataset_readme_template.md'
1897
+ # or 'model_readme.py'
1898
+ # or 'model_readme_template.md'
1899
+ # or any combination of those
1900
+ # when specifying custom
1901
+ # 'pipeline_card_artifacts_path'
1902
+ if "template.html" in os.listdir(
1903
+ self.pipeline_card_artifacts_path
1904
+ ):
1905
+ template_dir = self.pipeline_card_artifacts_path
1906
+ else:
1907
+ template_dir = os.path.dirname(
1908
+ importlib.util.find_spec(
1909
+ f"retrain_pipelines.pipeline_card."+
1910
+ f"{os.getenv('retrain_pipeline_type')}"
1911
+ ).origin)
1912
+ #############################
1913
+ if "pipeline_card.py" in os.listdir(
1914
+ self.pipeline_card_artifacts_path
1915
+ ):
1916
+ from retrain_pipelines.utils import get_get_html
1917
+ get_html = \
1918
+ get_get_html(self.pipeline_card_artifacts_path)
1919
+ else:
1920
+ from retrain_pipelines.pipeline_card import \
1921
+ get_html
1922
+ from retrain_pipelines.pipeline_card.helpers import \
1923
+ mf_dag_svg
1924
+ #############################
1925
+
1926
+
1927
+ #############################
1928
+ ## "default" card ##
1929
+ #############################
1930
+ self.metadata = {
1931
+ "name": "TabNet Model",
1932
+ "version": "1.0",
1933
+ "retrain_pipelines": f"retrain-pipelines {__version__}",
1934
+ "retrain_pipeline_type": os.environ["retrain_pipeline_type"],
1935
+ "description": "A PyTorch TabNet model retrained",
1936
+ "authors": [current.username],
1937
+ "tags": ["classification", "tabnet"],
1938
+ "license": "MIT License",
1939
+ "data_augmentation": [
1940
+ {
1941
+ "name": "Augmentation",
1942
+ "description": "Truncating queries and " + \
1943
+ "associate those to " + \
1944
+ "no tool-call answers. " + \
1945
+ "Intent being to instruct on " + \
1946
+ "not hallucinating missing " + \
1947
+ "tool-calls parameters values."
1948
+ },
1949
+ {
1950
+ "name": "Enrichment",
1951
+ "description": "Addition of records " + \
1952
+ "from an external data-source. " + \
1953
+ "Here to instruct on no tool-call."
1954
+ }
1955
+ ],
1956
+ "references": [
1957
+ {
1958
+ "title": "Base model",
1959
+ "link": f"https://hf.co/{self.hf_base_model_dict['repo_id']}"
1960
+ },
1961
+ {
1962
+ "title": "Function-calling dataset",
1963
+ "link": f"https://hf.co/{self.hf_dataset_dict['repo_id']}"
1964
+ },
1965
+ {
1966
+ "title": "Data-enrichment dataset",
1967
+ "link": f"https://hf.co/{self.hf_enrich_dataset_dict['repo_id']}"
1968
+ },
1969
+ {
1970
+ "title": "Unsloth",
1971
+ "link": "https://unsloth.ai/blog/contpretraining"
1972
+ }
1973
+ ]
1974
+ }
1975
+
1976
+ current.card['default'].append(Markdown(
1977
+ "model_version_blessed : **%s**" % str(self.model_version_blessed)))
1978
+ current.card['default'].append(Artifact(
1979
+ {"model_version_blessed": self.model_version_blessed}))
1980
+
1981
+ current.card['default'].append(
1982
+ Image.from_matplotlib(self.sft_log_history_fig))
1983
+ current.card['default'].append(
1984
+ Image.from_matplotlib(self.validation_completions_fig))
1985
+ #############################
1986
+
1987
+ #############################
1988
+ ## html "custom" card ##
1989
+ #############################
1990
+ dt = datetime.datetime.now(tz=datetime.timezone.utc)
1991
+ formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
1992
+ task_obj_python_cmd = f"metaflow.Task(" + \
1993
+ f"\"{current.pathspec}\", " + \
1994
+ f"attempt={str(current.retry_count)})"
1995
+ params={
1996
+ 'template_dir': template_dir,
1997
+ 'title': f"{current.flow_name}",
1998
+ "subtitle": f"(flow run # {len(list(current.run.parent.runs()))}," + \
1999
+ f" run_id: {str(current.run.id)} - {formatted_dt})",
2000
+
2001
+ # blessed status / current_blessed version
2002
+ 'model_version_blessed': self.model_version_blessed,
2003
+ 'current_blessed_version_label': (
2004
+ self.current_blessed_version_dict["version_label"]
2005
+ if self.current_blessed_version_dict
2006
+ else None
2007
+ ),
2008
+ 'current_blessed_commit_datetime': (
2009
+ self.current_blessed_version_dict["commit_datetime"]
2010
+ if self.current_blessed_version_dict
2011
+ else None
2012
+ ),
2013
+ 'current_blessed_model_commit_hash': (
2014
+ self.current_blessed_version_dict["commit_hash"]
2015
+ if self.current_blessed_version_dict
2016
+ else None
2017
+ ),
2018
+ 'current_blessed_run': self.current_blessed_run,
2019
+
2020
+ 'LocalServeReadinessEnum': LocalServeReadinessEnum,
2021
+ 'local_serve_is_ready': self.local_serve_is_ready,
2022
+ # EDA
2023
+ 'main_dataset_repo_id': self.hf_dataset['repo_id'],
2024
+ 'main_dataset_commit_hash': self.hf_dataset_dict['commit_hash'],
2025
+ 'main_dataset_commit_datetime': \
2026
+ self.hf_dataset_dict['commit_datetime'],
2027
+
2028
+ 'records_count': self.records_count,
2029
+ 'data_schema': self.data_schema,
2030
+ 'answers_tools_count_fig': self.answers_tools_count_fig,
2031
+ 'words_count_fig': self.words_count_fig,
2032
+
2033
+ # model training
2034
+ 'dataset_repo_id': self.dataset_repo_id,
2035
+ 'dataset_version_label': self.dataset_commit_dict["version_label"],
2036
+ 'dataset_commit_datetime': self.dataset_commit_dict["commit_datetime"],
2037
+ 'dataset_commit_hash': self.dataset_commit_dict["commit_hash"],
2038
+ 'dataset_augmentation_rate': self.actual_augmentation_rate,
2039
+ 'dataset_enrichment_rate': self.enrichment_rate,
2040
+
2041
+ 'model_repo_id': self.model_repo_id,
2042
+ 'model_version_label': self.model_commit_dict["version_label"],
2043
+ 'model_commit_datetime': self.model_commit_dict["commit_datetime"],
2044
+ 'model_commit_hash': self.model_commit_dict["commit_hash"],
2045
+
2046
+ 'cpt_log_history_fig': self.cpt_log_history_fig,
2047
+ 'sft_log_history_fig': self.sft_log_history_fig,
2048
+
2049
+ 'validation_completions_fig': self.validation_completions_fig,
2050
+
2051
+ 'pipeline_parameters_dict': {"cpt": self.cpt_training_args,
2052
+ "sft": self.sft_training_args},
2053
+
2054
+ 'metrics_dict': self.perf_metrics,
2055
+
2056
+ 'task_obj_python_cmd': task_obj_python_cmd,
2057
+ 'dag_svg': mf_dag_svg(self)
2058
+ }
2059
+ self.html = get_html(params)
2060
+ #############################
2061
+ current
2062
+ #############################
2063
+
2064
+ self.next(self.pipeline_to_hub)
2065
+
2066
+
2067
+ @step
2068
+ def pipeline_to_hub(self):
2069
+ """
2070
+ publish versioned source-code and pipeline-card
2071
+ for ths run on the Hugging Face Hub.
2072
+ """
2073
+
2074
+ model_commit_datetime = \
2075
+ self.model_commit_dict["commit_datetime"]
2076
+ timestamp_str = \
2077
+ "{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
2078
+ "{:03d}".format(model_commit_datetime.microsecond//1000) + \
2079
+ "_UTC"
2080
+ subfolder_name = \
2081
+ "v" + self.model_commit_dict["version_label"] + \
2082
+ "_" + timestamp_str
2083
+ commit_datetime = datetime.utcnow()
2084
+
2085
+ ###############################
2086
+ # source-code #
2087
+ ###############################
2088
+ # We upload only herein file #
2089
+ # plus user-provided versions #
2090
+ # of the customizable ones #
2091
+ # (if any). #
2092
+ ###############################
2093
+ custom_source_files = [os.path.abspath(__file__)]
2094
+ if (
2095
+ self.pipeline_card_artifacts_path != \
2096
+ self.default_pipeline_card_module_dir
2097
+ ):
2098
+ candidate_source_files = [
2099
+ "pipeline_card.py",
2100
+ "template.html",
2101
+ "dataset_readme.py",
2102
+ "dataset_readme_template.md",
2103
+ "model_readme.py",
2104
+ "model_readme_template.md"
2105
+ ]
2106
+ for candidate_source_file in candidate_source_files:
2107
+ file_fullpath = os.path.join(
2108
+ self.pipeline_card_artifacts_path,
2109
+ candidate_source_file)
2110
+ if os.path.exists(file_fullpath):
2111
+ custom_source_files.append(file_fullpath)
2112
+
2113
+ source_code_commit_hash = \
2114
+ push_files_to_hub_repo_branch(
2115
+ repo_id=self.model_repo_id,
2116
+ branch_name="retrain-pipelines_source-code",
2117
+ file_fullnames=custom_source_files,
2118
+ include_requirements_txt=True,
2119
+ path_in_repo=subfolder_name,
2120
+ commit_message=\
2121
+ "source-code for model version " + \
2122
+ subfolder_name + \
2123
+ f"- retrain-pipelines {__version__}",
2124
+ repo_type="model",
2125
+ hf_token=os.getenv("HF_TOKEN", None)
2126
+ )
2127
+ print(source_code_commit_hash)
2128
+ self.source_code_commit_dict = {
2129
+ "repo_id": self.model_repo_id,
2130
+ "branch_name": "retrain-pipelines_source-code",
2131
+ "commit_datetime": commit_datetime,
2132
+ "commit_hash": source_code_commit_hash
2133
+ }
2134
+ ###############################
2135
+
2136
+ ###############################
2137
+ # pipeline-card #
2138
+ ###############################
2139
+ pipeline_card_fullname = None
2140
+ for run_step in current.run.steps():
2141
+ task = list(run_step.tasks())[0]
2142
+ task_name = task.path_components[2]
2143
+ if "pipeline_card" == task_name:
2144
+ pipeline_card = get_cards(
2145
+ task, id='custom', type='html')[0]
2146
+ pipeline_card_fullname = os.path.realpath(
2147
+ os.path.join(
2148
+ task.metadata_dict.get("ds-root", None),
2149
+ mf_config.CARD_SUFFIX, pipeline_card.path
2150
+ ))
2151
+ print(pipeline_card_fullname)
2152
+ break
2153
+ pipeline_card_commit_hash = \
2154
+ push_files_to_hub_repo_branch(
2155
+ repo_id=self.model_repo_id,
2156
+ branch_name="retrain-pipelines_pipeline-card",
2157
+ file_fullnames=[pipeline_card_fullname],
2158
+ path_in_repo=subfolder_name,
2159
+ commit_message=\
2160
+ "pipeline-card for model version " + \
2161
+ subfolder_name + \
2162
+ f"- retrain-pipelines {__version__}",
2163
+ repo_type="model",
2164
+ hf_token=os.getenv("HF_TOKEN", None)
2165
+ )
2166
+ print(pipeline_card_commit_hash)
2167
+ self.pipeline_card_commit_dict = {
2168
+ "repo_id": self.model_repo_id,
2169
+ "branch_name": "retrain-pipelines_pipeline-card",
2170
+ "commit_datetime": commit_datetime,
2171
+ "commit_hash": pipeline_card_commit_hash
2172
+ }
2173
+ ###############################
2174
+
2175
+ self.next(self.deploy)
2176
+
2177
+
2178
+ @step
2179
+ def deploy(self):
2180
+ """
2181
+ placeholder for the serving SDK deploy call
2182
+ (on the target production platform).
2183
+ Include any artifact you want,
2184
+ consider including the portable pipelione-card
2185
+ itself !
2186
+ """
2187
+
2188
+ if (
2189
+ self.model_version_blessed and
2190
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2191
+ ):
2192
+ pass # your code here
2193
+
2194
+ self.next(self.load_test)
2195
+
2196
+
2197
+ @step
2198
+ def load_test(self):
2199
+ """
2200
+ placeholder
2201
+ """
2202
+
2203
+ if (
2204
+ self.model_version_blessed and
2205
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2206
+ ):
2207
+ pass # your code here
2208
+
2209
+ self.next(self.end)
2210
+
2211
+
2212
+ @step
2213
+ def end(self):
2214
+ pass
2215
+
2216
+
2217
+ if __name__ == "__main__":
2218
+ UnslothFuncCallFlow()
2219
+