source-code for model version v0.18_20250323_235054255_UTC- retrain-pipelines 0.1.1
Browse files
v0.18_20250323_235054255_UTC/requirements.txt
ADDED
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
accelerate==1.5.2
|
3 |
+
aiohappyeyeballs==2.6.1
|
4 |
+
aiohttp==3.11.14
|
5 |
+
aiosignal==1.3.2
|
6 |
+
alabaster==1.0.0
|
7 |
+
albucore==0.0.23
|
8 |
+
albumentations==2.0.5
|
9 |
+
ale-py==0.10.2
|
10 |
+
altair==5.5.0
|
11 |
+
annotated-types==0.7.0
|
12 |
+
anyio==4.9.0
|
13 |
+
argon2-cffi==23.1.0
|
14 |
+
argon2-cffi-bindings==21.2.0
|
15 |
+
array_record==0.7.1
|
16 |
+
arviz==0.21.0
|
17 |
+
astropy==7.0.1
|
18 |
+
astropy-iers-data==0.2025.3.17.0.34.53
|
19 |
+
astunparse==1.6.3
|
20 |
+
atpublic==5.1
|
21 |
+
attrs==25.3.0
|
22 |
+
audioread==3.0.1
|
23 |
+
autograd==1.7.0
|
24 |
+
babel==2.17.0
|
25 |
+
backcall==0.2.0
|
26 |
+
beautifulsoup4==4.13.3
|
27 |
+
betterproto==2.0.0b6
|
28 |
+
bigframes==1.40.0
|
29 |
+
bigquery-magics==0.8.0
|
30 |
+
bitsandbytes==0.45.3
|
31 |
+
bleach==6.2.0
|
32 |
+
blinker==1.9.0
|
33 |
+
blis==1.2.0
|
34 |
+
blosc2==3.2.0
|
35 |
+
bokeh==3.6.3
|
36 |
+
boto3==1.37.18
|
37 |
+
botocore==1.37.18
|
38 |
+
Bottleneck==1.4.2
|
39 |
+
bqplot==0.12.44
|
40 |
+
branca==0.8.1
|
41 |
+
CacheControl==0.14.2
|
42 |
+
cachetools==5.5.2
|
43 |
+
catalogue==2.0.10
|
44 |
+
certifi==2025.1.31
|
45 |
+
cffi==1.17.1
|
46 |
+
chardet==5.2.0
|
47 |
+
charset-normalizer==3.4.1
|
48 |
+
chex==0.1.89
|
49 |
+
clarabel==0.10.0
|
50 |
+
click==8.1.8
|
51 |
+
cloudpathlib==0.21.0
|
52 |
+
cloudpickle==3.1.1
|
53 |
+
cmake==3.31.6
|
54 |
+
cmdstanpy==1.2.5
|
55 |
+
colorama==0.4.6
|
56 |
+
colorcet==3.1.0
|
57 |
+
colorlover==0.3.0
|
58 |
+
colour==0.1.5
|
59 |
+
comm==0.2.2
|
60 |
+
community==1.0.0b1
|
61 |
+
confection==0.1.5
|
62 |
+
cons==0.4.6
|
63 |
+
contourpy==1.3.1
|
64 |
+
cramjam==2.9.1
|
65 |
+
cryptography==43.0.3
|
66 |
+
cuda-python==12.6.0
|
67 |
+
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.2.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
68 |
+
cudf-polars-cu12==24.12.0
|
69 |
+
cufflinks==0.17.3
|
70 |
+
cuml-cu12==25.2.1
|
71 |
+
cupy-cuda12x==13.3.0
|
72 |
+
cut-cross-entropy==25.1.1
|
73 |
+
cuvs-cu12==25.2.1
|
74 |
+
cvxopt==1.3.2
|
75 |
+
cvxpy==1.6.4
|
76 |
+
cycler==0.12.1
|
77 |
+
cyipopt==1.5.0
|
78 |
+
cymem==2.0.11
|
79 |
+
Cython==3.0.12
|
80 |
+
dask==2024.12.1
|
81 |
+
dask-cuda==25.2.0
|
82 |
+
dask-cudf-cu12==25.2.2
|
83 |
+
dask-expr==1.1.21
|
84 |
+
datascience==0.17.6
|
85 |
+
datasets==3.1.0
|
86 |
+
db-dtypes==1.4.2
|
87 |
+
dbus-python==1.2.18
|
88 |
+
debugpy==1.8.0
|
89 |
+
decorator==4.4.2
|
90 |
+
defusedxml==0.7.1
|
91 |
+
Deprecated==1.2.18
|
92 |
+
diffusers==0.32.2
|
93 |
+
dill==0.3.8
|
94 |
+
distributed==2024.12.1
|
95 |
+
distributed-ucxx-cu12==0.42.0
|
96 |
+
distro==1.9.0
|
97 |
+
dlib==19.24.2
|
98 |
+
dm-tree==0.1.9
|
99 |
+
docker==7.1.0
|
100 |
+
docker-pycreds==0.4.0
|
101 |
+
docstring_parser==0.16
|
102 |
+
docutils==0.21.2
|
103 |
+
dopamine_rl==4.1.2
|
104 |
+
duckdb==1.2.1
|
105 |
+
earthengine-api==1.5.7
|
106 |
+
easydict==1.13
|
107 |
+
editdistance==0.8.1
|
108 |
+
eerepr==0.1.1
|
109 |
+
einops==0.8.1
|
110 |
+
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
|
111 |
+
entrypoints==0.4
|
112 |
+
et_xmlfile==2.0.0
|
113 |
+
etils==1.12.2
|
114 |
+
etuples==0.3.9
|
115 |
+
Farama-Notifications==0.0.4
|
116 |
+
fastai==2.7.19
|
117 |
+
fastapi==0.115.11
|
118 |
+
fastcore==1.7.29
|
119 |
+
fastdownload==0.0.7
|
120 |
+
fastjsonschema==2.21.1
|
121 |
+
fastprogress==1.0.3
|
122 |
+
fastrlock==0.8.3
|
123 |
+
filelock==3.18.0
|
124 |
+
firebase-admin==6.7.0
|
125 |
+
Flask==3.1.0
|
126 |
+
flatbuffers==25.2.10
|
127 |
+
flax==0.10.4
|
128 |
+
folium==0.19.5
|
129 |
+
fonttools==4.56.0
|
130 |
+
frozendict==2.4.6
|
131 |
+
frozenlist==1.5.0
|
132 |
+
fsspec==2024.9.0
|
133 |
+
future==1.0.0
|
134 |
+
gast==0.6.0
|
135 |
+
GDAL==3.6.4
|
136 |
+
gdown==5.2.0
|
137 |
+
geemap==0.35.3
|
138 |
+
geocoder==1.38.1
|
139 |
+
geographiclib==2.0
|
140 |
+
geopandas==1.0.1
|
141 |
+
geopy==2.4.1
|
142 |
+
gin-config==0.5.0
|
143 |
+
gitdb==4.0.12
|
144 |
+
GitPython==3.1.44
|
145 |
+
glob2==0.7
|
146 |
+
google==2.0.3
|
147 |
+
google-ai-generativelanguage==0.6.15
|
148 |
+
google-api-core==2.24.2
|
149 |
+
google-api-python-client==2.164.0
|
150 |
+
google-auth==2.38.0
|
151 |
+
google-auth-httplib2==0.2.0
|
152 |
+
google-auth-oauthlib==1.2.1
|
153 |
+
google-cloud-aiplatform==1.84.0
|
154 |
+
google-cloud-bigquery==3.29.0
|
155 |
+
google-cloud-bigquery-connection==1.18.2
|
156 |
+
google-cloud-bigquery-storage==2.29.1
|
157 |
+
google-cloud-bigtable==2.29.0
|
158 |
+
google-cloud-core==2.4.3
|
159 |
+
google-cloud-dataproc==5.18.1
|
160 |
+
google-cloud-datastore==2.20.2
|
161 |
+
google-cloud-firestore==2.20.1
|
162 |
+
google-cloud-functions==1.20.2
|
163 |
+
google-cloud-iam==2.18.2
|
164 |
+
google-cloud-language==2.17.1
|
165 |
+
google-cloud-pubsub==2.28.0
|
166 |
+
google-cloud-resource-manager==1.14.2
|
167 |
+
google-cloud-spanner==3.53.0
|
168 |
+
google-cloud-storage==2.19.0
|
169 |
+
google-cloud-translate==3.20.2
|
170 |
+
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
|
171 |
+
google-crc32c==1.7.0
|
172 |
+
google-genai==1.5.0
|
173 |
+
google-generativeai==0.8.4
|
174 |
+
google-pasta==0.2.0
|
175 |
+
google-resumable-media==2.7.2
|
176 |
+
google-spark-connect==0.5.2
|
177 |
+
googleapis-common-protos==1.69.2
|
178 |
+
googledrivedownloader==1.1.0
|
179 |
+
graphviz==0.20.3
|
180 |
+
greenlet==3.1.1
|
181 |
+
grpc-google-iam-v1==0.14.2
|
182 |
+
grpc-interceptor==0.15.4
|
183 |
+
grpcio==1.71.0
|
184 |
+
grpcio-status==1.71.0
|
185 |
+
grpclib==0.4.7
|
186 |
+
gspread==6.2.0
|
187 |
+
gspread-dataframe==4.0.0
|
188 |
+
gym==0.25.2
|
189 |
+
gym-notices==0.0.8
|
190 |
+
gymnasium==1.1.1
|
191 |
+
h11==0.14.0
|
192 |
+
h2==4.2.0
|
193 |
+
h5netcdf==1.6.1
|
194 |
+
h5py==3.13.0
|
195 |
+
hdbscan==0.8.40
|
196 |
+
hf_transfer==0.1.9
|
197 |
+
highspy==1.9.0
|
198 |
+
holidays==0.69
|
199 |
+
holoviews==1.20.2
|
200 |
+
hpack==4.1.0
|
201 |
+
html5lib==1.1
|
202 |
+
httpcore==1.0.7
|
203 |
+
httpimport==1.4.1
|
204 |
+
httplib2==0.22.0
|
205 |
+
httptools==0.6.4
|
206 |
+
httpx==0.28.1
|
207 |
+
huggingface-hub==0.27.1
|
208 |
+
humanize==4.12.1
|
209 |
+
hyperframe==6.1.0
|
210 |
+
hyperopt==0.2.7
|
211 |
+
ibis-framework==9.5.0
|
212 |
+
idna==3.10
|
213 |
+
imageio==2.37.0
|
214 |
+
imageio-ffmpeg==0.6.0
|
215 |
+
imagesize==1.4.1
|
216 |
+
imbalanced-learn==0.13.0
|
217 |
+
immutabledict==4.2.1
|
218 |
+
importlib_metadata==8.6.1
|
219 |
+
importlib_resources==6.5.2
|
220 |
+
imutils==0.5.4
|
221 |
+
inflect==7.5.0
|
222 |
+
iniconfig==2.0.0
|
223 |
+
intel-cmplr-lib-ur==2025.0.5
|
224 |
+
intel-openmp==2025.0.5
|
225 |
+
ipyevents==2.0.2
|
226 |
+
ipyfilechooser==0.6.0
|
227 |
+
ipykernel==6.29.5
|
228 |
+
ipyleaflet==0.19.2
|
229 |
+
ipyparallel==8.8.0
|
230 |
+
ipython==7.34.0
|
231 |
+
ipython-genutils==0.2.0
|
232 |
+
ipython-sql==0.5.0
|
233 |
+
ipytree==0.2.2
|
234 |
+
ipywidgets==7.7.1
|
235 |
+
itsdangerous==2.2.0
|
236 |
+
jax==0.5.2
|
237 |
+
jax-cuda12-pjrt==0.5.1
|
238 |
+
jax-cuda12-plugin==0.5.1
|
239 |
+
jaxlib==0.5.1
|
240 |
+
jedi==0.19.2
|
241 |
+
jeepney==0.7.1
|
242 |
+
jellyfish==1.1.0
|
243 |
+
jieba==0.42.1
|
244 |
+
Jinja2==3.1.4
|
245 |
+
jiter==0.9.0
|
246 |
+
jmespath==1.0.1
|
247 |
+
joblib==1.4.2
|
248 |
+
jsonpatch==1.33
|
249 |
+
jsonpickle==4.0.2
|
250 |
+
jsonpointer==3.0.0
|
251 |
+
jsonschema==4.23.0
|
252 |
+
jsonschema-specifications==2024.10.1
|
253 |
+
jupyter-client==6.1.12
|
254 |
+
jupyter-console==6.1.0
|
255 |
+
jupyter-leaflet==0.19.2
|
256 |
+
jupyter-server==1.16.0
|
257 |
+
jupyter_core==5.7.2
|
258 |
+
jupyterlab_pygments==0.3.0
|
259 |
+
jupyterlab_widgets==3.0.13
|
260 |
+
kaggle==1.7.4.2
|
261 |
+
kagglehub==0.3.10
|
262 |
+
keras==3.8.0
|
263 |
+
keras-hub==0.18.1
|
264 |
+
keras-nlp==0.18.1
|
265 |
+
keyring==23.5.0
|
266 |
+
kiwisolver==1.4.8
|
267 |
+
langchain==0.3.20
|
268 |
+
langchain-core==0.3.45
|
269 |
+
langchain-text-splitters==0.3.6
|
270 |
+
langcodes==3.5.0
|
271 |
+
langsmith==0.3.15
|
272 |
+
language_data==1.3.0
|
273 |
+
launchpadlib==1.10.16
|
274 |
+
lazr.restfulclient==0.14.4
|
275 |
+
lazr.uri==1.0.6
|
276 |
+
lazy_loader==0.4
|
277 |
+
libclang==18.1.1
|
278 |
+
libcudf-cu12==24.12.0
|
279 |
+
libcugraph-cu12==25.2.0
|
280 |
+
libcuml-cu12==25.2.1
|
281 |
+
libcuvs-cu12==25.2.1
|
282 |
+
libkvikio-cu12==24.12.1
|
283 |
+
libraft-cu12==25.2.0
|
284 |
+
librosa==0.11.0
|
285 |
+
libucx-cu12==1.18.0
|
286 |
+
libucxx-cu12==0.42.0
|
287 |
+
lightgbm==4.5.0
|
288 |
+
linkify-it-py==2.0.3
|
289 |
+
litserve==0.2.6
|
290 |
+
llvmlite==0.43.0
|
291 |
+
locket==1.0.0
|
292 |
+
logical-unification==0.4.6
|
293 |
+
lxml==5.3.0
|
294 |
+
Mako==1.1.3
|
295 |
+
marisa-trie==1.2.1
|
296 |
+
Markdown==3.7
|
297 |
+
markdown-it-py==3.0.0
|
298 |
+
MarkupSafe==3.0.2
|
299 |
+
matplotlib==3.9.2
|
300 |
+
matplotlib-inline==0.1.7
|
301 |
+
matplotlib-venn==1.1.2
|
302 |
+
mdit-py-plugins==0.4.2
|
303 |
+
mdurl==0.1.2
|
304 |
+
metaflow==2.10.0
|
305 |
+
metaflow-card-html==1.0.2
|
306 |
+
miniKanren==1.0.3
|
307 |
+
missingno==0.5.2
|
308 |
+
mistune==3.1.2
|
309 |
+
mizani==0.13.1
|
310 |
+
mkl==2025.0.1
|
311 |
+
ml-dtypes==0.4.1
|
312 |
+
mlxtend==0.23.4
|
313 |
+
more-itertools==10.6.0
|
314 |
+
moviepy==1.0.3
|
315 |
+
mpmath==1.3.0
|
316 |
+
msgpack==1.1.0
|
317 |
+
multidict==6.2.0
|
318 |
+
multipledispatch==1.0.0
|
319 |
+
multiprocess==0.70.16
|
320 |
+
multitasking==0.0.11
|
321 |
+
murmurhash==1.0.12
|
322 |
+
music21==9.3.0
|
323 |
+
namex==0.0.8
|
324 |
+
narwhals==1.31.0
|
325 |
+
natsort==8.4.0
|
326 |
+
nbclassic==1.2.0
|
327 |
+
nbclient==0.10.2
|
328 |
+
nbconvert==7.16.6
|
329 |
+
nbformat==5.10.4
|
330 |
+
ndindex==1.9.2
|
331 |
+
nest-asyncio==1.6.0
|
332 |
+
networkx==3.2.1
|
333 |
+
nibabel==5.3.2
|
334 |
+
nltk==3.9.1
|
335 |
+
notebook==6.5.7
|
336 |
+
notebook_shim==0.2.4
|
337 |
+
numba==0.60.0
|
338 |
+
numba-cuda==0.2.0
|
339 |
+
numexpr==2.10.2
|
340 |
+
numpy==1.26.4
|
341 |
+
nvidia-cublas-cu12==12.4.5.8
|
342 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
343 |
+
nvidia-cuda-nvcc-cu12==12.5.82
|
344 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
345 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
346 |
+
nvidia-cudnn-cu12==9.1.0.70
|
347 |
+
nvidia-cufft-cu12==11.2.1.3
|
348 |
+
nvidia-curand-cu12==10.3.5.147
|
349 |
+
nvidia-cusolver-cu12==11.6.1.9
|
350 |
+
nvidia-cusparse-cu12==12.3.1.170
|
351 |
+
nvidia-cusparselt-cu12==0.6.2
|
352 |
+
nvidia-ml-py==12.570.86
|
353 |
+
nvidia-nccl-cu12==2.21.5
|
354 |
+
nvidia-nvcomp-cu12==4.1.0.6
|
355 |
+
nvidia-nvjitlink-cu12==12.4.127
|
356 |
+
nvidia-nvtx-cu12==12.4.127
|
357 |
+
nvtx==0.2.11
|
358 |
+
nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.2.0-py3-none-any.whl
|
359 |
+
oauth2client==4.1.3
|
360 |
+
oauthlib==3.2.2
|
361 |
+
openai==1.66.3
|
362 |
+
opencv-contrib-python==4.11.0.86
|
363 |
+
opencv-python==4.11.0.86
|
364 |
+
opencv-python-headless==4.11.0.86
|
365 |
+
openpyxl==3.1.5
|
366 |
+
opentelemetry-api==1.31.0
|
367 |
+
opentelemetry-sdk==1.31.0
|
368 |
+
opentelemetry-semantic-conventions==0.52b0
|
369 |
+
opt_einsum==3.4.0
|
370 |
+
optax==0.2.4
|
371 |
+
optree==0.14.1
|
372 |
+
orbax-checkpoint==0.11.9
|
373 |
+
orjson==3.10.15
|
374 |
+
osqp==0.6.7.post3
|
375 |
+
packaging==24.2
|
376 |
+
pandas==2.2.2
|
377 |
+
pandas-datareader==0.10.0
|
378 |
+
pandas-gbq==0.28.0
|
379 |
+
pandas-stubs==2.2.2.240909
|
380 |
+
pandocfilters==1.5.1
|
381 |
+
panel==1.6.1
|
382 |
+
param==2.2.0
|
383 |
+
parso==0.8.4
|
384 |
+
parsy==2.1
|
385 |
+
partd==1.4.2
|
386 |
+
pathlib==1.0.1
|
387 |
+
patsy==1.0.1
|
388 |
+
peewee==3.17.9
|
389 |
+
peft==0.14.0
|
390 |
+
pexpect==4.9.0
|
391 |
+
pickleshare==0.7.5
|
392 |
+
pillow==11.1.0
|
393 |
+
platformdirs==4.3.6
|
394 |
+
plotly==5.24.1
|
395 |
+
plotnine==0.14.5
|
396 |
+
pluggy==1.5.0
|
397 |
+
ply==3.11
|
398 |
+
polars==1.11.0
|
399 |
+
pooch==1.8.2
|
400 |
+
portpicker==1.5.2
|
401 |
+
preshed==3.0.9
|
402 |
+
prettytable==3.15.1
|
403 |
+
proglog==0.1.10
|
404 |
+
progressbar2==4.5.0
|
405 |
+
prometheus_client==0.21.1
|
406 |
+
promise==2.3
|
407 |
+
prompt_toolkit==3.0.50
|
408 |
+
propcache==0.3.0
|
409 |
+
prophet==1.1.6
|
410 |
+
proto-plus==1.26.1
|
411 |
+
protobuf==3.20.3
|
412 |
+
psutil==5.9.5
|
413 |
+
psycopg2==2.9.10
|
414 |
+
ptyprocess==0.7.0
|
415 |
+
py-cpuinfo==9.0.0
|
416 |
+
py4j==0.10.9.7
|
417 |
+
pyarrow==17.0.0
|
418 |
+
pyasn1==0.6.1
|
419 |
+
pyasn1_modules==0.4.1
|
420 |
+
pycairo==1.27.0
|
421 |
+
pycocotools==2.0.8
|
422 |
+
pycparser==2.22
|
423 |
+
pydantic==2.9.2
|
424 |
+
pydantic_core==2.23.4
|
425 |
+
pydata-google-auth==1.9.1
|
426 |
+
pydot==1.4.2
|
427 |
+
pydotplus==2.0.2
|
428 |
+
PyDrive==1.3.1
|
429 |
+
PyDrive2==1.21.3
|
430 |
+
pyerfa==2.0.1.5
|
431 |
+
pygame==2.6.1
|
432 |
+
pygit2==1.17.0
|
433 |
+
Pygments==2.18.0
|
434 |
+
PyGObject==3.42.0
|
435 |
+
PyJWT==2.10.1
|
436 |
+
pylibcudf-cu12==24.12.0
|
437 |
+
pylibcugraph-cu12==25.2.0
|
438 |
+
pylibraft-cu12==25.2.0
|
439 |
+
pymc==5.21.1
|
440 |
+
pymystem3==0.2.0
|
441 |
+
pynndescent==0.5.13
|
442 |
+
pynvjitlink-cu12==0.5.2
|
443 |
+
pynvml==12.0.0
|
444 |
+
pyogrio==0.10.0
|
445 |
+
Pyomo==6.8.2
|
446 |
+
PyOpenGL==3.1.9
|
447 |
+
pyOpenSSL==24.2.1
|
448 |
+
pyparsing==3.2.1
|
449 |
+
pyperclip==1.9.0
|
450 |
+
pyproj==3.7.1
|
451 |
+
pyshp==2.3.1
|
452 |
+
PySocks==1.7.1
|
453 |
+
pyspark==3.5.5
|
454 |
+
pytensor==2.28.3
|
455 |
+
pytest==8.3.3
|
456 |
+
python-apt==0.0.0
|
457 |
+
python-box==7.3.2
|
458 |
+
python-dateutil==2.8.2
|
459 |
+
python-dotenv==1.0.1
|
460 |
+
python-louvain==0.16
|
461 |
+
python-multipart==0.0.20
|
462 |
+
python-slugify==8.0.4
|
463 |
+
python-snappy==0.7.3
|
464 |
+
python-utils==3.9.1
|
465 |
+
pytz==2025.1
|
466 |
+
pyviz_comms==3.0.4
|
467 |
+
PyYAML==6.0.2
|
468 |
+
pyzmq==24.0.1
|
469 |
+
qdldl==0.1.7.post5
|
470 |
+
raft-dask-cu12==25.2.0
|
471 |
+
rapids-dask-dependency==25.2.0
|
472 |
+
ratelim==0.1.6
|
473 |
+
referencing==0.36.2
|
474 |
+
regex==2024.11.6
|
475 |
+
requests==2.32.3
|
476 |
+
requests-oauthlib==2.0.0
|
477 |
+
requests-toolbelt==1.0.0
|
478 |
+
requirements-parser==0.9.0
|
479 |
+
retrain_pipelines @ git+https://github.com/aurelienmorgan/retrain-pipelines.git@9a5b7f7992744e6739bc28700f3d5d8915796d71#subdirectory=pkg_src
|
480 |
+
rich==13.9.4
|
481 |
+
rmm-cu12==24.12.0
|
482 |
+
roman-numerals-py==3.1.0
|
483 |
+
rpds-py==0.23.1
|
484 |
+
rpy2==3.5.17
|
485 |
+
rsa==4.9
|
486 |
+
s3transfer==0.11.4
|
487 |
+
safetensors==0.5.3
|
488 |
+
scikit-image==0.25.2
|
489 |
+
scikit-learn==1.6.1
|
490 |
+
scipy==1.14.1
|
491 |
+
scooby==0.10.0
|
492 |
+
scs==3.2.7.post2
|
493 |
+
seaborn==0.13.2
|
494 |
+
SecretStorage==3.3.1
|
495 |
+
Send2Trash==1.8.3
|
496 |
+
sentence-transformers==3.4.1
|
497 |
+
sentencepiece==0.2.0
|
498 |
+
sentry-sdk==2.23.1
|
499 |
+
setproctitle==1.3.5
|
500 |
+
shap==0.47.0
|
501 |
+
shapely==2.0.7
|
502 |
+
shellingham==1.5.4
|
503 |
+
shtab==1.7.1
|
504 |
+
simple-parsing==0.1.7
|
505 |
+
simplejson==3.20.1
|
506 |
+
simsimd==6.2.1
|
507 |
+
six==1.17.0
|
508 |
+
sklearn-compat==0.1.3
|
509 |
+
sklearn-pandas==2.2.0
|
510 |
+
slicer==0.0.8
|
511 |
+
smart-open==7.1.0
|
512 |
+
smmap==5.0.2
|
513 |
+
sniffio==1.3.1
|
514 |
+
snowballstemmer==2.2.0
|
515 |
+
sortedcontainers==2.4.0
|
516 |
+
soundfile==0.13.1
|
517 |
+
soupsieve==2.6
|
518 |
+
soxr==0.5.0.post1
|
519 |
+
spacy==3.8.4
|
520 |
+
spacy-legacy==3.0.12
|
521 |
+
spacy-loggers==1.0.5
|
522 |
+
spanner-graph-notebook==1.1.3
|
523 |
+
Sphinx==8.2.3
|
524 |
+
sphinxcontrib-applehelp==2.0.0
|
525 |
+
sphinxcontrib-devhelp==2.0.0
|
526 |
+
sphinxcontrib-htmlhelp==2.1.0
|
527 |
+
sphinxcontrib-jsmath==1.0.1
|
528 |
+
sphinxcontrib-qthelp==2.0.0
|
529 |
+
sphinxcontrib-serializinghtml==2.0.0
|
530 |
+
SQLAlchemy==2.0.39
|
531 |
+
sqlglot==25.20.2
|
532 |
+
sqlparse==0.5.3
|
533 |
+
srsly==2.5.1
|
534 |
+
stanio==0.5.1
|
535 |
+
starlette==0.46.1
|
536 |
+
statsmodels==0.14.4
|
537 |
+
stringzilla==3.12.3
|
538 |
+
sympy==1.13.1
|
539 |
+
tables==3.10.2
|
540 |
+
tabulate==0.9.0
|
541 |
+
tbb==2022.0.0
|
542 |
+
tblib==3.0.0
|
543 |
+
tcmlib==1.2.0
|
544 |
+
tenacity==9.0.0
|
545 |
+
tensorboard==2.18.0
|
546 |
+
tensorboard-data-server==0.7.2
|
547 |
+
tensorflow==2.18.0
|
548 |
+
tensorflow-datasets==4.9.8
|
549 |
+
tensorflow-hub==0.16.1
|
550 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
551 |
+
tensorflow-metadata==1.16.1
|
552 |
+
tensorflow-probability==0.25.0
|
553 |
+
tensorflow-text==2.18.1
|
554 |
+
tensorstore==0.1.72
|
555 |
+
termcolor==2.5.0
|
556 |
+
terminado==0.18.1
|
557 |
+
text-unidecode==1.3
|
558 |
+
textblob==0.19.0
|
559 |
+
tf-slim==1.1.0
|
560 |
+
tf_keras==2.18.0
|
561 |
+
thinc==8.3.4
|
562 |
+
threadpoolctl==3.6.0
|
563 |
+
tifffile==2025.3.13
|
564 |
+
timm==1.0.15
|
565 |
+
tinycss2==1.4.0
|
566 |
+
tokenizers==0.20.3
|
567 |
+
toml==0.10.2
|
568 |
+
toolz==0.12.1
|
569 |
+
torch==2.5.0
|
570 |
+
torchsummary==1.5.1
|
571 |
+
torchvision==0.20.0
|
572 |
+
tornado==6.4.2
|
573 |
+
tqdm==4.67.1
|
574 |
+
traitlets==5.7.1
|
575 |
+
traittypes==0.2.1
|
576 |
+
transformers==4.46.2
|
577 |
+
treelite==4.4.1
|
578 |
+
treescope==0.1.9
|
579 |
+
triton==3.1.0
|
580 |
+
trl==0.12.0
|
581 |
+
tweepy==4.15.0
|
582 |
+
typeguard==4.4.2
|
583 |
+
typer==0.15.2
|
584 |
+
types-pytz==2025.1.0.20250318
|
585 |
+
types-setuptools==76.0.0.20250313
|
586 |
+
typing_extensions==4.12.2
|
587 |
+
tyro==0.9.17
|
588 |
+
tzdata==2025.1
|
589 |
+
tzlocal==5.3.1
|
590 |
+
uc-micro-py==1.0.3
|
591 |
+
ucx-py-cu12==0.42.0
|
592 |
+
ucxx-cu12==0.42.0
|
593 |
+
umap-learn==0.5.7
|
594 |
+
umf==0.9.1
|
595 |
+
unsloth @ git+https://github.com/unslothai/unsloth.git@3a1e7ef8299f3c96fa6e8de11fd0772af3cbc83f
|
596 |
+
unsloth_zoo==2024.11.4
|
597 |
+
uritemplate==4.1.1
|
598 |
+
urllib3==2.3.0
|
599 |
+
uvicorn==0.34.0
|
600 |
+
uvloop==0.21.0
|
601 |
+
vega-datasets==0.9.0
|
602 |
+
wadllib==1.3.6
|
603 |
+
wandb==0.19.8
|
604 |
+
wasabi==1.1.3
|
605 |
+
watchfiles==1.0.4
|
606 |
+
wcwidth==0.2.13
|
607 |
+
weasel==0.4.1
|
608 |
+
webcolors==24.11.1
|
609 |
+
webencodings==0.5.1
|
610 |
+
websocket-client==1.8.0
|
611 |
+
websockets==14.2
|
612 |
+
Werkzeug==3.1.3
|
613 |
+
widgetsnbextension==3.6.10
|
614 |
+
wordcloud==1.9.4
|
615 |
+
wrapt==1.17.2
|
616 |
+
xarray==2025.1.2
|
617 |
+
xarray-einstats==0.8.0
|
618 |
+
xformers==0.0.28.post2
|
619 |
+
xgboost==2.1.4
|
620 |
+
xlrd==2.0.1
|
621 |
+
xxhash==3.5.0
|
622 |
+
xyzservices==2025.1.0
|
623 |
+
yarl==1.18.3
|
624 |
+
yellowbrick==1.5
|
625 |
+
yfinance==0.2.54
|
626 |
+
zict==3.0.0
|
627 |
+
zipp==3.21.0
|
628 |
+
zstandard==0.23.0
|
v0.18_20250323_235054255_UTC/retraining_pipeline.py
ADDED
@@ -0,0 +1,2219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from unsloth import FastLanguageModel, \
|
3 |
+
is_bfloat16_supported, UnslothTrainer, \
|
4 |
+
UnslothTrainingArguments
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
|
11 |
+
import gc
|
12 |
+
import json
|
13 |
+
import time
|
14 |
+
import shutil
|
15 |
+
import logging
|
16 |
+
import traceback
|
17 |
+
import subprocess
|
18 |
+
import importlib.util
|
19 |
+
from enum import Enum
|
20 |
+
from io import StringIO
|
21 |
+
from textwrap import dedent
|
22 |
+
from datetime import datetime
|
23 |
+
from contextlib import redirect_stdout
|
24 |
+
|
25 |
+
import numpy as np
|
26 |
+
import pandas as pd
|
27 |
+
|
28 |
+
import polars as pl
|
29 |
+
from polars.exceptions import ComputeError
|
30 |
+
|
31 |
+
import matplotlib
|
32 |
+
import matplotlib.pyplot as plt
|
33 |
+
|
34 |
+
from jinja2 import Environment, FileSystemLoader
|
35 |
+
|
36 |
+
from metaflow import FlowSpec, step, Parameter, JSONType, \
|
37 |
+
IncludeFile, current, metaflow_config as mf_config, \
|
38 |
+
resources, Flow, Task, card
|
39 |
+
from metaflow.current import Current
|
40 |
+
from metaflow.cards import Image, Table, Markdown, \
|
41 |
+
Artifact, get_cards
|
42 |
+
|
43 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
44 |
+
from datasets.config import HF_DATASETS_CACHE, HF_CACHE_HOME
|
45 |
+
from huggingface_hub import list_repo_commits
|
46 |
+
from transformers import AutoTokenizer
|
47 |
+
from transformers.utils import logging as hf_logging
|
48 |
+
|
49 |
+
from retrain_pipelines import __version__
|
50 |
+
from retrain_pipelines.dataset.hf_utils import get_lazy_df, \
|
51 |
+
get_column_info, iterable_dataset_multi_buffer_sampler, \
|
52 |
+
push_dataset_version_to_hub
|
53 |
+
from retrain_pipelines.dataset.tool_calls import \
|
54 |
+
get_unique_tools, count_tool_occurrences, \
|
55 |
+
plot_tools_occurences, column_words_stats, \
|
56 |
+
plot_words_count
|
57 |
+
from retrain_pipelines.utils.hf_utils import \
|
58 |
+
get_new_repo_minor_version, push_files_to_hub_repo_branch
|
59 |
+
from retrain_pipelines.utils import create_requirements
|
60 |
+
|
61 |
+
|
62 |
+
class LocalServeReadinessEnum(Enum):
|
63 |
+
"""
|
64 |
+
tracking local-serve (infra-validation)
|
65 |
+
status using a "3+"-states enum :
|
66 |
+
- "-1" for "not applicable"
|
67 |
+
(i.e. "model version not blessed"),
|
68 |
+
- "0/1" bool for failure/success.
|
69 |
+
"""
|
70 |
+
NOT_APPLICABLE = -1
|
71 |
+
FAILURE = 0
|
72 |
+
FAILURE_NO_DOCKER = 2
|
73 |
+
SUCCESS = 1
|
74 |
+
|
75 |
+
|
76 |
+
class UnslothFuncCallFlow(FlowSpec):
|
77 |
+
"""
|
78 |
+
Training pipeline
|
79 |
+
"""
|
80 |
+
# @see https://github.com/unslothai/unsloth/wiki
|
81 |
+
|
82 |
+
#--- flow parameters -------------------------------------------------------
|
83 |
+
|
84 |
+
RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
|
85 |
+
# in order to share the config across subprocesses
|
86 |
+
os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
|
87 |
+
|
88 |
+
hf_dataset = Parameter(
|
89 |
+
"hf_dataset",
|
90 |
+
help="dict with 'repo_id' and 'commit_hash' keys. " + \
|
91 |
+
"if 'commit_hash is None, falls back to latest version " +\
|
92 |
+
"of the dataset available in parquet format.\n" +
|
93 |
+
"Note that there are 3 required 'attributes' of type " + \
|
94 |
+
"str, list[str], list[str]",
|
95 |
+
type=JSONType,
|
96 |
+
default=dedent("""{
|
97 |
+
"repo_id": "Salesforce/xlam-function-calling-60k",
|
98 |
+
"config_name": "",
|
99 |
+
"commit_hash": "",
|
100 |
+
"attributes": {
|
101 |
+
"query_attr": "query",
|
102 |
+
"answers_attr": "answers",
|
103 |
+
"tools_attr": "tools"
|
104 |
+
}
|
105 |
+
}""").replace("'", '"').strip('"')
|
106 |
+
)
|
107 |
+
|
108 |
+
augmentation_rate = Parameter(
|
109 |
+
"augmentation_rate",
|
110 |
+
type=float,
|
111 |
+
default=.05,
|
112 |
+
help="proportion of records to be augmented "+\
|
113 |
+
"(x% of original dataset is created"+\
|
114 |
+
" as additional augmented datapoints), i.e. "+\
|
115 |
+
"truncated queries to serve as negative examples, "+\
|
116 |
+
"meaning they trigger no tool call "+\
|
117 |
+
"due to info incompleteness."
|
118 |
+
)
|
119 |
+
|
120 |
+
hf_enrich_dataset = Parameter(
|
121 |
+
"hf_enrich_dataset",
|
122 |
+
help="dict with 'repo_id', 'config_name' and 'commit_hash', "+\
|
123 |
+
"query_attribute' and 'query_attribute_handler' keys. "+\
|
124 |
+
"if 'commit_hash is None, falls back to latest version "+\
|
125 |
+
"of the dataset available in parquet format."+\
|
126 |
+
"'query_attribute' depicts the dataset attribute "+\
|
127 |
+
"from which 'queries' are to be sampled."+\
|
128 |
+
"'query_attribute_handler' serves for attributes "+\
|
129 |
+
"that have complex structure, "+\
|
130 |
+
"other than 'string' datatype.",
|
131 |
+
type=JSONType,
|
132 |
+
# @see https://huggingface.co/datasets/google-research-datasets/natural_questions
|
133 |
+
default=dedent("""{
|
134 |
+
"repo_id": "lighteval/natural_questions_clean",
|
135 |
+
"config_name": "",
|
136 |
+
"commit_hash": "",
|
137 |
+
"query_attribute": "question",
|
138 |
+
"query_attribute_handler": "lambda x: x"
|
139 |
+
}""").replace("'", '"').strip('"')
|
140 |
+
)
|
141 |
+
|
142 |
+
enrichment_rate = Parameter(
|
143 |
+
"enrichment_rate",
|
144 |
+
type=float,
|
145 |
+
default=.1,
|
146 |
+
help="proportion of records "+\
|
147 |
+
"to be added from the 'hf_enrich_dataset'"+\
|
148 |
+
"(x% of original dataset is sampled and"+\
|
149 |
+
" added as enriching datapoints), i.e. "+\
|
150 |
+
"queries to serve as negative examples, "+\
|
151 |
+
"due to their complete disconnexion "+\
|
152 |
+
"to tool calling situations."
|
153 |
+
)
|
154 |
+
|
155 |
+
dataset_repo_id = Parameter(
|
156 |
+
"dataset_repo_id",
|
157 |
+
type=str,
|
158 |
+
default="retrain-pipelines/func_calls",
|
159 |
+
help="The 'repo_id' to be used " + \
|
160 |
+
"for the Hugging Face dataset version push " + \
|
161 |
+
"(will be created at runtime" + \
|
162 |
+
" if doesn't already exist)."
|
163 |
+
)
|
164 |
+
|
165 |
+
hf_base_model = Parameter(
|
166 |
+
"hf_base_model",
|
167 |
+
help="dict with 'repo_id' and 'commit_hash' keys."+\
|
168 |
+
"if 'commit_hash is None, falls back "+\
|
169 |
+
"to latest available version of the model.",
|
170 |
+
type=JSONType,
|
171 |
+
default=dedent("""{
|
172 |
+
"repo_id": "unsloth/Qwen2.5-1.5B",
|
173 |
+
"commit_hash": ""
|
174 |
+
}""").replace("'", '"').strip('"')
|
175 |
+
)
|
176 |
+
|
177 |
+
cpt_training_args = Parameter(
|
178 |
+
"cpt_training_args",
|
179 |
+
help="dict with `TrainingArguments` params "+\
|
180 |
+
"for the CPT job.",
|
181 |
+
type=JSONType,
|
182 |
+
default=dedent("""{
|
183 |
+
"warmup_ratio": 0.1,
|
184 |
+
"num_train_epochs": 1
|
185 |
+
}""").replace("'", '"').strip('"')
|
186 |
+
)
|
187 |
+
|
188 |
+
sft_training_args = Parameter(
|
189 |
+
"sft_training_args",
|
190 |
+
help="dict with `TrainingArguments` params "+\
|
191 |
+
"for the SFT job.",
|
192 |
+
type=JSONType,
|
193 |
+
default=dedent("""{
|
194 |
+
"warmup_ratio": 0.1,
|
195 |
+
"num_train_epochs": 1
|
196 |
+
}""").replace("'", '"').strip('"')
|
197 |
+
)
|
198 |
+
|
199 |
+
model_repo_id = Parameter(
|
200 |
+
"model_repo_id",
|
201 |
+
type=str,
|
202 |
+
default="retrain-pipelines/function_caller",
|
203 |
+
help="The 'repo_id' to be used " + \
|
204 |
+
"for the Hugging Face model version push " + \
|
205 |
+
"(will be created at runtime" + \
|
206 |
+
" if doesn't already exist)."
|
207 |
+
)
|
208 |
+
|
209 |
+
default_pipeline_card_module_dir = \
|
210 |
+
os.path.dirname(
|
211 |
+
importlib.util.find_spec(
|
212 |
+
f"retrain_pipelines.pipeline_card."+
|
213 |
+
f"{RETRAIN_PIPELINE_TYPE}"
|
214 |
+
).origin)
|
215 |
+
pipeline_card_artifacts_path = Parameter(
|
216 |
+
"pipeline_card_artifacts_path",
|
217 |
+
type=str,
|
218 |
+
default=default_pipeline_card_module_dir,
|
219 |
+
help="pipeline_card artifacts location "+\
|
220 |
+
"(i.e. dir hosting your optional " + \
|
221 |
+
" custom documentation files :" + \
|
222 |
+
" 'pipeline_card.py' and/or 'template.html'"+\
|
223 |
+
" and/or 'model_readme.py'"+\
|
224 |
+
" and/or 'model_readme_template.md'," +\
|
225 |
+
" and/or 'dataset_readme.py'"+\
|
226 |
+
" and/or 'dataset_readme_template.md' file), " +\
|
227 |
+
"if different from default."
|
228 |
+
)
|
229 |
+
@staticmethod
|
230 |
+
def copy_default_dataset_readme_module(
|
231 |
+
target_dir: str,
|
232 |
+
exists_ok: bool = False
|
233 |
+
) -> None:
|
234 |
+
os.makedirs(target_dir, exist_ok=True)
|
235 |
+
if (
|
236 |
+
not exists_ok and
|
237 |
+
os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
|
238 |
+
):
|
239 |
+
print("File already exists. Skipping copy.")
|
240 |
+
else:
|
241 |
+
filefullname = os.path.join(
|
242 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
243 |
+
"dataset_readme.py"
|
244 |
+
)
|
245 |
+
shutil.copy(filefullname, target_dir)
|
246 |
+
print(filefullname)
|
247 |
+
@staticmethod
|
248 |
+
def copy_default_dataset_readme_template(
|
249 |
+
target_dir: str,
|
250 |
+
exists_ok: bool = False
|
251 |
+
) -> None:
|
252 |
+
os.makedirs(target_dir, exist_ok=True)
|
253 |
+
if (
|
254 |
+
not exists_ok and
|
255 |
+
os.path.exists(os.path.join(target_dir,
|
256 |
+
"dataset_readme_template.md"))
|
257 |
+
):
|
258 |
+
print("File already exists. Skipping copy.")
|
259 |
+
else:
|
260 |
+
filefullname = os.path.join(
|
261 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
262 |
+
"dataset_readme_template.md")
|
263 |
+
shutil.copy(filefullname, target_dir)
|
264 |
+
print(filefullname)
|
265 |
+
@staticmethod
|
266 |
+
def copy_default_model_readme_module(
|
267 |
+
target_dir: str,
|
268 |
+
exists_ok: bool = False
|
269 |
+
) -> None:
|
270 |
+
os.makedirs(target_dir, exist_ok=True)
|
271 |
+
if (
|
272 |
+
not exists_ok and
|
273 |
+
os.path.exists(os.path.join(target_dir, "model_readme.py"))
|
274 |
+
):
|
275 |
+
print("File already exists. Skipping copy.")
|
276 |
+
else:
|
277 |
+
filefullname = os.path.join(
|
278 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
279 |
+
"model_readme.py"
|
280 |
+
)
|
281 |
+
shutil.copy(filefullname, target_dir)
|
282 |
+
print(filefullname)
|
283 |
+
@staticmethod
|
284 |
+
def copy_default_model_readme_template(
|
285 |
+
target_dir: str,
|
286 |
+
exists_ok: bool = False
|
287 |
+
) -> None:
|
288 |
+
os.makedirs(target_dir, exist_ok=True)
|
289 |
+
if (
|
290 |
+
not exists_ok and
|
291 |
+
os.path.exists(os.path.join(target_dir,
|
292 |
+
"model_readme_template.md"))
|
293 |
+
):
|
294 |
+
print("File already exists. Skipping copy.")
|
295 |
+
else:
|
296 |
+
filefullname = os.path.join(
|
297 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
298 |
+
"model_readme_template.md")
|
299 |
+
shutil.copy(filefullname, target_dir)
|
300 |
+
print(filefullname)
|
301 |
+
@staticmethod
|
302 |
+
def copy_default_pipeline_card_module(
|
303 |
+
target_dir: str,
|
304 |
+
exists_ok: bool = False
|
305 |
+
) -> None:
|
306 |
+
os.makedirs(target_dir, exist_ok=True)
|
307 |
+
if (
|
308 |
+
not exists_ok and
|
309 |
+
os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
|
310 |
+
):
|
311 |
+
print("File already exists. Skipping copy.")
|
312 |
+
else:
|
313 |
+
filefullname = os.path.join(
|
314 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
315 |
+
"pipeline_card.py"
|
316 |
+
)
|
317 |
+
shutil.copy(filefullname, target_dir)
|
318 |
+
print(filefullname)
|
319 |
+
@staticmethod
|
320 |
+
def copy_default_pipeline_card_html_template(
|
321 |
+
target_dir: str,
|
322 |
+
exists_ok: bool = False
|
323 |
+
) -> None:
|
324 |
+
os.makedirs(target_dir, exist_ok=True)
|
325 |
+
if (
|
326 |
+
not exists_ok and
|
327 |
+
os.path.exists(os.path.join(target_dir, "template.html"))
|
328 |
+
):
|
329 |
+
print("File already exists. Skipping copy.")
|
330 |
+
else:
|
331 |
+
filefullname = os.path.join(
|
332 |
+
UnslothFuncCallFlow.default_pipeline_card_module_dir,
|
333 |
+
"template.html")
|
334 |
+
shutil.copy(filefullname, target_dir)
|
335 |
+
print(filefullname)
|
336 |
+
|
337 |
+
del RETRAIN_PIPELINE_TYPE
|
338 |
+
|
339 |
+
#---------------------------------------------------------------------------
|
340 |
+
|
341 |
+
@step
|
342 |
+
def start(self):
|
343 |
+
print(f"{current.flow_name} - {current.run_id}")
|
344 |
+
|
345 |
+
# GPU availability
|
346 |
+
print(torch.cuda.get_device_name(0))
|
347 |
+
print(torch.__version__)
|
348 |
+
self.engine = "gpu" if torch.cuda.is_available() else "cpu"
|
349 |
+
|
350 |
+
# hf_dataset
|
351 |
+
hf_dataset_dict = \
|
352 |
+
get_lazy_df(
|
353 |
+
repo_id=self.hf_dataset["repo_id"],
|
354 |
+
commit_hash=self.hf_dataset["commit_hash"],
|
355 |
+
files_filter=(
|
356 |
+
self.hf_dataset['config_name']+"/.*\\.parquet"
|
357 |
+
if (
|
358 |
+
self.hf_dataset["config_name"] and
|
359 |
+
"" < self.hf_dataset["config_name"]
|
360 |
+
) else ".*\\.parquet"
|
361 |
+
),
|
362 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
363 |
+
)
|
364 |
+
try:
|
365 |
+
print(hf_dataset_dict["repo_id"], ", ",
|
366 |
+
hf_dataset_dict["commit_hash"], " - ",
|
367 |
+
hf_dataset_dict["commit_datetime"], "\n",
|
368 |
+
hf_dataset_dict["lazy_df"].explain())
|
369 |
+
except ComputeError as ex:
|
370 |
+
if "HF_TOKEN" not in os.environ:
|
371 |
+
print("Does the Hugging Face-hosted dataset " +
|
372 |
+
"require authentication ?",
|
373 |
+
file=sys.stderr, flush=True)
|
374 |
+
raise ex
|
375 |
+
self.hf_dataset_dict = hf_dataset_dict
|
376 |
+
|
377 |
+
# hf_enrich_dataset
|
378 |
+
print(self.hf_enrich_dataset)
|
379 |
+
hf_enrich_dataset_dict = \
|
380 |
+
get_lazy_df(
|
381 |
+
repo_id=self.hf_enrich_dataset["repo_id"],
|
382 |
+
commit_hash=self.hf_enrich_dataset["commit_hash"],
|
383 |
+
files_filter=(
|
384 |
+
self.hf_enrich_dataset['config_name']+"/.*\\.parquet"
|
385 |
+
if (
|
386 |
+
self.hf_enrich_dataset["config_name"] and
|
387 |
+
"" < self.hf_enrich_dataset["config_name"]
|
388 |
+
) else ".*\\.parquet"
|
389 |
+
),
|
390 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
391 |
+
)
|
392 |
+
print(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
|
393 |
+
for k in ['commit_hash',
|
394 |
+
'commit_datetime']))
|
395 |
+
self.hf_enrich_dataset_dict = hf_enrich_dataset_dict
|
396 |
+
|
397 |
+
# hf_base_model
|
398 |
+
hf_base_model_commits = list_repo_commits(
|
399 |
+
repo_id=self.hf_base_model["repo_id"],
|
400 |
+
revision=(
|
401 |
+
None if (rev_commit_hash:=self.hf_base_model["commit_hash"]) == ""
|
402 |
+
else rev_commit_hash
|
403 |
+
),
|
404 |
+
repo_type="model",
|
405 |
+
token=os.getenv("HF_TOKEN", None))
|
406 |
+
self.hf_base_model_dict = {
|
407 |
+
"repo_id": self.hf_base_model["repo_id"],
|
408 |
+
"commit_hash": hf_base_model_commits[0].commit_id,
|
409 |
+
"commit_datetime": \
|
410 |
+
hf_base_model_commits[0].created_at
|
411 |
+
}
|
412 |
+
|
413 |
+
self.model_version_blessed = False
|
414 |
+
self.current_blessed_run = None
|
415 |
+
self.current_blessed_version_dict = None
|
416 |
+
current.run.remove_tag("model_version_blessed")
|
417 |
+
|
418 |
+
self.retrain_pipelines = f"retrain-pipelines {__version__}"
|
419 |
+
self.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
|
420 |
+
|
421 |
+
self.serving_artifacts_local_folder = \
|
422 |
+
os.path.realpath(os.path.join(
|
423 |
+
os.path.dirname(__file__),
|
424 |
+
'..', '..', 'serving_artifacts',
|
425 |
+
os.path.sep.join(current.run.path_components)
|
426 |
+
))
|
427 |
+
|
428 |
+
if not os.path.exists(self.serving_artifacts_local_folder):
|
429 |
+
os.makedirs(self.serving_artifacts_local_folder)
|
430 |
+
|
431 |
+
self.unsloth_dir = os.path.join(
|
432 |
+
self.serving_artifacts_local_folder,
|
433 |
+
"Unsloth"
|
434 |
+
)
|
435 |
+
print(f"unsloth_dir : {self.unsloth_dir}")
|
436 |
+
self.cpt_model_dir = os.path.join(
|
437 |
+
self.unsloth_dir, "cpt_model")
|
438 |
+
self.sft_model_dir = os.path.join(
|
439 |
+
self.unsloth_dir, "sft_model")
|
440 |
+
|
441 |
+
self.next(self.eda)
|
442 |
+
|
443 |
+
|
444 |
+
@step
|
445 |
+
def eda(self):
|
446 |
+
"""
|
447 |
+
exploratory data analysis.
|
448 |
+
"""
|
449 |
+
|
450 |
+
############################
|
451 |
+
# features and label #
|
452 |
+
# basic counts #
|
453 |
+
############################
|
454 |
+
self.records_count = self.hf_dataset_dict["lazy_df"] \
|
455 |
+
.select(pl.len()).collect(engine=self.engine).item()
|
456 |
+
self.data_schema = get_column_info(
|
457 |
+
self.hf_dataset_dict["lazy_df"], engine=self.engine)
|
458 |
+
############################
|
459 |
+
|
460 |
+
############################
|
461 |
+
# Answers #
|
462 |
+
# tools count #
|
463 |
+
############################
|
464 |
+
struct_schema = pl.Struct([
|
465 |
+
pl.Field("name",
|
466 |
+
pl.String
|
467 |
+
),
|
468 |
+
pl.Field("arguments",
|
469 |
+
pl.List(pl.String) # we retrieve list of args names
|
470 |
+
# (without assigned values)
|
471 |
+
)
|
472 |
+
])
|
473 |
+
tool_answer_occurrences_df = \
|
474 |
+
count_tool_occurrences(
|
475 |
+
self.hf_dataset_dict["lazy_df"],
|
476 |
+
self.hf_dataset["attributes"]["answers_attr"],
|
477 |
+
struct_schema) \
|
478 |
+
.collect(engine=self.engine)
|
479 |
+
print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
|
480 |
+
f"query/tool-calls pairs")
|
481 |
+
fig = plot_tools_occurences(tool_answer_occurrences_df,
|
482 |
+
title_prefix="Dataset answers - ")
|
483 |
+
self.answers_tools_count_fig = fig
|
484 |
+
############################
|
485 |
+
|
486 |
+
############################
|
487 |
+
# Query #
|
488 |
+
# words count #
|
489 |
+
############################
|
490 |
+
queries_max_length = self.hf_dataset_dict["lazy_df"].select(
|
491 |
+
pl.col(
|
492 |
+
self.hf_dataset["attributes"]["query_attr"]
|
493 |
+
).str.len_chars().max().alias("max_query_length")
|
494 |
+
).collect(engine=self.engine)
|
495 |
+
print(f"longuest query counts " +
|
496 |
+
f"{queries_max_length['max_query_length'][0]:,} characters")
|
497 |
+
|
498 |
+
# queries length quartiles
|
499 |
+
self.query_words_stats = \
|
500 |
+
column_words_stats(
|
501 |
+
self.hf_dataset_dict["lazy_df"],
|
502 |
+
self.hf_dataset["attributes"]["query_attr"]
|
503 |
+
).collect(engine=self.engine)
|
504 |
+
print(self.query_words_stats.to_pandas().to_string(index=False))
|
505 |
+
print("Two thirds of the records have a query with less than " +
|
506 |
+
f"{self.query_words_stats['q3'][0]} words.")
|
507 |
+
|
508 |
+
fig = plot_words_count(
|
509 |
+
self.hf_dataset_dict["lazy_df"],
|
510 |
+
column_name=self.hf_dataset["attributes"]["query_attr"],
|
511 |
+
engine=self.engine)
|
512 |
+
self.words_count_fig = fig
|
513 |
+
############################
|
514 |
+
|
515 |
+
############################
|
516 |
+
# hf_enrich_dataset #
|
517 |
+
# Query words count #
|
518 |
+
############################
|
519 |
+
enrich_question_words_stats = \
|
520 |
+
column_words_stats(
|
521 |
+
self.hf_enrich_dataset_dict['lazy_df'],
|
522 |
+
self.hf_enrich_dataset["query_attribute"],
|
523 |
+
column_attr_handler=eval(
|
524 |
+
self.hf_enrich_dataset["query_attribute_handler"])
|
525 |
+
).collect(engine=self.engine)
|
526 |
+
print(enrich_question_words_stats.to_pandas()
|
527 |
+
.to_string(index=False))
|
528 |
+
del enrich_question_words_stats
|
529 |
+
############################
|
530 |
+
|
531 |
+
self.next(self.augment_data)
|
532 |
+
|
533 |
+
|
534 |
+
@step
|
535 |
+
def augment_data(self):
|
536 |
+
"""
|
537 |
+
Add 'negative' examples, where
|
538 |
+
queries do not trigger any tool call.
|
539 |
+
To achieve that, we sample long user queries,
|
540 |
+
truncate at half words count, and
|
541 |
+
associate this to an empty list of tool-calls.
|
542 |
+
"""
|
543 |
+
"""
|
544 |
+
We only consider :
|
545 |
+
- records with longuest queries,
|
546 |
+
i.e. queries in the last quartile
|
547 |
+
of "queries with most word-counts"
|
548 |
+
(this is to avoid that 'truncated' queries
|
549 |
+
get really short)
|
550 |
+
- records with answers consisting
|
551 |
+
in a single tool-call
|
552 |
+
(in order to minimize the risk
|
553 |
+
that truncating actually gives
|
554 |
+
a valid answer with
|
555 |
+
one tool-call [or more])
|
556 |
+
|
557 |
+
Note on flow 'augmentation_rate' :
|
558 |
+
we add that many records (at most),
|
559 |
+
as quartiles size permits.
|
560 |
+
"""
|
561 |
+
|
562 |
+
print("Sampling within the population with more than " +
|
563 |
+
str(self.query_words_stats['q3'][0]) +
|
564 |
+
" words (longest queries quartile) =>")
|
565 |
+
|
566 |
+
samples_count = \
|
567 |
+
int(self.records_count * self.augmentation_rate)
|
568 |
+
print(f"would represent {samples_count:,.0f} " +
|
569 |
+
f"records to be sampled")
|
570 |
+
|
571 |
+
eligible_records_df = \
|
572 |
+
self.hf_dataset_dict["lazy_df"].filter(
|
573 |
+
pl.col(
|
574 |
+
self.hf_dataset["attributes"]["query_attr"]
|
575 |
+
)
|
576 |
+
.str.extract_all(r"\w+")
|
577 |
+
.map_elements(
|
578 |
+
lambda arr: len(arr),
|
579 |
+
return_dtype=pl.Int16)
|
580 |
+
.gt(self.query_words_stats['q3'][0])
|
581 |
+
& pl.col("answers")
|
582 |
+
.map_elements(
|
583 |
+
lambda x: len(json.loads(x)) == 1
|
584 |
+
if isinstance(x, str)
|
585 |
+
else False,
|
586 |
+
return_dtype=pl.Boolean)
|
587 |
+
) \
|
588 |
+
.collect(engine=self.engine)
|
589 |
+
eligible_records_count = \
|
590 |
+
eligible_records_df.select(pl.len())["len"][0]
|
591 |
+
print(f"eligible_records_count : " +
|
592 |
+
f"{eligible_records_count:,.0f}")
|
593 |
+
samples_count = min(samples_count, eligible_records_count)
|
594 |
+
self.actual_augmentation_rate = \
|
595 |
+
samples_count / self.records_count
|
596 |
+
print("actual augmentation rate : " +
|
597 |
+
f"{self.actual_augmentation_rate:.1%}")
|
598 |
+
sampled_records_df = eligible_records_df.sample(
|
599 |
+
n=samples_count
|
600 |
+
)
|
601 |
+
|
602 |
+
self.augmented_records_df = \
|
603 |
+
sampled_records_df.with_columns(
|
604 |
+
pl.col("query")
|
605 |
+
.map_elements(
|
606 |
+
lambda query:
|
607 |
+
" ".join(
|
608 |
+
query.split()[
|
609 |
+
:len(query.split()) // 2]),
|
610 |
+
return_dtype=pl.Utf8)
|
611 |
+
.alias("truncated_query")
|
612 |
+
).select([
|
613 |
+
pl.col("truncated_query").alias("query"),
|
614 |
+
pl.lit("[]").alias("answers")
|
615 |
+
])
|
616 |
+
print(self.augmented_records_df.height,
|
617 |
+
self.augmented_records_df.columns)
|
618 |
+
|
619 |
+
self.next(self.enrich_data)
|
620 |
+
|
621 |
+
|
622 |
+
@step
|
623 |
+
def enrich_data(self):
|
624 |
+
"""
|
625 |
+
Further enrich our dataset with 'negative' records from
|
626 |
+
another dataset (can be general-purpose text dataset)
|
627 |
+
as specified by the the flow 'hf_enrich_dataset' argument.
|
628 |
+
"""
|
629 |
+
"""
|
630 |
+
Note : we here use the Hugging Face `datasets` library
|
631 |
+
in 'streaming' mode for records sampling.
|
632 |
+
"""
|
633 |
+
|
634 |
+
hf_enrich_ds = load_dataset(
|
635 |
+
path=self.hf_enrich_dataset["repo_id"],
|
636 |
+
name=self.hf_enrich_dataset["config_name"],
|
637 |
+
revision=self.hf_enrich_dataset_dict["commit_hash"],
|
638 |
+
streaming=True)
|
639 |
+
print(hf_enrich_ds["train"])
|
640 |
+
|
641 |
+
samples_count = \
|
642 |
+
int(self.records_count * self.enrichment_rate)
|
643 |
+
print(f"Samplig {samples_count:,.0f} records")
|
644 |
+
|
645 |
+
query_attribute_handler = \
|
646 |
+
eval(self.hf_enrich_dataset["query_attribute_handler"])
|
647 |
+
samples_iterator = iterable_dataset_multi_buffer_sampler(
|
648 |
+
hf_enrich_ds["train"],
|
649 |
+
total_samples=samples_count,
|
650 |
+
attributes_selector=\
|
651 |
+
(lambda x:query_attribute_handler(
|
652 |
+
x[self.hf_enrich_dataset["query_attribute"]])),
|
653 |
+
buffer_size=3_000,
|
654 |
+
num_passes=3,
|
655 |
+
seed=None
|
656 |
+
)
|
657 |
+
# Capitalize and add end punctuation if missing
|
658 |
+
start_time = time.time()
|
659 |
+
print("Starting sample enriching records, " +
|
660 |
+
"this may take some time if the source dataset " +
|
661 |
+
"has a complex structure..")
|
662 |
+
samples_list = [
|
663 |
+
s.capitalize() + ("" if s[-1] in ".!?" else "?")
|
664 |
+
for s in samples_iterator]
|
665 |
+
elapsed_time = time.time() - start_time
|
666 |
+
print(f".. sampling completed " +
|
667 |
+
f"({int(elapsed_time // 3_600)}h:" +
|
668 |
+
f"{int((elapsed_time % 3_600) // 60)}m:" +
|
669 |
+
f"{int(elapsed_time % 60)}s).")
|
670 |
+
enriched_records_df = pl.DataFrame(
|
671 |
+
{"query": samples_list,
|
672 |
+
"answers": \
|
673 |
+
["[]"] * \
|
674 |
+
len(samples_list)}
|
675 |
+
)
|
676 |
+
self.enriched_records_df = enriched_records_df
|
677 |
+
|
678 |
+
self.next(self.dataset_to_hub)
|
679 |
+
|
680 |
+
|
681 |
+
@step
|
682 |
+
def dataset_to_hub(self):
|
683 |
+
"""
|
684 |
+
Push to hub dataset version
|
685 |
+
- continued pre-training dataset
|
686 |
+
- training and validation splits of the
|
687 |
+
augmented and enriched
|
688 |
+
supervised finetuning dataset
|
689 |
+
- readme with versioning info
|
690 |
+
"""
|
691 |
+
|
692 |
+
#############################
|
693 |
+
# case of user-provided #
|
694 |
+
# documentation artifact(s) #
|
695 |
+
#############################
|
696 |
+
# note that user can provide either
|
697 |
+
# 'pipeline_card.py' or 'template.html'
|
698 |
+
# or 'dataset_readme.py'
|
699 |
+
# or 'dataset_readme_template.md'
|
700 |
+
# or 'model_readme.py'
|
701 |
+
# or 'model_readme_template.md'
|
702 |
+
# or any combination of those
|
703 |
+
# when specifying custom
|
704 |
+
# 'pipeline_card_artifacts_path'
|
705 |
+
if (
|
706 |
+
"dataset_readme_template.md" in
|
707 |
+
os.listdir(self.pipeline_card_artifacts_path)
|
708 |
+
):
|
709 |
+
template_dir = self.pipeline_card_artifacts_path
|
710 |
+
else:
|
711 |
+
template_dir = os.path.dirname(
|
712 |
+
importlib.util.find_spec(
|
713 |
+
f"retrain_pipelines.pipeline_card."+
|
714 |
+
f"{os.getenv('retrain_pipeline_type')}"
|
715 |
+
).origin)
|
716 |
+
print(f"template_dir : '{template_dir}'")
|
717 |
+
#############################
|
718 |
+
if "dataset_readme.py" in os.listdir(
|
719 |
+
self.pipeline_card_artifacts_path):
|
720 |
+
from retrain_pipelines.utils import \
|
721 |
+
get_get_dataset_readme_content
|
722 |
+
get_dataset_readme_content = \
|
723 |
+
get_get_dataset_readme_content(
|
724 |
+
self.pipeline_card_artifacts_path)
|
725 |
+
else:
|
726 |
+
from retrain_pipelines.pipeline_card import \
|
727 |
+
get_dataset_readme_content
|
728 |
+
#############################
|
729 |
+
|
730 |
+
|
731 |
+
#############################
|
732 |
+
# augmented & enriched #
|
733 |
+
# finetuning dataset #
|
734 |
+
#############################
|
735 |
+
merged_df = pl.concat([
|
736 |
+
# dataset
|
737 |
+
self.hf_dataset_dict["lazy_df"].select([
|
738 |
+
self.hf_dataset["attributes"]["query_attr"],
|
739 |
+
self.hf_dataset["attributes"]["answers_attr"]
|
740 |
+
]).collect(engine=self.engine),
|
741 |
+
# truncated queries augmentation
|
742 |
+
self.augmented_records_df,
|
743 |
+
# enriching dataset
|
744 |
+
self.enriched_records_df
|
745 |
+
]).sample(
|
746 |
+
# shuffling
|
747 |
+
fraction=1,
|
748 |
+
shuffle=True,
|
749 |
+
with_replacement=False
|
750 |
+
)
|
751 |
+
merged_df = merged_df.sample(fraction=1, shuffle=True)
|
752 |
+
merged_df.rechunk()
|
753 |
+
print(("merged_df", f"{merged_df.shape[0]:,.0F}",
|
754 |
+
merged_df.columns))
|
755 |
+
|
756 |
+
pandas_df = merged_df.to_pandas()
|
757 |
+
train_size = int(0.8 * len(pandas_df))
|
758 |
+
print(f"validation : {len(pandas_df) - train_size}")
|
759 |
+
sft_dataset = DatasetDict({
|
760 |
+
"train": Dataset.from_pandas(pandas_df[:train_size]),
|
761 |
+
"validation": Dataset.from_pandas(pandas_df[train_size:])
|
762 |
+
})
|
763 |
+
#############################
|
764 |
+
|
765 |
+
#############################
|
766 |
+
# continued pre-training #
|
767 |
+
# dataset #
|
768 |
+
#############################
|
769 |
+
struct_schema = pl.Struct([
|
770 |
+
pl.Field("name", pl.String),
|
771 |
+
pl.Field("description", pl.String),
|
772 |
+
pl.Field(
|
773 |
+
"parameters",
|
774 |
+
pl.String # Use String to allow
|
775 |
+
# for varying structures
|
776 |
+
# (different tools indeed having
|
777 |
+
# different sets of parameters
|
778 |
+
# i.e. different parameters counts,
|
779 |
+
# datatypes and names)
|
780 |
+
# so parsing must be tolerant.
|
781 |
+
)
|
782 |
+
])
|
783 |
+
unique_tools_df = get_unique_tools(
|
784 |
+
self.hf_dataset_dict["lazy_df"],
|
785 |
+
tools_attr_name=\
|
786 |
+
self.hf_dataset["attributes"]["tools_attr"],
|
787 |
+
struct_schema=struct_schema
|
788 |
+
).collect(engine=self.engine)
|
789 |
+
unique_tools_arrow_table = unique_tools_df.to_arrow()
|
790 |
+
self.unique_tools_dataset = \
|
791 |
+
Dataset(unique_tools_arrow_table)
|
792 |
+
print(self.unique_tools_dataset)
|
793 |
+
#############################
|
794 |
+
|
795 |
+
#############################
|
796 |
+
# DatasetDict #
|
797 |
+
# with multiple tables #
|
798 |
+
#############################
|
799 |
+
dataset_dict = DatasetDict({
|
800 |
+
"continued_pre_training": \
|
801 |
+
self.unique_tools_dataset,
|
802 |
+
"supervised_finetuning": sft_dataset
|
803 |
+
})
|
804 |
+
print(dataset_dict, flush=True)
|
805 |
+
#############################
|
806 |
+
|
807 |
+
#############################
|
808 |
+
# dataset README #
|
809 |
+
# from template #
|
810 |
+
#############################
|
811 |
+
commit_datetime = datetime.utcnow()
|
812 |
+
new_dataset_version_label = get_new_repo_minor_version(
|
813 |
+
repo_id=self.dataset_repo_id,
|
814 |
+
repo_type="dataset",
|
815 |
+
hf_token=os.getenv("HF_TOKEN", None))
|
816 |
+
readme_content = get_dataset_readme_content(
|
817 |
+
template_folder=template_dir,
|
818 |
+
|
819 |
+
hf_dataset_dict=self.hf_dataset_dict,
|
820 |
+
hf_enrich_dataset_dict=self.hf_enrich_dataset_dict,
|
821 |
+
dataset_dict=dataset_dict,
|
822 |
+
|
823 |
+
augmentation_rate=self.actual_augmentation_rate,
|
824 |
+
enrichment_rate=self.enrichment_rate,
|
825 |
+
|
826 |
+
version_label=new_dataset_version_label,
|
827 |
+
commit_datetime=commit_datetime,
|
828 |
+
|
829 |
+
mf_flow_name=current.flow_name,
|
830 |
+
mf_run_id=current.run.id,
|
831 |
+
engine=self.engine
|
832 |
+
)
|
833 |
+
#############################
|
834 |
+
|
835 |
+
dataset_commit_hash = push_dataset_version_to_hub(
|
836 |
+
repo_id=self.dataset_repo_id,
|
837 |
+
version_label=new_dataset_version_label,
|
838 |
+
timestamp_str=commit_datetime.strftime(
|
839 |
+
"%Y-%m-%d %H:%M:%S UTC"),
|
840 |
+
dataset_dict=dataset_dict,
|
841 |
+
dataset_readme_content=readme_content,
|
842 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
843 |
+
)
|
844 |
+
if not dataset_commit_hash:
|
845 |
+
raise Exception(
|
846 |
+
"Failed to publish dataset version.")
|
847 |
+
print(f"https://huggingface.co/datasets/{self.dataset_repo_id}" +
|
848 |
+
f"/blob/{dataset_commit_hash}/README.md")
|
849 |
+
self.dataset_commit_dict = {
|
850 |
+
"repo_id": self.dataset_repo_id,
|
851 |
+
"commit_hash": dataset_commit_hash,
|
852 |
+
"version_label": new_dataset_version_label,
|
853 |
+
"commit_datetime": commit_datetime,
|
854 |
+
}
|
855 |
+
|
856 |
+
self.next(self.continued_pre_training)
|
857 |
+
|
858 |
+
|
859 |
+
@step
|
860 |
+
def continued_pre_training(self):
|
861 |
+
"""
|
862 |
+
Gives the base model some additional intrinsic knowkledge
|
863 |
+
through continued pre-training.
|
864 |
+
See unsloth.ai/blog/contpretraining
|
865 |
+
"""
|
866 |
+
from retrain_pipelines.model.hf_utils import \
|
867 |
+
plot_log_history
|
868 |
+
|
869 |
+
#######################################
|
870 |
+
# base-model and associated tokenizer #
|
871 |
+
# from Hub (or local cache) #
|
872 |
+
#######################################
|
873 |
+
self.max_seq_length = 2048
|
874 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
875 |
+
model_name=self.hf_base_model_dict["repo_id"],
|
876 |
+
revision=self.hf_base_model_dict["commit_hash"],
|
877 |
+
max_seq_length=self.max_seq_length,
|
878 |
+
dtype=None,
|
879 |
+
load_in_4bit=False,
|
880 |
+
# case of a gated or private base-model
|
881 |
+
token=os.getenv("HF_TOKEN", None)
|
882 |
+
)
|
883 |
+
#######################################
|
884 |
+
|
885 |
+
#######################################
|
886 |
+
# dataset prompt_template mapping #
|
887 |
+
#######################################
|
888 |
+
tools_dataset = DatasetDict(
|
889 |
+
{"train": self.unique_tools_dataset})
|
890 |
+
print(tools_dataset)
|
891 |
+
tool_prompt_template = "tool: {}"
|
892 |
+
def formatting_prompts_func(tools_batch):
|
893 |
+
tools_batch = tools_batch["tool"]
|
894 |
+
outputs = []
|
895 |
+
for tool in tools_batch:
|
896 |
+
# Must add EOS_TOKEN,
|
897 |
+
# otherwise generation will go on forever!
|
898 |
+
text = tool_prompt_template.format(tool) + \
|
899 |
+
tokenizer.eos_token
|
900 |
+
outputs.append(text)
|
901 |
+
return { "tools" : outputs, }
|
902 |
+
cpt_dataset = tools_dataset["train"].map(
|
903 |
+
formatting_prompts_func, batched=True,)
|
904 |
+
#######################################
|
905 |
+
|
906 |
+
#######################################
|
907 |
+
# PEFT adapter #
|
908 |
+
# for continued pre-training #
|
909 |
+
#######################################
|
910 |
+
model = FastLanguageModel.get_peft_model(
|
911 |
+
model,
|
912 |
+
r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
|
913 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
914 |
+
"gate_proj", "up_proj", "down_proj",
|
915 |
+
# Add for continued pretraining
|
916 |
+
"embed_tokens", "lm_head",],
|
917 |
+
lora_alpha = 32,
|
918 |
+
lora_dropout = 0, # Supports any, 0 is optimized
|
919 |
+
bias = "none", # Supports any, "none" is optimized
|
920 |
+
# True or "unsloth" for very long context
|
921 |
+
use_gradient_checkpointing = "unsloth",
|
922 |
+
use_rslora = True, # rank-stabilized LoRA
|
923 |
+
loftq_config = None, # LoftQ
|
924 |
+
#random_state = 3407,
|
925 |
+
)
|
926 |
+
#######################################
|
927 |
+
|
928 |
+
#######################################
|
929 |
+
# cpt_trainer #
|
930 |
+
#######################################
|
931 |
+
if (
|
932 |
+
"records_cap" in self.cpt_training_args and
|
933 |
+
self.cpt_training_args["records_cap"] is not None and
|
934 |
+
isinstance(self.cpt_training_args["records_cap"], int)
|
935 |
+
):
|
936 |
+
cpt_dataset = cpt_dataset.take(
|
937 |
+
self.cpt_training_args["records_cap"])
|
938 |
+
print(f"cpt_dataset : {cpt_dataset}")
|
939 |
+
|
940 |
+
train_args = UnslothTrainingArguments(
|
941 |
+
# https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
|
942 |
+
per_device_train_batch_size=2,
|
943 |
+
gradient_accumulation_steps=8,
|
944 |
+
|
945 |
+
**{k: v for k, v in self.cpt_training_args.items()
|
946 |
+
if k != "records_cap"},
|
947 |
+
|
948 |
+
# 2 to 10x smaller learning rate
|
949 |
+
# for the embedding matrices
|
950 |
+
learning_rate=5e-5,
|
951 |
+
embedding_learning_rate=1e-5,
|
952 |
+
|
953 |
+
fp16=not is_bfloat16_supported(),
|
954 |
+
bf16=is_bfloat16_supported(),
|
955 |
+
logging_steps=1,
|
956 |
+
optim="adamw_8bit",
|
957 |
+
weight_decay=0.01,
|
958 |
+
lr_scheduler_type="linear",
|
959 |
+
#seed=3407,
|
960 |
+
|
961 |
+
output_dir=os.path.join(
|
962 |
+
self.unsloth_dir, "outputs", "cpt"),
|
963 |
+
save_total_limit = 2,
|
964 |
+
|
965 |
+
report_to="tensorboard",
|
966 |
+
logging_dir=os.path.join(
|
967 |
+
self.sft_model_dir,
|
968 |
+
"runs", "cpt")
|
969 |
+
)
|
970 |
+
|
971 |
+
self.cpt_traces_file_fullname = os.path.join(
|
972 |
+
self.unsloth_dir, "cpt_trainer_traces.txt")
|
973 |
+
print("Training started. " +
|
974 |
+
f"Check {self.cpt_traces_file_fullname} for live traces.",
|
975 |
+
flush=True)
|
976 |
+
|
977 |
+
trainer = UnslothTrainer(
|
978 |
+
model=model, tokenizer=tokenizer,
|
979 |
+
train_dataset=cpt_dataset,
|
980 |
+
dataset_text_field="tools",
|
981 |
+
max_seq_length=self.max_seq_length,
|
982 |
+
dataset_num_proc=2,
|
983 |
+
args=train_args,
|
984 |
+
)
|
985 |
+
#######################################
|
986 |
+
|
987 |
+
#######################################
|
988 |
+
# Show current memory stats #
|
989 |
+
#######################################
|
990 |
+
torch.cuda.ipc_collect()
|
991 |
+
torch.cuda.empty_cache()
|
992 |
+
_ = gc.collect()
|
993 |
+
|
994 |
+
gpu_stats = torch.cuda.get_device_properties(0)
|
995 |
+
self.start_gpu_memory = \
|
996 |
+
round(torch.cuda.max_memory_reserved()
|
997 |
+
/ 1024 / 1024 / 1024, 3)
|
998 |
+
self.max_memory = \
|
999 |
+
round(gpu_stats.total_memory
|
1000 |
+
/ 1024 / 1024 / 1024, 3)
|
1001 |
+
print(f"GPU = {gpu_stats.name}. " +
|
1002 |
+
f"Max memory = {self.max_memory} GB.")
|
1003 |
+
print(f"{self.start_gpu_memory} GB of memory reserved.")
|
1004 |
+
#######################################
|
1005 |
+
|
1006 |
+
with open(self.cpt_traces_file_fullname, 'w') as f:
|
1007 |
+
with redirect_stdout(f):
|
1008 |
+
hf_logging.set_verbosity_error()
|
1009 |
+
hf_logging.disable_progress_bar()
|
1010 |
+
trainer_stats = trainer.train()
|
1011 |
+
hf_logging.set_verbosity_info()
|
1012 |
+
hf_logging.enable_progress_bar()
|
1013 |
+
print(f"{trainer_stats.metrics['train_runtime']} " +
|
1014 |
+
f"seconds used for training " +
|
1015 |
+
f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
|
1016 |
+
f" minutes).")
|
1017 |
+
|
1018 |
+
self.cpt_log_history = trainer.state.log_history
|
1019 |
+
# print(self.cpt_log_history)
|
1020 |
+
self.cpt_log_history_fig = \
|
1021 |
+
plot_log_history(
|
1022 |
+
self.cpt_log_history,
|
1023 |
+
title="Continued pretraining loss"
|
1024 |
+
)
|
1025 |
+
|
1026 |
+
model.save_pretrained_merged(
|
1027 |
+
save_directory=self.cpt_model_dir,
|
1028 |
+
tokenizer=tokenizer,
|
1029 |
+
save_method="lora"
|
1030 |
+
)
|
1031 |
+
print(f"cpt_model_dir : {self.cpt_model_dir}\n")
|
1032 |
+
|
1033 |
+
self.next(self.supervised_finetuning)
|
1034 |
+
|
1035 |
+
|
1036 |
+
@step
|
1037 |
+
def supervised_finetuning(self):
|
1038 |
+
"""
|
1039 |
+
Trains the model on tool-calling
|
1040 |
+
task specialization.
|
1041 |
+
"""
|
1042 |
+
from retrain_pipelines.model.hf_utils import \
|
1043 |
+
plot_log_history
|
1044 |
+
|
1045 |
+
torch.cuda.ipc_collect()
|
1046 |
+
torch.cuda.empty_cache()
|
1047 |
+
_ = gc.collect()
|
1048 |
+
|
1049 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
1050 |
+
model_name=self.cpt_model_dir,
|
1051 |
+
max_seq_length=self.max_seq_length,
|
1052 |
+
dtype=None,
|
1053 |
+
load_in_4bit=False,
|
1054 |
+
)
|
1055 |
+
# !!!! bug fix BEGIN !!!!
|
1056 |
+
# otherwise, 'embed_tokens' and 'lm_head'
|
1057 |
+
# trained during CPT are "ignored",
|
1058 |
+
# i.e. not saved after SFT
|
1059 |
+
# (note that, alternatively, we could also
|
1060 |
+
# do this fix after sft-training and
|
1061 |
+
# just before saving ;
|
1062 |
+
# which would be equivalent to
|
1063 |
+
# freezing embeddings during finetuning
|
1064 |
+
# for better pretrained knowledge retention)
|
1065 |
+
# @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
|
1066 |
+
model.model.model.embed_tokens.modules_to_save.default.to(
|
1067 |
+
device="cuda:0",
|
1068 |
+
dtype=torch.float32,
|
1069 |
+
non_blocking=True)
|
1070 |
+
model.model.model.embed_tokens.modules_to_save.default \
|
1071 |
+
.requires_grad_(True)
|
1072 |
+
model.model.lm_head.modules_to_save.default.to(
|
1073 |
+
device="cuda:0",
|
1074 |
+
dtype=torch.float32,
|
1075 |
+
non_blocking=True)
|
1076 |
+
model.model.lm_head.modules_to_save.default \
|
1077 |
+
.requires_grad_(True)
|
1078 |
+
# !!!! bug fix END !!!!
|
1079 |
+
|
1080 |
+
#######################################
|
1081 |
+
# dataset prompt_template mapping #
|
1082 |
+
#######################################
|
1083 |
+
# download from Hub (or get from local cache)
|
1084 |
+
queries_dataset = load_dataset(
|
1085 |
+
path=self.dataset_commit_dict["repo_id"],
|
1086 |
+
name="supervised_finetuning",
|
1087 |
+
revision=self.dataset_commit_dict["commit_hash"],
|
1088 |
+
token=os.getenv("HF_TOKEN", None))
|
1089 |
+
print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
|
1090 |
+
self.sft_prompt_template = dedent("""
|
1091 |
+
You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
|
1092 |
+
|
1093 |
+
Rules:
|
1094 |
+
1. You can only use tools you know. Do not create new tools under any circumstances.
|
1095 |
+
2. If a query does not match any known tool, return an empty list ([]).
|
1096 |
+
3. If information is missing to use a known tool, do not attempt to use it.
|
1097 |
+
4. Your response must always be a valid JSON array, and nothing else.
|
1098 |
+
|
1099 |
+
Be precise and do not guess.
|
1100 |
+
|
1101 |
+
# query:
|
1102 |
+
{}
|
1103 |
+
# response:
|
1104 |
+
{}
|
1105 |
+
""").strip()
|
1106 |
+
tokenizer.chat_template = self.sft_prompt_template
|
1107 |
+
|
1108 |
+
EOS_TOKEN = tokenizer.eos_token
|
1109 |
+
def formatting_prompts_func(records):
|
1110 |
+
query = records["query"]
|
1111 |
+
tools = records["answers"]
|
1112 |
+
outputs = []
|
1113 |
+
for query, tools in zip(query, tools):
|
1114 |
+
# Must add EOS_TOKEN,
|
1115 |
+
# otherwise your generation will go on forever
|
1116 |
+
text = self.sft_prompt_template.format(query, tools) \
|
1117 |
+
+ EOS_TOKEN
|
1118 |
+
outputs.append(text)
|
1119 |
+
return { "text" : outputs, }
|
1120 |
+
sft_train_dataset = queries_dataset["train"].map(
|
1121 |
+
formatting_prompts_func, batched=True)
|
1122 |
+
sft_valid_dataset = queries_dataset["validation"].map(
|
1123 |
+
formatting_prompts_func, batched=True,)
|
1124 |
+
#######################################
|
1125 |
+
|
1126 |
+
#######################################
|
1127 |
+
# PEFT adapter #
|
1128 |
+
# for supervised finetuning #
|
1129 |
+
#######################################
|
1130 |
+
# for cases where CPT has been merged into overall model
|
1131 |
+
# otherwize, keep on training current LoRa adapter
|
1132 |
+
# model = FastLanguageModel.get_peft_model(
|
1133 |
+
# model,
|
1134 |
+
# r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
|
1135 |
+
# target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
1136 |
+
# "gate_proj", "up_proj", "down_proj"],
|
1137 |
+
# lora_alpha = 32,
|
1138 |
+
# lora_dropout = 0, # Supports any, but = 0 is optimized
|
1139 |
+
# bias = "none", # Supports any, but = "none" is optimized
|
1140 |
+
# # True or "unsloth" for very long context
|
1141 |
+
# use_gradient_checkpointing = "unsloth",
|
1142 |
+
# random_state = 3407,
|
1143 |
+
# use_rslora = True, # rank stabilized LoRA
|
1144 |
+
# loftq_config = None, # LoftQ
|
1145 |
+
# )
|
1146 |
+
#######################################
|
1147 |
+
|
1148 |
+
#######################################
|
1149 |
+
# sft_trainer #
|
1150 |
+
#######################################
|
1151 |
+
split = sft_train_dataset.train_test_split(
|
1152 |
+
test_size=1000,
|
1153 |
+
#seed=42
|
1154 |
+
)
|
1155 |
+
train_dataset = split['train']
|
1156 |
+
eval_dataset = split['test']
|
1157 |
+
if (
|
1158 |
+
"records_cap" in self.sft_training_args and
|
1159 |
+
self.sft_training_args["records_cap"] is not None and
|
1160 |
+
isinstance(self.sft_training_args["records_cap"], int)
|
1161 |
+
):
|
1162 |
+
train_dataset = train_dataset.take(
|
1163 |
+
self.sft_training_args["records_cap"])
|
1164 |
+
eval_dataset = eval_dataset.take(
|
1165 |
+
self.sft_training_args["records_cap"])
|
1166 |
+
print(f"train_dataset : {train_dataset}")
|
1167 |
+
print(f"eval_dataset : {eval_dataset}")
|
1168 |
+
|
1169 |
+
train_args = UnslothTrainingArguments(
|
1170 |
+
per_device_train_batch_size=2,
|
1171 |
+
gradient_accumulation_steps=8,
|
1172 |
+
|
1173 |
+
**{k: v for k, v in self.sft_training_args.items()
|
1174 |
+
if k != "records_cap"},
|
1175 |
+
|
1176 |
+
per_device_eval_batch_size=2,
|
1177 |
+
eval_steps=200,
|
1178 |
+
eval_strategy="steps",
|
1179 |
+
do_eval=True,
|
1180 |
+
|
1181 |
+
learning_rate=5e-5,
|
1182 |
+
# embedding_learning_rate=1e-5, # Optionally here
|
1183 |
+
|
1184 |
+
fp16=not is_bfloat16_supported(),
|
1185 |
+
bf16=is_bfloat16_supported(),
|
1186 |
+
|
1187 |
+
optim="adamw_8bit",
|
1188 |
+
weight_decay=0.00,
|
1189 |
+
lr_scheduler_type="linear",
|
1190 |
+
#seed=3407,
|
1191 |
+
|
1192 |
+
output_dir=os.path.join(
|
1193 |
+
self.unsloth_dir, "outputs", "sft"),
|
1194 |
+
save_total_limit=2,
|
1195 |
+
|
1196 |
+
logging_steps=1,
|
1197 |
+
report_to="tensorboard",
|
1198 |
+
logging_dir=os.path.join(
|
1199 |
+
self.sft_model_dir,
|
1200 |
+
"runs", "sft")
|
1201 |
+
)
|
1202 |
+
|
1203 |
+
self.sft_traces_file_fullname = os.path.join(
|
1204 |
+
self.unsloth_dir, "sft_trainer_traces.txt")
|
1205 |
+
print("Training started. " +
|
1206 |
+
f"Check {self.sft_traces_file_fullname} for live traces.",
|
1207 |
+
flush=True)
|
1208 |
+
|
1209 |
+
trainer = UnslothTrainer(
|
1210 |
+
model=model, tokenizer=tokenizer,
|
1211 |
+
train_dataset=train_dataset,
|
1212 |
+
dataset_text_field="text",
|
1213 |
+
eval_dataset=eval_dataset,
|
1214 |
+
max_seq_length=self.max_seq_length,
|
1215 |
+
dataset_num_proc=8,
|
1216 |
+
args=train_args
|
1217 |
+
)
|
1218 |
+
trainer.can_return_loss = True
|
1219 |
+
#######################################
|
1220 |
+
|
1221 |
+
#######################################
|
1222 |
+
# Show current memory stats #
|
1223 |
+
#######################################
|
1224 |
+
torch.cuda.ipc_collect()
|
1225 |
+
torch.cuda.empty_cache()
|
1226 |
+
_ = gc.collect()
|
1227 |
+
|
1228 |
+
used_memory = \
|
1229 |
+
round(torch.cuda.max_memory_reserved()
|
1230 |
+
/1024/1024/1024, 3)
|
1231 |
+
used_memory_for_lora = \
|
1232 |
+
round(used_memory-self.start_gpu_memory, 3)
|
1233 |
+
used_percentage = \
|
1234 |
+
round(used_memory/self.max_memory*100, 3)
|
1235 |
+
lora_percentage = \
|
1236 |
+
round(used_memory_for_lora/self.max_memory*100,
|
1237 |
+
3)
|
1238 |
+
print(f"Peak reserved memory = " +
|
1239 |
+
f"{used_memory} GB.")
|
1240 |
+
print(f"Peak reserved memory for " +
|
1241 |
+
f"training = {used_memory_for_lora} " +
|
1242 |
+
f"GB.")
|
1243 |
+
print(f"Peak reserved memory % of " +
|
1244 |
+
f"max memory = {used_percentage} %.")
|
1245 |
+
print(f"Peak reserved memory for training " +
|
1246 |
+
f"% of max memory = {lora_percentage} %.")
|
1247 |
+
#######################################
|
1248 |
+
|
1249 |
+
with open(self.sft_traces_file_fullname, 'w') as f:
|
1250 |
+
with redirect_stdout(f):
|
1251 |
+
hf_logging.set_verbosity_error()
|
1252 |
+
hf_logging.disable_progress_bar()
|
1253 |
+
trainer_stats = trainer.train()
|
1254 |
+
hf_logging.set_verbosity_info()
|
1255 |
+
hf_logging.enable_progress_bar()
|
1256 |
+
print(f"{trainer_stats.metrics['train_runtime']} " +
|
1257 |
+
f"seconds used for training " +
|
1258 |
+
f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
|
1259 |
+
f" minutes).")
|
1260 |
+
|
1261 |
+
self.sft_log_history = trainer.state.log_history
|
1262 |
+
self.sft_log_history_fig = \
|
1263 |
+
plot_log_history(
|
1264 |
+
self.sft_log_history,
|
1265 |
+
title="Supervised finetuning loss"
|
1266 |
+
)
|
1267 |
+
|
1268 |
+
model.save_pretrained_merged(
|
1269 |
+
self.sft_model_dir, tokenizer,
|
1270 |
+
save_method = "lora"
|
1271 |
+
)
|
1272 |
+
print(f"sft_model_dir : {self.sft_model_dir}\n")
|
1273 |
+
|
1274 |
+
self.next(self.evaluate_model)
|
1275 |
+
|
1276 |
+
|
1277 |
+
@step
|
1278 |
+
def evaluate_model(self):
|
1279 |
+
"""
|
1280 |
+
Batch inference on the SFT validation dataset.
|
1281 |
+
"""
|
1282 |
+
from retrain_pipelines.model import \
|
1283 |
+
infer_validation, compute_counts_n_metrics, \
|
1284 |
+
plot_validation_completions
|
1285 |
+
|
1286 |
+
torch.cuda.ipc_collect()
|
1287 |
+
torch.cuda.empty_cache()
|
1288 |
+
_ = gc.collect()
|
1289 |
+
|
1290 |
+
|
1291 |
+
######################################################
|
1292 |
+
# loading trained adapter #
|
1293 |
+
######################################################
|
1294 |
+
# Unsloth [and hf transformers before it] #
|
1295 |
+
# (if loading both model & tokenizer at once #
|
1296 |
+
# same as we did in prior tasks, but now #
|
1297 |
+
# with tokenizer.chat_template being set #
|
1298 |
+
# in tokenizer.config) is forcing on us some kind of #
|
1299 |
+
# chat_template format hard-requirements. #
|
1300 |
+
######################################################
|
1301 |
+
# load base from cache
|
1302 |
+
# (with base tokenizer, which we ignore)
|
1303 |
+
model, _ = FastLanguageModel.from_pretrained(
|
1304 |
+
model_name=self.hf_base_model_dict["repo_id"],
|
1305 |
+
revision=self.hf_base_model_dict["commit_hash"],
|
1306 |
+
max_seq_length=self.max_seq_length,
|
1307 |
+
dtype=None,
|
1308 |
+
load_in_4bit=False,
|
1309 |
+
# case of a gated or private base-model
|
1310 |
+
token=os.getenv("HF_TOKEN", None)
|
1311 |
+
)
|
1312 |
+
model = FastLanguageModel.for_inference(model)
|
1313 |
+
# load our CPT+SFT trained & locally-saved adapter
|
1314 |
+
model.load_adapter(peft_model_id=self.sft_model_dir)
|
1315 |
+
# Separately load our (potentially trained &)
|
1316 |
+
# locally-saved adapter-tokenizer
|
1317 |
+
# (loading it below via HF and not Unsloth)
|
1318 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
1319 |
+
pretrained_model_name_or_path=self.sft_model_dir
|
1320 |
+
)
|
1321 |
+
######################################################
|
1322 |
+
|
1323 |
+
######################################################
|
1324 |
+
# validation dataset #
|
1325 |
+
######################################################
|
1326 |
+
# download from Hub (or get from local cache)
|
1327 |
+
queries_dataset = load_dataset(
|
1328 |
+
path=self.dataset_commit_dict["repo_id"],
|
1329 |
+
name="supervised_finetuning",
|
1330 |
+
revision=self.dataset_commit_dict["commit_hash"],
|
1331 |
+
token=os.getenv("HF_TOKEN", None))
|
1332 |
+
if (
|
1333 |
+
"records_cap" in self.sft_training_args and
|
1334 |
+
self.sft_training_args["records_cap"] is not None and
|
1335 |
+
isinstance(self.sft_training_args["records_cap"], int)
|
1336 |
+
):
|
1337 |
+
validation_data = queries_dataset["validation"].take(
|
1338 |
+
self.sft_training_args["records_cap"])
|
1339 |
+
else:
|
1340 |
+
validation_data = queries_dataset["validation"]
|
1341 |
+
print(validation_data, flush=True)
|
1342 |
+
######################################################
|
1343 |
+
|
1344 |
+
self.max_new_tokens = 400
|
1345 |
+
start_time = time.time()
|
1346 |
+
validation_results = infer_validation(
|
1347 |
+
tokenizer=tokenizer,
|
1348 |
+
model=model,
|
1349 |
+
validation_data=validation_data,
|
1350 |
+
prompt_template=tokenizer.chat_template,
|
1351 |
+
batch_size=32, # 64,
|
1352 |
+
queries_attr_name=\
|
1353 |
+
self.hf_dataset["attributes"]["query_attr"],
|
1354 |
+
answers_attr_name=\
|
1355 |
+
self.hf_dataset["attributes"]["answers_attr"],
|
1356 |
+
max_new_tokens=self.max_new_tokens,
|
1357 |
+
device="cuda"
|
1358 |
+
)
|
1359 |
+
print("infer_validation - Elapsed time: " +
|
1360 |
+
f"{(time.time() - start_time):.2f} seconds")
|
1361 |
+
self.validation_results = validation_results # <= to artifacts store
|
1362 |
+
|
1363 |
+
eval_df = pl.LazyFrame(validation_results)
|
1364 |
+
|
1365 |
+
records = eval_df.with_columns(
|
1366 |
+
(pl.col("answer") == pl.col("completion")) \
|
1367 |
+
.alias("is_ground_truth_identical")
|
1368 |
+
).collect() #engine=self.engine)
|
1369 |
+
print("perfect characters-match accuracy : " +
|
1370 |
+
str(records['is_ground_truth_identical'].mean()))
|
1371 |
+
|
1372 |
+
eval_metrics_df = compute_counts_n_metrics(
|
1373 |
+
eval_df, is_format_fault_tolerant=True)
|
1374 |
+
overall_metrics_df = eval_metrics_df.select([
|
1375 |
+
pl.col("precision").mean(),
|
1376 |
+
pl.col("recall").mean(),
|
1377 |
+
pl.col("f1").mean(),
|
1378 |
+
pl.col("jaccard").mean()
|
1379 |
+
]).collect() #engine=self.engine)
|
1380 |
+
self.perf_metrics = overall_metrics_df.row(0, named=True)
|
1381 |
+
print(self.perf_metrics)
|
1382 |
+
|
1383 |
+
self.validation_completions_fig = \
|
1384 |
+
plot_validation_completions(
|
1385 |
+
eval_metrics_df, engine=self.engine)
|
1386 |
+
|
1387 |
+
del model
|
1388 |
+
del tokenizer
|
1389 |
+
torch.cuda.ipc_collect()
|
1390 |
+
torch.cuda.empty_cache()
|
1391 |
+
_ = gc.collect()
|
1392 |
+
|
1393 |
+
self.next(self.model_version_blessing)
|
1394 |
+
|
1395 |
+
|
1396 |
+
@step
|
1397 |
+
def model_version_blessing(self):
|
1398 |
+
"""
|
1399 |
+
Comparing newly-retrained model version
|
1400 |
+
against best-performing predecessor.
|
1401 |
+
"""
|
1402 |
+
"""
|
1403 |
+
Note: for Hugging Face integrated pipelines,
|
1404 |
+
we compare against lastest commit of main branch
|
1405 |
+
of the model repository there.
|
1406 |
+
When it comes to local "mf_run_id" of the pipeline run
|
1407 |
+
having generated that best prior model version
|
1408 |
+
(retrieved from model card metadata from HF yaml section),
|
1409 |
+
we check against records of the herein ML-framework instance,
|
1410 |
+
as "prior best version" of the model here beign retrained
|
1411 |
+
may have been originated from another one
|
1412 |
+
than the one executing the current retraining
|
1413 |
+
(in which case, we simply don't includ a "local" hyperlink
|
1414 |
+
in the model version pipeline_cards that will be
|
1415 |
+
produced later in the herein pipeline run).
|
1416 |
+
"""
|
1417 |
+
from retrain_pipelines.model.hf_utils import \
|
1418 |
+
current_blessed_model_version_dict
|
1419 |
+
|
1420 |
+
main_perf_metric_name = "jaccard"
|
1421 |
+
|
1422 |
+
current_blessed_version_dict = \
|
1423 |
+
current_blessed_model_version_dict(
|
1424 |
+
repo_id=self.model_repo_id,
|
1425 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
1426 |
+
)
|
1427 |
+
print("current_blessed_version_dict : " +
|
1428 |
+
str(current_blessed_version_dict))
|
1429 |
+
|
1430 |
+
if current_blessed_version_dict is None:
|
1431 |
+
print("case 'no prior blessed model version found"
|
1432 |
+
" => blessing.'")
|
1433 |
+
self.model_version_blessed = True
|
1434 |
+
|
1435 |
+
elif (
|
1436 |
+
main_perf_metric_name in
|
1437 |
+
current_blessed_version_dict["perf_metrics"]
|
1438 |
+
):
|
1439 |
+
current_blessed_run_id = \
|
1440 |
+
current_blessed_version_dict["mf_run_id"]
|
1441 |
+
print(f"current_blessed_run_id : {current_blessed_run_id}")
|
1442 |
+
current_blessed_metric_value = \
|
1443 |
+
current_blessed_version_dict[
|
1444 |
+
"perf_metrics"][main_perf_metric_name]
|
1445 |
+
|
1446 |
+
self.model_version_blessed = (
|
1447 |
+
self.perf_metrics[main_perf_metric_name] >=
|
1448 |
+
current_blessed_metric_value
|
1449 |
+
)
|
1450 |
+
|
1451 |
+
if not self.model_version_blessed:
|
1452 |
+
self.current_blessed_version_dict = \
|
1453 |
+
current_blessed_version_dict
|
1454 |
+
for run in Flow(self.__class__.__name__):
|
1455 |
+
if str(run.id) == current_blessed_run_id:
|
1456 |
+
run_steps = iter(run.steps())
|
1457 |
+
last_run_step = next(run_steps)
|
1458 |
+
last_task = next(iter(last_run_step.tasks()))
|
1459 |
+
|
1460 |
+
# tasks are listed backwards, so last task is first item :
|
1461 |
+
# Has the run seen task "pipeline_card" prior to last task
|
1462 |
+
# (meaning, "pipeline_card" completed successfully and
|
1463 |
+
# "run" has generated a sutom pipeline-card artifact) ?
|
1464 |
+
# If not, hyperlink generation will later fail.
|
1465 |
+
run_has_custom_card_artifact = False
|
1466 |
+
for step in run_steps:
|
1467 |
+
if "pipeline_card" == step.id:
|
1468 |
+
run_has_custom_card_artifact = True
|
1469 |
+
break
|
1470 |
+
|
1471 |
+
if not run_has_custom_card_artifact:
|
1472 |
+
print(
|
1473 |
+
f"Run #{current_blessed_run_id} " +
|
1474 |
+
"Doesn't seem to have successfully " +
|
1475 |
+
"generated a pipeline-card artifact.",
|
1476 |
+
file=sys.stderr, flush=True)
|
1477 |
+
break
|
1478 |
+
else:
|
1479 |
+
# further filtering on successful runs that are
|
1480 |
+
# retraining of a prior version of the same model
|
1481 |
+
# (to minimize the risk that this was obtained
|
1482 |
+
# on another ML-framework instance)
|
1483 |
+
if (
|
1484 |
+
# last_task.successful and
|
1485 |
+
# may have failed after the "pipeline_card" step
|
1486 |
+
# and been resumed
|
1487 |
+
hasattr(last_task.artifacts,
|
1488 |
+
'model_version_blessed') and
|
1489 |
+
last_task.artifacts.model_version_blessed.data and
|
1490 |
+
hasattr(last_task.artifacts,
|
1491 |
+
'model_repo_id') and
|
1492 |
+
last_task.artifacts.model_repo_id.data == \
|
1493 |
+
self.model_repo_id
|
1494 |
+
):
|
1495 |
+
self.current_blessed_run = run
|
1496 |
+
break
|
1497 |
+
|
1498 |
+
if not self.current_blessed_run:
|
1499 |
+
print(
|
1500 |
+
"Couldn't find blessed run " +
|
1501 |
+
f"{current_blessed_run_id} !\n" +
|
1502 |
+
"It seems that prior blessed run was " +
|
1503 |
+
"executed on another ML framework instance.",
|
1504 |
+
file=sys.stderr, flush=True)
|
1505 |
+
|
1506 |
+
print("new : " +
|
1507 |
+
str(self.perf_metrics[main_perf_metric_name]) +
|
1508 |
+
" - previous best : " +
|
1509 |
+
str(current_blessed_metric_value) +
|
1510 |
+
" - model_version_blessing : " +
|
1511 |
+
str(self.model_version_blessed))
|
1512 |
+
|
1513 |
+
else:
|
1514 |
+
raise Exception(
|
1515 |
+
"Performance metric '" +
|
1516 |
+
main_perf_metric_name +
|
1517 |
+
"' can't be found in eval results " +
|
1518 |
+
"from blessed run " +
|
1519 |
+
str(current_blessed_version_dict[
|
1520 |
+
"mf_run_id"]) + " !")
|
1521 |
+
|
1522 |
+
# self.model_version_blessed = True ### DEBUG - DELETE ###
|
1523 |
+
|
1524 |
+
self.next(self.model_to_hub)
|
1525 |
+
|
1526 |
+
|
1527 |
+
@step
|
1528 |
+
def model_to_hub(self):
|
1529 |
+
"""
|
1530 |
+
Push to hub model version, including
|
1531 |
+
readme with versioning info.
|
1532 |
+
"""
|
1533 |
+
|
1534 |
+
#############################
|
1535 |
+
# case of user-provided #
|
1536 |
+
# documentation artifact(s) #
|
1537 |
+
#############################
|
1538 |
+
# note that user can provide either
|
1539 |
+
# 'pipeline_card.py' or 'template.html'
|
1540 |
+
# or 'dataset_readme.py'
|
1541 |
+
# or 'dataset_readme_template.md'
|
1542 |
+
# or 'model_readme.py'
|
1543 |
+
# or 'model_readme_template.md'
|
1544 |
+
# or any combination of those
|
1545 |
+
# when specifying custom
|
1546 |
+
# 'pipeline_card_artifacts_path'
|
1547 |
+
if (
|
1548 |
+
"model_readme_template.md" in
|
1549 |
+
os.listdir(self.pipeline_card_artifacts_path)
|
1550 |
+
):
|
1551 |
+
template_dir = self.pipeline_card_artifacts_path
|
1552 |
+
else:
|
1553 |
+
template_dir = os.path.dirname(
|
1554 |
+
importlib.util.find_spec(
|
1555 |
+
f"retrain_pipelines.pipeline_card."+
|
1556 |
+
f"{os.getenv('retrain_pipeline_type')}"
|
1557 |
+
).origin)
|
1558 |
+
print(f"template_dir : '{template_dir}'")
|
1559 |
+
#############################
|
1560 |
+
if "model_readme.py" in os.listdir(
|
1561 |
+
self.pipeline_card_artifacts_path):
|
1562 |
+
from retrain_pipelines.utils import \
|
1563 |
+
get_get_model_readme_content
|
1564 |
+
get_model_readme_content = \
|
1565 |
+
get_get_model_readme_content(
|
1566 |
+
self.pipeline_card_artifacts_path)
|
1567 |
+
else:
|
1568 |
+
from retrain_pipelines.pipeline_card import \
|
1569 |
+
get_model_readme_content
|
1570 |
+
#############################
|
1571 |
+
from retrain_pipelines.model.hf_utils import \
|
1572 |
+
push_model_version_to_hub
|
1573 |
+
|
1574 |
+
#############################
|
1575 |
+
# model README #
|
1576 |
+
# from template #
|
1577 |
+
#############################
|
1578 |
+
commit_datetime = datetime.utcnow()
|
1579 |
+
new_model_version_label = get_new_repo_minor_version(
|
1580 |
+
repo_id=self.model_repo_id,
|
1581 |
+
repo_type="model",
|
1582 |
+
hf_token=os.getenv("HF_TOKEN", None))
|
1583 |
+
readme_content = get_model_readme_content(
|
1584 |
+
template_folder=template_dir,
|
1585 |
+
|
1586 |
+
model_repo_id=self.model_repo_id,
|
1587 |
+
|
1588 |
+
base_model_dict=self.hf_base_model_dict,
|
1589 |
+
training_dataset_dict=self.dataset_commit_dict,
|
1590 |
+
|
1591 |
+
version_label=new_model_version_label,
|
1592 |
+
commit_datetime=commit_datetime,
|
1593 |
+
perf_metrics=self.perf_metrics,
|
1594 |
+
|
1595 |
+
mf_flow_name=current.flow_name,
|
1596 |
+
mf_run_id=current.run.id
|
1597 |
+
)
|
1598 |
+
#############################
|
1599 |
+
|
1600 |
+
print("Pushing model version to HF hub " +
|
1601 |
+
("(blessed). " if self.model_version_blessed
|
1602 |
+
else "(not blessed). ") +
|
1603 |
+
"May take a while..",
|
1604 |
+
flush=True)
|
1605 |
+
model_commit_hash = push_model_version_to_hub(
|
1606 |
+
repo_id=self.model_repo_id,
|
1607 |
+
model_version_blessed=\
|
1608 |
+
self.model_version_blessed,
|
1609 |
+
version_label=new_model_version_label,
|
1610 |
+
timestamp_str=commit_datetime.strftime(
|
1611 |
+
"%Y-%m-%d %H:%M:%S UTC"),
|
1612 |
+
model_dir=self.sft_model_dir,
|
1613 |
+
model_readme_content=readme_content,
|
1614 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
1615 |
+
)
|
1616 |
+
if not model_commit_hash:
|
1617 |
+
raise Exception(
|
1618 |
+
"Failed to publish model version.")
|
1619 |
+
print("Push of model version to HF hub completed.",
|
1620 |
+
flush=True)
|
1621 |
+
print(f"https://huggingface.co/{self.model_repo_id}" +
|
1622 |
+
f"/blob/{model_commit_hash}/README.md")
|
1623 |
+
|
1624 |
+
self.model_commit_dict = {
|
1625 |
+
"repo_id": self.model_repo_id,
|
1626 |
+
"commit_hash": model_commit_hash,
|
1627 |
+
"version_label": new_model_version_label,
|
1628 |
+
"commit_datetime": commit_datetime,
|
1629 |
+
}
|
1630 |
+
|
1631 |
+
self.next(self.infra_validator)
|
1632 |
+
|
1633 |
+
|
1634 |
+
@step
|
1635 |
+
def infra_validator(self):
|
1636 |
+
"""
|
1637 |
+
If the trained model version is blessed,
|
1638 |
+
validate serving.
|
1639 |
+
"""
|
1640 |
+
"""
|
1641 |
+
Note that using isolated virtual env
|
1642 |
+
(using @conda task decorator)
|
1643 |
+
is advisable to not embark the whole
|
1644 |
+
pipeline dependencies into the local server.
|
1645 |
+
We don't for educational purpose,
|
1646 |
+
keep things "simple" to grasp
|
1647 |
+
as well as to avoid forcing conda
|
1648 |
+
(for instance miniconda) as
|
1649 |
+
a virtual environment management mean
|
1650 |
+
to the user.
|
1651 |
+
"""
|
1652 |
+
"""
|
1653 |
+
Note : We load base model from HF-cache
|
1654 |
+
(mounted as /huggingface_hub_cache
|
1655 |
+
docker volume) and adapter from local dir
|
1656 |
+
(mounted as /FuncCallAdater docker volume.
|
1657 |
+
"""
|
1658 |
+
|
1659 |
+
self.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
|
1660 |
+
|
1661 |
+
if self.model_version_blessed:
|
1662 |
+
from retrain_pipelines.utils.docker import \
|
1663 |
+
env_has_docker
|
1664 |
+
|
1665 |
+
if env_has_docker():
|
1666 |
+
model_module_dir = \
|
1667 |
+
os.path.dirname(
|
1668 |
+
importlib.util.find_spec(
|
1669 |
+
"retrain_pipelines.model." +
|
1670 |
+
os.getenv('retrain_pipeline_type')
|
1671 |
+
).origin)
|
1672 |
+
|
1673 |
+
# server & data-model & server-config modules artifacts
|
1674 |
+
files_to_copy = [
|
1675 |
+
"litserve_server.py",
|
1676 |
+
"litserve_datamodel.py",
|
1677 |
+
"litserve_serverconfig.py",
|
1678 |
+
".dockerignore" # docker context loading
|
1679 |
+
# at image-build time,
|
1680 |
+
# exclude model weights
|
1681 |
+
]
|
1682 |
+
for filename in files_to_copy:
|
1683 |
+
shutil.copy(
|
1684 |
+
os.path.join(model_module_dir, "litserve",
|
1685 |
+
filename),
|
1686 |
+
os.path.join(self.serving_artifacts_local_folder,
|
1687 |
+
filename)
|
1688 |
+
)
|
1689 |
+
|
1690 |
+
# save dependencies as artifact
|
1691 |
+
create_requirements(self.serving_artifacts_local_folder,
|
1692 |
+
exclude=["cudf-polars-.*", "cuda-python",
|
1693 |
+
"nvidia-.*", "(py)?libcudf-.*",
|
1694 |
+
"nvtx", "rmm-.*", "litserve",
|
1695 |
+
".*retrain-pipelines.*"]
|
1696 |
+
)
|
1697 |
+
|
1698 |
+
# server config yaml
|
1699 |
+
env = Environment(loader=FileSystemLoader(
|
1700 |
+
os.path.join(model_module_dir, "litserve")))
|
1701 |
+
template = env.get_template(
|
1702 |
+
"litserve_serverconfig_template.yaml")
|
1703 |
+
server_config_data = {
|
1704 |
+
"port": "8000",
|
1705 |
+
"max_seq_length": self.max_seq_length,
|
1706 |
+
"max_new_token": self.max_new_tokens,
|
1707 |
+
"base_model": {
|
1708 |
+
"repo_id": self.hf_base_model_dict["repo_id"],
|
1709 |
+
"revision": self.hf_base_model_dict["commit_hash"]
|
1710 |
+
},
|
1711 |
+
"adapters": [
|
1712 |
+
{
|
1713 |
+
"name": "func_caller",
|
1714 |
+
"path": "/FuncCallAdapter"
|
1715 |
+
}
|
1716 |
+
]
|
1717 |
+
}
|
1718 |
+
server_config_yaml = template.render(server_config_data)
|
1719 |
+
print(server_config_yaml)
|
1720 |
+
with open(os.path.join(
|
1721 |
+
self.serving_artifacts_local_folder,
|
1722 |
+
"litserve_serverconfig.yaml"), 'w'
|
1723 |
+
) as output_file:
|
1724 |
+
output_file.write(server_config_yaml)
|
1725 |
+
|
1726 |
+
# Dockerfile
|
1727 |
+
env = Environment(loader=FileSystemLoader(
|
1728 |
+
os.path.join(model_module_dir)))
|
1729 |
+
template = env.get_template(
|
1730 |
+
"Dockerfile.litserve_template")
|
1731 |
+
# Change CUDA version here from available list
|
1732 |
+
# @see https://hub.docker.com/r/nvidia/cuda/tags
|
1733 |
+
dockerfile_content = template.render(
|
1734 |
+
{"cuda_version": "12.0.0"})
|
1735 |
+
with open(os.path.join(
|
1736 |
+
self.serving_artifacts_local_folder,
|
1737 |
+
"Dockerfile.litserve"), 'w'
|
1738 |
+
) as output_file:
|
1739 |
+
output_file.write(dockerfile_content)
|
1740 |
+
|
1741 |
+
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
1742 |
+
|
1743 |
+
############################################
|
1744 |
+
# actually deploy the inference service #
|
1745 |
+
############################################
|
1746 |
+
start_time = time.time()
|
1747 |
+
from retrain_pipelines.utils.docker import \
|
1748 |
+
build_and_run_docker, print_container_log_tail, \
|
1749 |
+
cleanup_docker
|
1750 |
+
from retrain_pipelines.model.litserve import \
|
1751 |
+
endpoint_started, endpoint_is_ready
|
1752 |
+
|
1753 |
+
self.port = 8765
|
1754 |
+
HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
|
1755 |
+
os.getenv(
|
1756 |
+
"HF_HUB_CACHE",
|
1757 |
+
os.path.join(os.getenv("HF_HOME",
|
1758 |
+
"~/.cache/huggingface"),
|
1759 |
+
"hub")
|
1760 |
+
)))
|
1761 |
+
print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
|
1762 |
+
image_name = container_name = "litserve-model"
|
1763 |
+
|
1764 |
+
serving_container = build_and_run_docker(
|
1765 |
+
image_name=image_name, image_tag="1.0",
|
1766 |
+
build_path=self.serving_artifacts_local_folder,
|
1767 |
+
dockerfile="Dockerfile.litserve",
|
1768 |
+
ports_publish_dict={'8000/tcp': self.port},
|
1769 |
+
env_vars_dict={
|
1770 |
+
"HF_HUB_CACHE": "/huggingface_hub_cache",
|
1771 |
+
"HF_TOKEN": os.getenv("HF_TOKEN")
|
1772 |
+
},
|
1773 |
+
volumes_dict={
|
1774 |
+
self.sft_model_dir:
|
1775 |
+
{"bind": "/FuncCallAdapter",
|
1776 |
+
"mode": "ro"},
|
1777 |
+
HF_HUB_CACHE:
|
1778 |
+
{"bind": "/huggingface_hub_cache",
|
1779 |
+
"mode": "ro"}
|
1780 |
+
}
|
1781 |
+
)
|
1782 |
+
|
1783 |
+
if not serving_container:
|
1784 |
+
print("failed spinning the LitServe container",
|
1785 |
+
file=sys.stderr)
|
1786 |
+
self.local_serve_is_ready = \
|
1787 |
+
LocalServeReadinessEnum.FAILURE
|
1788 |
+
try:
|
1789 |
+
cleanup_docker(
|
1790 |
+
container_name=container_name,
|
1791 |
+
image_name=f"{image_name}:1.0",
|
1792 |
+
no_pruning=True # for intermediate layers recycling
|
1793 |
+
# (during later re-runs)
|
1794 |
+
# to avoid long rebuild time
|
1795 |
+
# of exactly the same.
|
1796 |
+
)
|
1797 |
+
except Exception as cleanup_ex:
|
1798 |
+
# fail silently
|
1799 |
+
pass
|
1800 |
+
else:
|
1801 |
+
print("Awaiting endpoint launch..")
|
1802 |
+
start_time = time.time()
|
1803 |
+
if not endpoint_started(
|
1804 |
+
container_name, port=self.port, timeout=10*60
|
1805 |
+
):
|
1806 |
+
print(
|
1807 |
+
f"The endpoint '{container_name}' " +
|
1808 |
+
f"did not start.")
|
1809 |
+
self.local_serve_is_ready = \
|
1810 |
+
LocalServeReadinessEnum.FAILURE
|
1811 |
+
# health check on the spun-up endpoint
|
1812 |
+
elif endpoint_is_ready(port=self.port):
|
1813 |
+
self.local_serve_is_ready = \
|
1814 |
+
LocalServeReadinessEnum.SUCCESS
|
1815 |
+
elapsed_time = time.time() - start_time
|
1816 |
+
print("deploy_local - Elapsed time: " +
|
1817 |
+
f"{elapsed_time:.2f} seconds")
|
1818 |
+
############################################
|
1819 |
+
else:
|
1820 |
+
# env doesn't have docker
|
1821 |
+
self.local_serve_is_ready = \
|
1822 |
+
LocalServeReadinessEnum.FAILURE_NO_DOCKER
|
1823 |
+
|
1824 |
+
if LocalServeReadinessEnum.SUCCESS == self.local_serve_is_ready:
|
1825 |
+
from retrain_pipelines.model.litserve.litserve_datamodel \
|
1826 |
+
import Response
|
1827 |
+
|
1828 |
+
import requests
|
1829 |
+
|
1830 |
+
url = f"http://localhost:{self.port}/predict"
|
1831 |
+
headers = {"accept": "application/x-www-form-urlencoded"}
|
1832 |
+
|
1833 |
+
try:
|
1834 |
+
start_time = time.time()
|
1835 |
+
data = {
|
1836 |
+
"adapter_name": "func_caller",
|
1837 |
+
"queries": '["Hello.", "Is 49 a perfect square?"]'
|
1838 |
+
}
|
1839 |
+
print(f"inference test - data: {data}")
|
1840 |
+
response = requests.post(url, headers=headers, data=data)
|
1841 |
+
parsed_response = Response(**{"output": response.json()})
|
1842 |
+
elapsed_time = time.time() - start_time
|
1843 |
+
print("parsed_response ('func_caller' adapter ON) :" +
|
1844 |
+
str(parsed_response) +
|
1845 |
+
f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
|
1846 |
+
|
1847 |
+
start_time = time.time()
|
1848 |
+
data = {
|
1849 |
+
"queries": '["Hello.", "Is 49 a perfect square?"]'
|
1850 |
+
}
|
1851 |
+
print(f"inference test - data: {data}")
|
1852 |
+
response = requests.post(url, headers=headers, data=data)
|
1853 |
+
parsed_response = Response(**{"output": response.json()})
|
1854 |
+
elapsed_time = time.time() - start_time
|
1855 |
+
print(f"parsed_response (no adapter) : {parsed_response}" +
|
1856 |
+
f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
|
1857 |
+
|
1858 |
+
except Exception as ex:
|
1859 |
+
print(ex, file=sys.stderr)
|
1860 |
+
traceback.print_tb(ex.__traceback__, file=sys.stderr)
|
1861 |
+
self.local_serve_is_ready = \
|
1862 |
+
LocalServeReadinessEnum.FAILURE
|
1863 |
+
pass
|
1864 |
+
|
1865 |
+
try:
|
1866 |
+
cleanup_docker(
|
1867 |
+
container_name=container_name,
|
1868 |
+
image_name=f"{image_name}:1.0",
|
1869 |
+
no_pruning=True # for intermediate layers recycling
|
1870 |
+
# (during later re-runs)
|
1871 |
+
# to avoid long rebuild time
|
1872 |
+
# of exactly the same.
|
1873 |
+
)
|
1874 |
+
except Exception as cleanup_ex:
|
1875 |
+
# fail silently
|
1876 |
+
pass
|
1877 |
+
|
1878 |
+
self.next(self.pipeline_card)
|
1879 |
+
|
1880 |
+
|
1881 |
+
@card(id='default')
|
1882 |
+
@card(type='html', id='custom')
|
1883 |
+
@step
|
1884 |
+
def pipeline_card(self):
|
1885 |
+
import re
|
1886 |
+
import datetime
|
1887 |
+
import importlib.metadata
|
1888 |
+
|
1889 |
+
#############################
|
1890 |
+
# case of user-provided #
|
1891 |
+
# documentation artifact(s) #
|
1892 |
+
#############################
|
1893 |
+
# note that user can provide either
|
1894 |
+
# 'pipeline_card.py' or 'template.html'
|
1895 |
+
# or 'dataset_readme.py'
|
1896 |
+
# or 'dataset_readme_template.md'
|
1897 |
+
# or 'model_readme.py'
|
1898 |
+
# or 'model_readme_template.md'
|
1899 |
+
# or any combination of those
|
1900 |
+
# when specifying custom
|
1901 |
+
# 'pipeline_card_artifacts_path'
|
1902 |
+
if "template.html" in os.listdir(
|
1903 |
+
self.pipeline_card_artifacts_path
|
1904 |
+
):
|
1905 |
+
template_dir = self.pipeline_card_artifacts_path
|
1906 |
+
else:
|
1907 |
+
template_dir = os.path.dirname(
|
1908 |
+
importlib.util.find_spec(
|
1909 |
+
f"retrain_pipelines.pipeline_card."+
|
1910 |
+
f"{os.getenv('retrain_pipeline_type')}"
|
1911 |
+
).origin)
|
1912 |
+
#############################
|
1913 |
+
if "pipeline_card.py" in os.listdir(
|
1914 |
+
self.pipeline_card_artifacts_path
|
1915 |
+
):
|
1916 |
+
from retrain_pipelines.utils import get_get_html
|
1917 |
+
get_html = \
|
1918 |
+
get_get_html(self.pipeline_card_artifacts_path)
|
1919 |
+
else:
|
1920 |
+
from retrain_pipelines.pipeline_card import \
|
1921 |
+
get_html
|
1922 |
+
from retrain_pipelines.pipeline_card.helpers import \
|
1923 |
+
mf_dag_svg
|
1924 |
+
#############################
|
1925 |
+
|
1926 |
+
|
1927 |
+
#############################
|
1928 |
+
## "default" card ##
|
1929 |
+
#############################
|
1930 |
+
self.metadata = {
|
1931 |
+
"name": "TabNet Model",
|
1932 |
+
"version": "1.0",
|
1933 |
+
"retrain_pipelines": f"retrain-pipelines {__version__}",
|
1934 |
+
"retrain_pipeline_type": os.environ["retrain_pipeline_type"],
|
1935 |
+
"description": "A PyTorch TabNet model retrained",
|
1936 |
+
"authors": [current.username],
|
1937 |
+
"tags": ["classification", "tabnet"],
|
1938 |
+
"license": "MIT License",
|
1939 |
+
"data_augmentation": [
|
1940 |
+
{
|
1941 |
+
"name": "Augmentation",
|
1942 |
+
"description": "Truncating queries and " + \
|
1943 |
+
"associate those to " + \
|
1944 |
+
"no tool-call answers. " + \
|
1945 |
+
"Intent being to instruct on " + \
|
1946 |
+
"not hallucinating missing " + \
|
1947 |
+
"tool-calls parameters values."
|
1948 |
+
},
|
1949 |
+
{
|
1950 |
+
"name": "Enrichment",
|
1951 |
+
"description": "Addition of records " + \
|
1952 |
+
"from an external data-source. " + \
|
1953 |
+
"Here to instruct on no tool-call."
|
1954 |
+
}
|
1955 |
+
],
|
1956 |
+
"references": [
|
1957 |
+
{
|
1958 |
+
"title": "Base model",
|
1959 |
+
"link": f"https://hf.co/{self.hf_base_model_dict['repo_id']}"
|
1960 |
+
},
|
1961 |
+
{
|
1962 |
+
"title": "Function-calling dataset",
|
1963 |
+
"link": f"https://hf.co/{self.hf_dataset_dict['repo_id']}"
|
1964 |
+
},
|
1965 |
+
{
|
1966 |
+
"title": "Data-enrichment dataset",
|
1967 |
+
"link": f"https://hf.co/{self.hf_enrich_dataset_dict['repo_id']}"
|
1968 |
+
},
|
1969 |
+
{
|
1970 |
+
"title": "Unsloth",
|
1971 |
+
"link": "https://unsloth.ai/blog/contpretraining"
|
1972 |
+
}
|
1973 |
+
]
|
1974 |
+
}
|
1975 |
+
|
1976 |
+
current.card['default'].append(Markdown(
|
1977 |
+
"model_version_blessed : **%s**" % str(self.model_version_blessed)))
|
1978 |
+
current.card['default'].append(Artifact(
|
1979 |
+
{"model_version_blessed": self.model_version_blessed}))
|
1980 |
+
|
1981 |
+
current.card['default'].append(
|
1982 |
+
Image.from_matplotlib(self.sft_log_history_fig))
|
1983 |
+
current.card['default'].append(
|
1984 |
+
Image.from_matplotlib(self.validation_completions_fig))
|
1985 |
+
#############################
|
1986 |
+
|
1987 |
+
#############################
|
1988 |
+
## html "custom" card ##
|
1989 |
+
#############################
|
1990 |
+
dt = datetime.datetime.now(tz=datetime.timezone.utc)
|
1991 |
+
formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
|
1992 |
+
task_obj_python_cmd = f"metaflow.Task(" + \
|
1993 |
+
f"\"{current.pathspec}\", " + \
|
1994 |
+
f"attempt={str(current.retry_count)})"
|
1995 |
+
params={
|
1996 |
+
'template_dir': template_dir,
|
1997 |
+
'title': f"{current.flow_name}",
|
1998 |
+
"subtitle": f"(flow run # {len(list(current.run.parent.runs()))}," + \
|
1999 |
+
f" run_id: {str(current.run.id)} - {formatted_dt})",
|
2000 |
+
|
2001 |
+
# blessed status / current_blessed version
|
2002 |
+
'model_version_blessed': self.model_version_blessed,
|
2003 |
+
'current_blessed_version_label': (
|
2004 |
+
self.current_blessed_version_dict["version_label"]
|
2005 |
+
if self.current_blessed_version_dict
|
2006 |
+
else None
|
2007 |
+
),
|
2008 |
+
'current_blessed_commit_datetime': (
|
2009 |
+
self.current_blessed_version_dict["commit_datetime"]
|
2010 |
+
if self.current_blessed_version_dict
|
2011 |
+
else None
|
2012 |
+
),
|
2013 |
+
'current_blessed_model_commit_hash': (
|
2014 |
+
self.current_blessed_version_dict["commit_hash"]
|
2015 |
+
if self.current_blessed_version_dict
|
2016 |
+
else None
|
2017 |
+
),
|
2018 |
+
'current_blessed_run': self.current_blessed_run,
|
2019 |
+
|
2020 |
+
'LocalServeReadinessEnum': LocalServeReadinessEnum,
|
2021 |
+
'local_serve_is_ready': self.local_serve_is_ready,
|
2022 |
+
# EDA
|
2023 |
+
'main_dataset_repo_id': self.hf_dataset['repo_id'],
|
2024 |
+
'main_dataset_commit_hash': self.hf_dataset_dict['commit_hash'],
|
2025 |
+
'main_dataset_commit_datetime': \
|
2026 |
+
self.hf_dataset_dict['commit_datetime'],
|
2027 |
+
|
2028 |
+
'records_count': self.records_count,
|
2029 |
+
'data_schema': self.data_schema,
|
2030 |
+
'answers_tools_count_fig': self.answers_tools_count_fig,
|
2031 |
+
'words_count_fig': self.words_count_fig,
|
2032 |
+
|
2033 |
+
# model training
|
2034 |
+
'dataset_repo_id': self.dataset_repo_id,
|
2035 |
+
'dataset_version_label': self.dataset_commit_dict["version_label"],
|
2036 |
+
'dataset_commit_datetime': self.dataset_commit_dict["commit_datetime"],
|
2037 |
+
'dataset_commit_hash': self.dataset_commit_dict["commit_hash"],
|
2038 |
+
'dataset_augmentation_rate': self.actual_augmentation_rate,
|
2039 |
+
'dataset_enrichment_rate': self.enrichment_rate,
|
2040 |
+
|
2041 |
+
'model_repo_id': self.model_repo_id,
|
2042 |
+
'model_version_label': self.model_commit_dict["version_label"],
|
2043 |
+
'model_commit_datetime': self.model_commit_dict["commit_datetime"],
|
2044 |
+
'model_commit_hash': self.model_commit_dict["commit_hash"],
|
2045 |
+
|
2046 |
+
'cpt_log_history_fig': self.cpt_log_history_fig,
|
2047 |
+
'sft_log_history_fig': self.sft_log_history_fig,
|
2048 |
+
|
2049 |
+
'validation_completions_fig': self.validation_completions_fig,
|
2050 |
+
|
2051 |
+
'pipeline_parameters_dict': {"cpt": self.cpt_training_args,
|
2052 |
+
"sft": self.sft_training_args},
|
2053 |
+
|
2054 |
+
'metrics_dict': self.perf_metrics,
|
2055 |
+
|
2056 |
+
'task_obj_python_cmd': task_obj_python_cmd,
|
2057 |
+
'dag_svg': mf_dag_svg(self)
|
2058 |
+
}
|
2059 |
+
self.html = get_html(params)
|
2060 |
+
#############################
|
2061 |
+
current
|
2062 |
+
#############################
|
2063 |
+
|
2064 |
+
self.next(self.pipeline_to_hub)
|
2065 |
+
|
2066 |
+
|
2067 |
+
@step
|
2068 |
+
def pipeline_to_hub(self):
|
2069 |
+
"""
|
2070 |
+
publish versioned source-code and pipeline-card
|
2071 |
+
for ths run on the Hugging Face Hub.
|
2072 |
+
"""
|
2073 |
+
|
2074 |
+
model_commit_datetime = \
|
2075 |
+
self.model_commit_dict["commit_datetime"]
|
2076 |
+
timestamp_str = \
|
2077 |
+
"{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
|
2078 |
+
"{:03d}".format(model_commit_datetime.microsecond//1000) + \
|
2079 |
+
"_UTC"
|
2080 |
+
subfolder_name = \
|
2081 |
+
"v" + self.model_commit_dict["version_label"] + \
|
2082 |
+
"_" + timestamp_str
|
2083 |
+
commit_datetime = datetime.utcnow()
|
2084 |
+
|
2085 |
+
###############################
|
2086 |
+
# source-code #
|
2087 |
+
###############################
|
2088 |
+
# We upload only herein file #
|
2089 |
+
# plus user-provided versions #
|
2090 |
+
# of the customizable ones #
|
2091 |
+
# (if any). #
|
2092 |
+
###############################
|
2093 |
+
custom_source_files = [os.path.abspath(__file__)]
|
2094 |
+
if (
|
2095 |
+
self.pipeline_card_artifacts_path != \
|
2096 |
+
self.default_pipeline_card_module_dir
|
2097 |
+
):
|
2098 |
+
candidate_source_files = [
|
2099 |
+
"pipeline_card.py",
|
2100 |
+
"template.html",
|
2101 |
+
"dataset_readme.py",
|
2102 |
+
"dataset_readme_template.md",
|
2103 |
+
"model_readme.py",
|
2104 |
+
"model_readme_template.md"
|
2105 |
+
]
|
2106 |
+
for candidate_source_file in candidate_source_files:
|
2107 |
+
file_fullpath = os.path.join(
|
2108 |
+
self.pipeline_card_artifacts_path,
|
2109 |
+
candidate_source_file)
|
2110 |
+
if os.path.exists(file_fullpath):
|
2111 |
+
custom_source_files.append(file_fullpath)
|
2112 |
+
|
2113 |
+
source_code_commit_hash = \
|
2114 |
+
push_files_to_hub_repo_branch(
|
2115 |
+
repo_id=self.model_repo_id,
|
2116 |
+
branch_name="retrain-pipelines_source-code",
|
2117 |
+
file_fullnames=custom_source_files,
|
2118 |
+
include_requirements_txt=True,
|
2119 |
+
path_in_repo=subfolder_name,
|
2120 |
+
commit_message=\
|
2121 |
+
"source-code for model version " + \
|
2122 |
+
subfolder_name + \
|
2123 |
+
f"- retrain-pipelines {__version__}",
|
2124 |
+
repo_type="model",
|
2125 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
2126 |
+
)
|
2127 |
+
print(source_code_commit_hash)
|
2128 |
+
self.source_code_commit_dict = {
|
2129 |
+
"repo_id": self.model_repo_id,
|
2130 |
+
"branch_name": "retrain-pipelines_source-code",
|
2131 |
+
"commit_datetime": commit_datetime,
|
2132 |
+
"commit_hash": source_code_commit_hash
|
2133 |
+
}
|
2134 |
+
###############################
|
2135 |
+
|
2136 |
+
###############################
|
2137 |
+
# pipeline-card #
|
2138 |
+
###############################
|
2139 |
+
pipeline_card_fullname = None
|
2140 |
+
for run_step in current.run.steps():
|
2141 |
+
task = list(run_step.tasks())[0]
|
2142 |
+
task_name = task.path_components[2]
|
2143 |
+
if "pipeline_card" == task_name:
|
2144 |
+
pipeline_card = get_cards(
|
2145 |
+
task, id='custom', type='html')[0]
|
2146 |
+
pipeline_card_fullname = os.path.realpath(
|
2147 |
+
os.path.join(
|
2148 |
+
task.metadata_dict.get("ds-root", None),
|
2149 |
+
mf_config.CARD_SUFFIX, pipeline_card.path
|
2150 |
+
))
|
2151 |
+
print(pipeline_card_fullname)
|
2152 |
+
break
|
2153 |
+
pipeline_card_commit_hash = \
|
2154 |
+
push_files_to_hub_repo_branch(
|
2155 |
+
repo_id=self.model_repo_id,
|
2156 |
+
branch_name="retrain-pipelines_pipeline-card",
|
2157 |
+
file_fullnames=[pipeline_card_fullname],
|
2158 |
+
path_in_repo=subfolder_name,
|
2159 |
+
commit_message=\
|
2160 |
+
"pipeline-card for model version " + \
|
2161 |
+
subfolder_name + \
|
2162 |
+
f"- retrain-pipelines {__version__}",
|
2163 |
+
repo_type="model",
|
2164 |
+
hf_token=os.getenv("HF_TOKEN", None)
|
2165 |
+
)
|
2166 |
+
print(pipeline_card_commit_hash)
|
2167 |
+
self.pipeline_card_commit_dict = {
|
2168 |
+
"repo_id": self.model_repo_id,
|
2169 |
+
"branch_name": "retrain-pipelines_pipeline-card",
|
2170 |
+
"commit_datetime": commit_datetime,
|
2171 |
+
"commit_hash": pipeline_card_commit_hash
|
2172 |
+
}
|
2173 |
+
###############################
|
2174 |
+
|
2175 |
+
self.next(self.deploy)
|
2176 |
+
|
2177 |
+
|
2178 |
+
@step
|
2179 |
+
def deploy(self):
|
2180 |
+
"""
|
2181 |
+
placeholder for the serving SDK deploy call
|
2182 |
+
(on the target production platform).
|
2183 |
+
Include any artifact you want,
|
2184 |
+
consider including the portable pipelione-card
|
2185 |
+
itself !
|
2186 |
+
"""
|
2187 |
+
|
2188 |
+
if (
|
2189 |
+
self.model_version_blessed and
|
2190 |
+
(self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
|
2191 |
+
):
|
2192 |
+
pass # your code here
|
2193 |
+
|
2194 |
+
self.next(self.load_test)
|
2195 |
+
|
2196 |
+
|
2197 |
+
@step
|
2198 |
+
def load_test(self):
|
2199 |
+
"""
|
2200 |
+
placeholder
|
2201 |
+
"""
|
2202 |
+
|
2203 |
+
if (
|
2204 |
+
self.model_version_blessed and
|
2205 |
+
(self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
|
2206 |
+
):
|
2207 |
+
pass # your code here
|
2208 |
+
|
2209 |
+
self.next(self.end)
|
2210 |
+
|
2211 |
+
|
2212 |
+
@step
|
2213 |
+
def end(self):
|
2214 |
+
pass
|
2215 |
+
|
2216 |
+
|
2217 |
+
if __name__ == "__main__":
|
2218 |
+
UnslothFuncCallFlow()
|
2219 |
+
|