Aurelien-Morgan-Bot commited on
Commit
7323ff8
·
verified ·
1 Parent(s): 6b0f156

source-code for model version v0.10_20250318_214952149_UTC- retrain-pipelines 0.1.1

Browse files
v0.10_20250318_214952149_UTC/requirements.txt ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==1.1.1
3
+ aiohappyeyeballs==2.4.3
4
+ aiohttp==3.10.10
5
+ aiosignal==1.3.1
6
+ airportsdata==20241001
7
+ annotated-types==0.7.0
8
+ anyio==4.8.0
9
+ asttokens==2.4.1
10
+ async-timeout==4.0.3
11
+ attrs==24.2.0
12
+ bitsandbytes==0.44.1
13
+ boto3==1.35.58
14
+ botocore==1.35.58
15
+ certifi==2024.8.30
16
+ charset-normalizer==3.4.0
17
+ click==8.1.7
18
+ cloudpickle==3.1.0
19
+ colorama==0.4.6
20
+ comm==0.2.2
21
+ contourpy==1.3.1
22
+ cuda-python==12.6.2
23
+ cudf-polars-cu12==24.10.1
24
+ cycler==0.12.1
25
+ datasets==3.1.0
26
+ debugpy==1.8.8
27
+ decorator==5.1.1
28
+ dill==0.3.8
29
+ diskcache==5.6.3
30
+ docker==7.1.0
31
+ docker-pycreds==0.4.0
32
+ docstring_parser==0.16
33
+ exceptiongroup==1.2.2
34
+ executing==2.1.0
35
+ fastapi==0.115.8
36
+ fastjsonschema==2.20.0
37
+ filelock==3.16.1
38
+ fonttools==4.54.1
39
+ frozenlist==1.5.0
40
+ fsspec==2024.9.0
41
+ gitdb==4.0.11
42
+ GitPython==3.1.43
43
+ graphviz==0.20.3
44
+ grpcio==1.68.1
45
+ h11==0.14.0
46
+ hf_transfer==0.1.8
47
+ httptools==0.6.4
48
+ huggingface-hub==0.27.1
49
+ idna==3.10
50
+ iniconfig==2.0.0
51
+ interegular==0.3.3
52
+ ipykernel==6.29.5
53
+ ipython==8.29.0
54
+ ipywidgets==8.1.5
55
+ jedi==0.19.2
56
+ Jinja2==3.1.4
57
+ jmespath==1.0.1
58
+ joblib==1.4.2
59
+ jsonschema==4.23.0
60
+ jsonschema-specifications==2024.10.1
61
+ jupyter_client==8.6.3
62
+ jupyter_core==5.7.2
63
+ jupyterlab_widgets==3.0.13
64
+ kiwisolver==1.4.7
65
+ lark==1.2.2
66
+ libcudf-cu12==24.10.1
67
+ litserve==0.2.6
68
+ llvmlite==0.43.0
69
+ lxml==5.3.0
70
+ Markdown==3.7
71
+ markdown-it-py==3.0.0
72
+ MarkupSafe==3.0.2
73
+ matplotlib==3.9.2
74
+ matplotlib-inline==0.1.7
75
+ mdurl==0.1.2
76
+ metaflow==2.10.0
77
+ metaflow-card-html==1.0.2
78
+ mpmath==1.3.0
79
+ multidict==6.1.0
80
+ multiprocess==0.70.16
81
+ nbformat==5.10.4
82
+ nest-asyncio==1.6.0
83
+ networkx==3.2.1
84
+ numba==0.60.0
85
+ numpy==1.26.4
86
+ nvidia-cublas-cu11==11.11.3.6
87
+ nvidia-cublas-cu12==12.4.5.8
88
+ nvidia-cuda-cupti-cu11==11.8.87
89
+ nvidia-cuda-cupti-cu12==12.4.127
90
+ nvidia-cuda-nvrtc-cu11==11.8.89
91
+ nvidia-cuda-nvrtc-cu12==12.4.127
92
+ nvidia-cuda-runtime-cu11==11.8.89
93
+ nvidia-cuda-runtime-cu12==12.4.127
94
+ nvidia-cudnn-cu11==9.1.0.70
95
+ nvidia-cudnn-cu12==9.1.0.70
96
+ nvidia-cufft-cu11==10.9.0.58
97
+ nvidia-cufft-cu12==11.2.1.3
98
+ nvidia-curand-cu11==10.3.0.86
99
+ nvidia-curand-cu12==10.3.5.147
100
+ nvidia-cusolver-cu11==11.4.1.48
101
+ nvidia-cusolver-cu12==11.6.1.9
102
+ nvidia-cusparse-cu11==11.7.5.86
103
+ nvidia-cusparse-cu12==12.3.1.170
104
+ nvidia-nccl-cu11==2.21.5
105
+ nvidia-nccl-cu12==2.21.5
106
+ nvidia-nvjitlink-cu12==12.4.127
107
+ nvidia-nvtx-cu11==11.8.86
108
+ nvidia-nvtx-cu12==12.4.127
109
+ nvtx==0.2.10
110
+ outlines==0.1.3
111
+ outlines_core==0.1.14
112
+ packaging==24.2
113
+ pandas==2.2.3
114
+ parso==0.8.4
115
+ peft==0.13.2
116
+ pexpect==4.9.0
117
+ pillow==11.0.0
118
+ platformdirs==4.3.6
119
+ plotly==5.24.0
120
+ pluggy==1.5.0
121
+ polars==1.8.2
122
+ prompt_toolkit==3.0.48
123
+ propcache==0.2.0
124
+ protobuf==3.20.3
125
+ psutil==6.1.0
126
+ ptyprocess==0.7.0
127
+ pure_eval==0.2.3
128
+ pyarrow==17.0.0
129
+ pycountry==24.6.1
130
+ pydantic==2.9.2
131
+ pydantic_core==2.23.4
132
+ pydot==1.4.2
133
+ Pygments==2.18.0
134
+ pylibcudf-cu12==24.10.1
135
+ pyparsing==3.2.0
136
+ pytest==8.3.3
137
+ python-dateutil==2.9.0.post0
138
+ python-dotenv==1.0.1
139
+ python-multipart==0.0.20
140
+ pytz==2024.2
141
+ PyYAML==6.0.2
142
+ pyzmq==26.2.0
143
+ referencing==0.35.1
144
+ regex==2024.11.6
145
+ requests==2.32.3
146
+ -e git+https://github.com/aurelienmorgan/retrain-pipelines.git@9bbdca8a19b421b90a2640a250d8549680898f9b#egg=retrain_pipelines&subdirectory=pkg_src
147
+ rich==13.9.4
148
+ rmm-cu12==24.10.0
149
+ rpds-py==0.21.0
150
+ s3transfer==0.10.3
151
+ safetensors==0.4.5
152
+ scikit-learn==1.5.1
153
+ scipy==1.14.1
154
+ sentencepiece==0.2.0
155
+ sentry-sdk==2.19.0
156
+ setproctitle==1.3.4
157
+ shtab==1.7.1
158
+ six==1.16.0
159
+ smmap==5.0.1
160
+ sniffio==1.3.1
161
+ stack-data==0.6.3
162
+ starlette==0.45.3
163
+ sympy==1.13.1
164
+ tenacity==9.0.0
165
+ tensorboard==2.18.0
166
+ tensorboard-data-server==0.7.2
167
+ threadpoolctl==3.5.0
168
+ tokenizers==0.20.3
169
+ tomli==2.1.0
170
+ torch==2.5.0
171
+ tornado==6.4.1
172
+ tqdm==4.67.0
173
+ traitlets==5.14.3
174
+ transformers==4.46.2
175
+ triton==3.1.0
176
+ trl==0.12.0
177
+ typing_extensions==4.12.2
178
+ tyro==0.8.14
179
+ tzdata==2024.2
180
+ unsloth==2024.11.5
181
+ unsloth_zoo==2024.11.4
182
+ urllib3==2.2.3
183
+ uvicorn==0.34.0
184
+ uvloop==0.21.0
185
+ wandb==0.18.7
186
+ watchfiles==1.0.4
187
+ wcwidth==0.2.13
188
+ websockets==15.0
189
+ Werkzeug==3.1.3
190
+ widgetsnbextension==4.0.13
191
+ xformers==0.0.28.post2
192
+ xxhash==3.5.0
193
+ yarl==1.17.1
v0.10_20250318_214952149_UTC/retraining_pipeline.py ADDED
@@ -0,0 +1,2179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from unsloth import FastLanguageModel, \
3
+ is_bfloat16_supported, UnslothTrainer, \
4
+ UnslothTrainingArguments
5
+
6
+ import torch
7
+
8
+ import os
9
+ import sys
10
+
11
+ import gc
12
+ import json
13
+ import time
14
+ import shutil
15
+ import logging
16
+ import traceback
17
+ import subprocess
18
+ import importlib.util
19
+ from enum import Enum
20
+ from io import StringIO
21
+ from textwrap import dedent
22
+ from datetime import datetime
23
+ from contextlib import redirect_stdout
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ import polars as pl
29
+ from polars.exceptions import ComputeError
30
+
31
+ import matplotlib
32
+ import matplotlib.pyplot as plt
33
+
34
+ from jinja2 import Environment, FileSystemLoader
35
+
36
+ from metaflow import FlowSpec, step, Parameter, JSONType, \
37
+ IncludeFile, current, metaflow_config as mf_config, \
38
+ resources, Flow, Task, card
39
+ from metaflow.current import Current
40
+ from metaflow.cards import Image, Table, Markdown, \
41
+ Artifact, get_cards
42
+
43
+ from datasets import load_dataset, Dataset, DatasetDict
44
+ from datasets.config import HF_DATASETS_CACHE, HF_CACHE_HOME
45
+ from huggingface_hub import list_repo_commits
46
+ from transformers import AutoTokenizer
47
+ from transformers.utils import logging as hf_logging
48
+
49
+ from retrain_pipelines import __version__
50
+ from retrain_pipelines.dataset.hf_utils import get_lazy_df, \
51
+ get_column_info, iterable_dataset_multi_buffer_sampler, \
52
+ push_dataset_version_to_hub
53
+ from retrain_pipelines.dataset.tool_calls import \
54
+ get_unique_tools, count_tool_occurrences, \
55
+ plot_tools_occurences, column_words_stats, \
56
+ plot_words_count
57
+ from retrain_pipelines.utils.hf_utils import \
58
+ get_new_repo_minor_version, push_files_to_hub_repo_branch
59
+ from retrain_pipelines.utils import create_requirements
60
+
61
+
62
+ class LocalServeReadinessEnum(Enum):
63
+ """
64
+ tracking local-serve (infra-validation)
65
+ status using a "3+"-states enum :
66
+ - "-1" for "not applicable"
67
+ (i.e. "model version not blessed"),
68
+ - "0/1" bool for failure/success.
69
+ """
70
+ NOT_APPLICABLE = -1
71
+ FAILURE = 0
72
+ FAILURE_NO_DOCKER = 2
73
+ SUCCESS = 1
74
+
75
+
76
+ class UnslothFuncCallFlow(FlowSpec):
77
+ """
78
+ Training pipeline
79
+ """
80
+ # @see https://github.com/unslothai/unsloth/wiki
81
+
82
+ #--- flow parameters -------------------------------------------------------
83
+
84
+ RETRAIN_PIPELINE_TYPE = "mf_unsloth_func_call_litserve"
85
+ # in order to share the config across subprocesses
86
+ os.environ["retrain_pipeline_type"] = RETRAIN_PIPELINE_TYPE
87
+
88
+ hf_dataset = Parameter(
89
+ "hf_dataset",
90
+ help="dict with 'repo_id' and 'commit_hash' keys. " + \
91
+ "if 'commit_hash is None, falls back to latest version " +\
92
+ "of the dataset available in parquet format.\n" +
93
+ "Note that there are 3 required 'attributes' of type " + \
94
+ "str, list[str], list[str]",
95
+ type=JSONType,
96
+ default=dedent("""{
97
+ "repo_id": "Salesforce/xlam-function-calling-60k",
98
+ "config_name": "",
99
+ "commit_hash": "",
100
+ "attributes": {
101
+ "query_attr": "query",
102
+ "answers_attr": "answers",
103
+ "tools_attr": "tools"
104
+ }
105
+ }""").replace("'", '"').strip('"')
106
+ )
107
+
108
+ augmentation_rate = Parameter(
109
+ "augmentation_rate",
110
+ type=float,
111
+ default=.05,
112
+ help="proportion of records to be augmented "+\
113
+ "(x% of original dataset is created"+\
114
+ " as additional augmented datapoints), i.e. "+\
115
+ "truncated queries to serve as negative examples, "+\
116
+ "meaning they trigger no tool call "+\
117
+ "due to info incompleteness."
118
+ )
119
+
120
+ hf_enrich_dataset = Parameter(
121
+ "hf_enrich_dataset",
122
+ help="dict with 'repo_id', 'config_name' and 'commit_hash', "+\
123
+ "query_attribute' and 'query_attribute_handler' keys. "+\
124
+ "if 'commit_hash is None, falls back to latest version "+\
125
+ "of the dataset available in parquet format."+\
126
+ "'query_attribute' depicts the dataset attribute "+\
127
+ "from which 'queries' are to be sampled."+\
128
+ "'query_attribute_handler' serves for attributes "+\
129
+ "that have complex structure, "+\
130
+ "other than 'string' datatype.",
131
+ type=JSONType,
132
+ # @see https://huggingface.co/datasets/google-research-datasets/natural_questions
133
+ default=dedent("""{
134
+ "repo_id": "lighteval/natural_questions_clean",
135
+ "config_name": "",
136
+ "commit_hash": "",
137
+ "query_attribute": "question",
138
+ "query_attribute_handler": "lambda x: x"
139
+ }""").replace("'", '"').strip('"')
140
+ )
141
+
142
+ enrichment_rate = Parameter(
143
+ "enrichment_rate",
144
+ type=float,
145
+ default=.1,
146
+ help="proportion of records "+\
147
+ "to be added from the 'hf_enrich_dataset'"+\
148
+ "(x% of original dataset is sampled and"+\
149
+ " added as enriching datapoints), i.e. "+\
150
+ "queries to serve as negative examples, "+\
151
+ "due to their complete disconnexion "+\
152
+ "to tool calling situations."
153
+ )
154
+
155
+ dataset_repo_id = Parameter(
156
+ "dataset_repo_id",
157
+ type=str,
158
+ default="retrain-pipelines/func_calls",
159
+ help="The 'repo_id' to be used " + \
160
+ "for the Hugging Face dataset version push " + \
161
+ "(will be created at runtime" + \
162
+ " if doesn't already exist)."
163
+ )
164
+
165
+ hf_base_model = Parameter(
166
+ "hf_base_model",
167
+ help="dict with 'repo_id' and 'commit_hash' keys."+\
168
+ "if 'commit_hash is None, falls back "+\
169
+ "to latest available version of the model.",
170
+ type=JSONType,
171
+ default=dedent("""{
172
+ "repo_id": "unsloth/Qwen2.5-1.5B",
173
+ "commit_hash": ""
174
+ }""").replace("'", '"').strip('"')
175
+ )
176
+
177
+ cpt_training_args = Parameter(
178
+ "cpt_training_args",
179
+ help="dict with `TrainingArguments` params "+\
180
+ "for the CPT job.",
181
+ type=JSONType,
182
+ default=dedent("""{
183
+ "warmup_ratio": 0.1,
184
+ "num_train_epochs": 1
185
+ }""").replace("'", '"').strip('"')
186
+ )
187
+
188
+ sft_training_args = Parameter(
189
+ "sft_training_args",
190
+ help="dict with `TrainingArguments` params "+\
191
+ "for the SFT job.",
192
+ type=JSONType,
193
+ default=dedent("""{
194
+ "warmup_ratio": 0.1,
195
+ "num_train_epochs": 1
196
+ }""").replace("'", '"').strip('"')
197
+ )
198
+
199
+ model_repo_id = Parameter(
200
+ "model_repo_id",
201
+ type=str,
202
+ default="retrain-pipelines/function_caller",
203
+ help="The 'repo_id' to be used " + \
204
+ "for the Hugging Face model version push " + \
205
+ "(will be created at runtime" + \
206
+ " if doesn't already exist)."
207
+ )
208
+
209
+ default_pipeline_card_module_dir = \
210
+ os.path.dirname(
211
+ importlib.util.find_spec(
212
+ f"retrain_pipelines.pipeline_card."+
213
+ f"{RETRAIN_PIPELINE_TYPE}"
214
+ ).origin)
215
+ pipeline_card_artifacts_path = Parameter(
216
+ "pipeline_card_artifacts_path",
217
+ type=str,
218
+ default=default_pipeline_card_module_dir,
219
+ help="pipeline_card artifacts location "+\
220
+ "(i.e. dir hosting your optional " + \
221
+ " custom documentation files :" + \
222
+ " 'pipeline_card.py' and/or 'template.html'"+\
223
+ " and/or 'model_readme.py'"+\
224
+ " and/or 'model_readme_template.md'," +\
225
+ " and/or 'dataset_readme.py'"+\
226
+ " and/or 'dataset_readme_template.md' file), " +\
227
+ "if different from default."
228
+ )
229
+ @staticmethod
230
+ def copy_default_dataset_readme_module(
231
+ target_dir: str,
232
+ exists_ok: bool = False
233
+ ) -> None:
234
+ os.makedirs(target_dir, exist_ok=True)
235
+ if (
236
+ not exists_ok and
237
+ os.path.exists(os.path.join(target_dir, "dataset_readme.py"))
238
+ ):
239
+ print("File already exists. Skipping copy.")
240
+ else:
241
+ filefullname = os.path.join(
242
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
243
+ "dataset_readme.py"
244
+ )
245
+ shutil.copy(filefullname, target_dir)
246
+ print(filefullname)
247
+ @staticmethod
248
+ def copy_default_dataset_readme_template(
249
+ target_dir: str,
250
+ exists_ok: bool = False
251
+ ) -> None:
252
+ os.makedirs(target_dir, exist_ok=True)
253
+ if (
254
+ not exists_ok and
255
+ os.path.exists(os.path.join(target_dir,
256
+ "dataset_readme_template.md"))
257
+ ):
258
+ print("File already exists. Skipping copy.")
259
+ else:
260
+ filefullname = os.path.join(
261
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
262
+ "dataset_readme_template.md")
263
+ shutil.copy(filefullname, target_dir)
264
+ print(filefullname)
265
+ @staticmethod
266
+ def copy_default_model_readme_module(
267
+ target_dir: str,
268
+ exists_ok: bool = False
269
+ ) -> None:
270
+ os.makedirs(target_dir, exist_ok=True)
271
+ if (
272
+ not exists_ok and
273
+ os.path.exists(os.path.join(target_dir, "model_readme.py"))
274
+ ):
275
+ print("File already exists. Skipping copy.")
276
+ else:
277
+ filefullname = os.path.join(
278
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
279
+ "model_readme.py"
280
+ )
281
+ shutil.copy(filefullname, target_dir)
282
+ print(filefullname)
283
+ @staticmethod
284
+ def copy_default_model_readme_template(
285
+ target_dir: str,
286
+ exists_ok: bool = False
287
+ ) -> None:
288
+ os.makedirs(target_dir, exist_ok=True)
289
+ if (
290
+ not exists_ok and
291
+ os.path.exists(os.path.join(target_dir,
292
+ "model_readme_template.md"))
293
+ ):
294
+ print("File already exists. Skipping copy.")
295
+ else:
296
+ filefullname = os.path.join(
297
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
298
+ "model_readme_template.md")
299
+ shutil.copy(filefullname, target_dir)
300
+ print(filefullname)
301
+ @staticmethod
302
+ def copy_default_pipeline_card_module(
303
+ target_dir: str,
304
+ exists_ok: bool = False
305
+ ) -> None:
306
+ os.makedirs(target_dir, exist_ok=True)
307
+ if (
308
+ not exists_ok and
309
+ os.path.exists(os.path.join(target_dir, "pipeline_card.py"))
310
+ ):
311
+ print("File already exists. Skipping copy.")
312
+ else:
313
+ filefullname = os.path.join(
314
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
315
+ "pipeline_card.py"
316
+ )
317
+ shutil.copy(filefullname, target_dir)
318
+ print(filefullname)
319
+ @staticmethod
320
+ def copy_default_pipeline_card_html_template(
321
+ target_dir: str,
322
+ exists_ok: bool = False
323
+ ) -> None:
324
+ os.makedirs(target_dir, exist_ok=True)
325
+ if (
326
+ not exists_ok and
327
+ os.path.exists(os.path.join(target_dir, "template.html"))
328
+ ):
329
+ print("File already exists. Skipping copy.")
330
+ else:
331
+ filefullname = os.path.join(
332
+ UnslothFuncCallFlow.default_pipeline_card_module_dir,
333
+ "template.html")
334
+ shutil.copy(filefullname, target_dir)
335
+ print(filefullname)
336
+
337
+ del RETRAIN_PIPELINE_TYPE
338
+
339
+ #---------------------------------------------------------------------------
340
+
341
+ @step
342
+ def start(self):
343
+ print(f"{current.flow_name} - {current.run_id}")
344
+
345
+ # GPU availability
346
+ print(torch.cuda.get_device_name(0))
347
+ print(torch.__version__)
348
+ self.engine = "gpu" if torch.cuda.is_available() else "cpu"
349
+
350
+ # hf_dataset
351
+ hf_dataset_dict = \
352
+ get_lazy_df(
353
+ repo_id=self.hf_dataset["repo_id"],
354
+ commit_hash=self.hf_dataset["commit_hash"],
355
+ files_filter=(
356
+ self.hf_dataset['config_name']+"/.*\\.parquet"
357
+ if (
358
+ self.hf_dataset["config_name"] and
359
+ "" < self.hf_dataset["config_name"]
360
+ ) else ".*\\.parquet"
361
+ ),
362
+ hf_token=os.getenv("HF_TOKEN", None)
363
+ )
364
+ try:
365
+ print(hf_dataset_dict["repo_id"], ", ",
366
+ hf_dataset_dict["commit_hash"], " - ",
367
+ hf_dataset_dict["commit_datetime"], "\n",
368
+ hf_dataset_dict["lazy_df"].explain())
369
+ except ComputeError as ex:
370
+ if "HF_TOKEN" not in os.environ:
371
+ print("Does the Hugging Face-hosted dataset " +
372
+ "require authentication ?",
373
+ file=sys.stderr, flush=True)
374
+ raise ex
375
+ self.hf_dataset_dict = hf_dataset_dict
376
+
377
+ # hf_enrich_dataset
378
+ print(self.hf_enrich_dataset)
379
+ hf_enrich_dataset_dict = \
380
+ get_lazy_df(
381
+ repo_id=self.hf_enrich_dataset["repo_id"],
382
+ commit_hash=self.hf_enrich_dataset["commit_hash"],
383
+ files_filter=(
384
+ self.hf_enrich_dataset['config_name']+"/.*\\.parquet"
385
+ if (
386
+ self.hf_enrich_dataset["config_name"] and
387
+ "" < self.hf_enrich_dataset["config_name"]
388
+ ) else ".*\\.parquet"
389
+ ),
390
+ hf_token=os.getenv("HF_TOKEN", None)
391
+ )
392
+ print(' ; '.join(f"{k}: {hf_enrich_dataset_dict[k]}"
393
+ for k in ['commit_hash',
394
+ 'commit_datetime']))
395
+ self.hf_enrich_dataset_dict = hf_enrich_dataset_dict
396
+
397
+ # hf_base_model
398
+ hf_base_model_commits = list_repo_commits(
399
+ repo_id=self.hf_base_model["repo_id"],
400
+ revision=(
401
+ None if (rev_commit_hash:=self.hf_base_model["commit_hash"]) == ""
402
+ else rev_commit_hash
403
+ ),
404
+ repo_type="model",
405
+ token=os.getenv("HF_TOKEN", None))
406
+ self.hf_base_model_dict = {
407
+ "repo_id": self.hf_base_model["repo_id"],
408
+ "commit_hash": hf_base_model_commits[0].commit_id,
409
+ "commit_datetime": \
410
+ hf_base_model_commits[0].created_at
411
+ }
412
+
413
+ self.model_version_blessed = False
414
+ self.current_blessed_run = None
415
+ self.current_blessed_version_dict = None
416
+ current.run.remove_tag("model_version_blessed")
417
+
418
+ self.retrain_pipelines = f"retrain-pipelines {__version__}"
419
+ self.retrain_pipeline_type = os.environ["retrain_pipeline_type"]
420
+
421
+ self.serving_artifacts_local_folder = \
422
+ os.path.realpath(os.path.join(
423
+ os.path.dirname(__file__),
424
+ '..', '..', 'serving_artifacts',
425
+ os.path.sep.join(current.run.path_components)
426
+ ))
427
+
428
+ if not os.path.exists(self.serving_artifacts_local_folder):
429
+ os.makedirs(self.serving_artifacts_local_folder)
430
+
431
+ self.unsloth_dir = os.path.join(
432
+ self.serving_artifacts_local_folder,
433
+ "Unsloth"
434
+ )
435
+ print(f"unsloth_dir : {self.unsloth_dir}")
436
+ self.cpt_model_dir = os.path.join(
437
+ self.unsloth_dir, "cpt_model")
438
+ self.sft_model_dir = os.path.join(
439
+ self.unsloth_dir, "sft_model")
440
+
441
+ self.next(self.eda)
442
+
443
+
444
+ @step
445
+ def eda(self):
446
+ """
447
+ exploratory data analysis.
448
+ """
449
+
450
+ ############################
451
+ # features and label #
452
+ # basic counts #
453
+ ############################
454
+ self.records_count = self.hf_dataset_dict["lazy_df"] \
455
+ .select(pl.len()).collect(engine=self.engine).item()
456
+ self.data_schema = get_column_info(
457
+ self.hf_dataset_dict["lazy_df"], engine=self.engine)
458
+ ############################
459
+
460
+ ############################
461
+ # Answers #
462
+ # tools count #
463
+ ############################
464
+ struct_schema = pl.Struct([
465
+ pl.Field("name",
466
+ pl.String
467
+ ),
468
+ pl.Field("arguments",
469
+ pl.List(pl.String) # we retrieve list of args names
470
+ # (without assigned values)
471
+ )
472
+ ])
473
+ tool_answer_occurrences_df = \
474
+ count_tool_occurrences(
475
+ self.hf_dataset_dict["lazy_df"],
476
+ self.hf_dataset["attributes"]["answers_attr"],
477
+ struct_schema) \
478
+ .collect(engine=self.engine)
479
+ print(f"{tool_answer_occurrences_df['occurrences'].sum():,} " +
480
+ f"query/tool-calls pairs")
481
+ fig = plot_tools_occurences(tool_answer_occurrences_df,
482
+ title_prefix="Dataset answers - ")
483
+ self.answers_tools_count_fig = fig
484
+ ############################
485
+
486
+ ############################
487
+ # Query #
488
+ # words count #
489
+ ############################
490
+ queries_max_length = self.hf_dataset_dict["lazy_df"].select(
491
+ pl.col(
492
+ self.hf_dataset["attributes"]["query_attr"]
493
+ ).str.len_chars().max().alias("max_query_length")
494
+ ).collect(engine=self.engine)
495
+ print(f"longuest query counts " +
496
+ f"{queries_max_length['max_query_length'][0]:,} characters")
497
+
498
+ # queries length quartiles
499
+ self.query_words_stats = \
500
+ column_words_stats(
501
+ self.hf_dataset_dict["lazy_df"],
502
+ self.hf_dataset["attributes"]["query_attr"]
503
+ ).collect(engine=self.engine)
504
+ print(self.query_words_stats.to_pandas().to_string(index=False))
505
+ print("Two thirds of the records have a query with less than " +
506
+ f"{self.query_words_stats['q3'][0]} words.")
507
+
508
+ fig = plot_words_count(
509
+ self.hf_dataset_dict["lazy_df"],
510
+ column_name=self.hf_dataset["attributes"]["query_attr"],
511
+ engine=self.engine)
512
+ self.words_count_fig = fig
513
+ ############################
514
+
515
+ ############################
516
+ # hf_enrich_dataset #
517
+ # Query words count #
518
+ ############################
519
+ enrich_question_words_stats = \
520
+ column_words_stats(
521
+ self.hf_enrich_dataset_dict['lazy_df'],
522
+ self.hf_enrich_dataset["query_attribute"],
523
+ column_attr_handler=eval(
524
+ self.hf_enrich_dataset["query_attribute_handler"])
525
+ ).collect(engine=self.engine)
526
+ print(enrich_question_words_stats.to_pandas()
527
+ .to_string(index=False))
528
+ del enrich_question_words_stats
529
+ ############################
530
+
531
+ self.next(self.augment_data)
532
+
533
+
534
+ @step
535
+ def augment_data(self):
536
+ """
537
+ Add 'negative' examples, where
538
+ queries do not trigger any tool call.
539
+ To achieve that, we sample long user queries,
540
+ truncate at half words count, and
541
+ associate this to an empty list of tool-calls.
542
+ """
543
+ """
544
+ We only consider :
545
+ - records with longuest queries,
546
+ i.e. queries in the last quartile
547
+ of "queries with most word-counts"
548
+ (this is to avoid that 'truncated' queries
549
+ get really short)
550
+ - records with answers consisting
551
+ in a single tool-call
552
+ (in order to minimize the risk
553
+ that truncating actually gives
554
+ a valid answer with
555
+ one tool-call [or more])
556
+
557
+ Note on flow 'augmentation_rate' :
558
+ we add that many records (at most),
559
+ as quartiles size permits.
560
+ """
561
+
562
+ print("Sampling within the population with more than " +
563
+ str(self.query_words_stats['q3'][0]) +
564
+ " words (longest queries quartile) =>")
565
+
566
+ samples_count = \
567
+ int(self.records_count * self.augmentation_rate)
568
+ print(f"would represent {samples_count:,.0f} " +
569
+ f"records to be sampled")
570
+
571
+ eligible_records_df = \
572
+ self.hf_dataset_dict["lazy_df"].filter(
573
+ pl.col(
574
+ self.hf_dataset["attributes"]["query_attr"]
575
+ )
576
+ .str.extract_all(r"\w+")
577
+ .map_elements(
578
+ lambda arr: len(arr),
579
+ return_dtype=pl.Int16)
580
+ .gt(self.query_words_stats['q3'][0])
581
+ & pl.col("answers")
582
+ .map_elements(
583
+ lambda x: len(json.loads(x)) == 1
584
+ if isinstance(x, str)
585
+ else False,
586
+ return_dtype=pl.Boolean)
587
+ ) \
588
+ .collect(engine=self.engine)
589
+ eligible_records_count = \
590
+ eligible_records_df.select(pl.len())["len"][0]
591
+ print(f"eligible_records_count : " +
592
+ f"{eligible_records_count:,.0f}")
593
+ samples_count = min(samples_count, eligible_records_count)
594
+ self.actual_augmentation_rate = \
595
+ samples_count / self.records_count
596
+ print("actual augmentation rate : " +
597
+ f"{self.actual_augmentation_rate:.1%}")
598
+ sampled_records_df = eligible_records_df.sample(
599
+ n=samples_count
600
+ )
601
+
602
+ self.augmented_records_df = \
603
+ sampled_records_df.with_columns(
604
+ pl.col("query")
605
+ .map_elements(
606
+ lambda query:
607
+ " ".join(
608
+ query.split()[
609
+ :len(query.split()) // 2]),
610
+ return_dtype=pl.Utf8)
611
+ .alias("truncated_query")
612
+ ).select([
613
+ pl.col("truncated_query").alias("query"),
614
+ pl.lit("[]").alias("answers")
615
+ ])
616
+ print(self.augmented_records_df.height,
617
+ self.augmented_records_df.columns)
618
+
619
+ self.next(self.enrich_data)
620
+
621
+
622
+ @step
623
+ def enrich_data(self):
624
+ """
625
+ Further enrich our dataset with 'negative' records from
626
+ another dataset (can be general-purpose text dataset)
627
+ as specified by the the flow 'hf_enrich_dataset' argument.
628
+ """
629
+ """
630
+ Note : we here use the Hugging Face `datasets` library
631
+ in 'streaming' mode for records sampling.
632
+ """
633
+
634
+ hf_enrich_ds = load_dataset(
635
+ path=self.hf_enrich_dataset["repo_id"],
636
+ name=self.hf_enrich_dataset["config_name"],
637
+ revision=self.hf_enrich_dataset_dict["commit_hash"],
638
+ streaming=True)
639
+ print(hf_enrich_ds["train"])
640
+
641
+ samples_count = \
642
+ int(self.records_count * self.enrichment_rate)
643
+ print(f"Samplig {samples_count:,.0f} records")
644
+
645
+ query_attribute_handler = \
646
+ eval(self.hf_enrich_dataset["query_attribute_handler"])
647
+ samples_iterator = iterable_dataset_multi_buffer_sampler(
648
+ hf_enrich_ds["train"],
649
+ total_samples=samples_count,
650
+ attributes_selector=\
651
+ (lambda x:query_attribute_handler(
652
+ x[self.hf_enrich_dataset["query_attribute"]])),
653
+ buffer_size=3_000,
654
+ num_passes=3,
655
+ seed=None
656
+ )
657
+ # Capitalize and add end punctuation if missing
658
+ start_time = time.time()
659
+ print("Starting sample enriching records, " +
660
+ "this may take some time if the source dataset " +
661
+ "has a complex structure..")
662
+ samples_list = [
663
+ s.capitalize() + ("" if s[-1] in ".!?" else "?")
664
+ for s in samples_iterator]
665
+ elapsed_time = time.time() - start_time
666
+ print(f".. sampling completed " +
667
+ f"({int(elapsed_time // 3_600)}h:" +
668
+ f"{int((elapsed_time % 3_600) // 60)}m:" +
669
+ f"{int(elapsed_time % 60)}s).")
670
+ enriched_records_df = pl.DataFrame(
671
+ {"query": samples_list,
672
+ "answers": \
673
+ ["[]"] * \
674
+ len(samples_list)}
675
+ )
676
+ self.enriched_records_df = enriched_records_df
677
+
678
+ self.next(self.dataset_to_hub)
679
+
680
+
681
+ @step
682
+ def dataset_to_hub(self):
683
+ """
684
+ Push to hub dataset version
685
+ - continued pre-training dataset
686
+ - training and validation splits of the
687
+ augmented and enriched
688
+ supervised finetuning dataset
689
+ - readme with versioning info
690
+ """
691
+
692
+ #############################
693
+ # case of user-provided #
694
+ # documentation artifact(s) #
695
+ #############################
696
+ # note that user can provide either
697
+ # 'pipeline_card.py' or 'template.html'
698
+ # or 'dataset_readme.py'
699
+ # or 'dataset_readme_template.md'
700
+ # or 'model_readme.py'
701
+ # or 'model_readme_template.md'
702
+ # or any combination of those
703
+ # when specifying custom
704
+ # 'pipeline_card_artifacts_path'
705
+ if (
706
+ "dataset_readme_template.md" in
707
+ os.listdir(self.pipeline_card_artifacts_path)
708
+ ):
709
+ template_dir = self.pipeline_card_artifacts_path
710
+ else:
711
+ template_dir = os.path.dirname(
712
+ importlib.util.find_spec(
713
+ f"retrain_pipelines.pipeline_card."+
714
+ f"{os.getenv('retrain_pipeline_type')}"
715
+ ).origin)
716
+ print(f"template_dir : '{template_dir}'")
717
+ #############################
718
+ if "dataset_readme.py" in os.listdir(
719
+ self.pipeline_card_artifacts_path):
720
+ from retrain_pipelines.utils import \
721
+ get_get_dataset_readme_content
722
+ get_dataset_readme_content = \
723
+ get_get_dataset_readme_content(
724
+ self.pipeline_card_artifacts_path)
725
+ else:
726
+ from retrain_pipelines.pipeline_card import \
727
+ get_dataset_readme_content
728
+ #############################
729
+
730
+
731
+ #############################
732
+ # augmented & enriched #
733
+ # finetuning dataset #
734
+ #############################
735
+ merged_df = pl.concat([
736
+ # dataset
737
+ self.hf_dataset_dict["lazy_df"].select([
738
+ self.hf_dataset["attributes"]["query_attr"],
739
+ self.hf_dataset["attributes"]["answers_attr"]
740
+ ]).collect(engine=self.engine),
741
+ # truncated queries augmentation
742
+ self.augmented_records_df,
743
+ # enriching dataset
744
+ self.enriched_records_df
745
+ ]).sample(
746
+ # shuffling
747
+ fraction=1,
748
+ shuffle=True,
749
+ with_replacement=False
750
+ )
751
+ merged_df = merged_df.sample(fraction=1, shuffle=True)
752
+ merged_df.rechunk()
753
+ print(("merged_df", f"{merged_df.shape[0]:,.0F}",
754
+ merged_df.columns))
755
+
756
+ pandas_df = merged_df.to_pandas()
757
+ train_size = int(0.8 * len(pandas_df))
758
+ print(f"validation : {len(pandas_df) - train_size}")
759
+ sft_dataset = DatasetDict({
760
+ "train": Dataset.from_pandas(pandas_df[:train_size]),
761
+ "validation": Dataset.from_pandas(pandas_df[train_size:])
762
+ })
763
+ #############################
764
+
765
+ #############################
766
+ # continued pre-training #
767
+ # dataset #
768
+ #############################
769
+ struct_schema = pl.Struct([
770
+ pl.Field("name", pl.String),
771
+ pl.Field("description", pl.String),
772
+ pl.Field(
773
+ "parameters",
774
+ pl.String # Use String to allow
775
+ # for varying structures
776
+ # (different tools indeed having
777
+ # different sets of parameters
778
+ # i.e. different parameters counts,
779
+ # datatypes and names)
780
+ # so parsing must be tolerant.
781
+ )
782
+ ])
783
+ unique_tools_df = get_unique_tools(
784
+ self.hf_dataset_dict["lazy_df"],
785
+ tools_attr_name=\
786
+ self.hf_dataset["attributes"]["tools_attr"],
787
+ struct_schema=struct_schema
788
+ ).collect(engine=self.engine)
789
+ unique_tools_arrow_table = unique_tools_df.to_arrow()
790
+ self.unique_tools_dataset = \
791
+ Dataset(unique_tools_arrow_table)
792
+ print(self.unique_tools_dataset)
793
+ #############################
794
+
795
+ #############################
796
+ # DatasetDict #
797
+ # with multiple tables #
798
+ #############################
799
+ dataset_dict = DatasetDict({
800
+ "continued_pre_training": \
801
+ self.unique_tools_dataset,
802
+ "supervised_finetuning": sft_dataset
803
+ })
804
+ print(dataset_dict, flush=True)
805
+ #############################
806
+
807
+ #############################
808
+ # dataset README #
809
+ # from template #
810
+ #############################
811
+ commit_datetime = datetime.utcnow()
812
+ new_dataset_version_label = get_new_repo_minor_version(
813
+ repo_id=self.dataset_repo_id,
814
+ repo_type="dataset",
815
+ hf_token=os.getenv("HF_TOKEN", None))
816
+ readme_content = get_dataset_readme_content(
817
+ template_folder=template_dir,
818
+
819
+ hf_dataset_dict=self.hf_dataset_dict,
820
+ hf_enrich_dataset_dict=self.hf_enrich_dataset_dict,
821
+ dataset_dict=dataset_dict,
822
+
823
+ augmentation_rate=self.actual_augmentation_rate,
824
+ enrichment_rate=self.enrichment_rate,
825
+
826
+ version_label=new_dataset_version_label,
827
+ commit_datetime=commit_datetime,
828
+
829
+ mf_flow_name=current.flow_name,
830
+ mf_run_id=current.run.id,
831
+ engine=self.engine
832
+ )
833
+ #############################
834
+
835
+ dataset_commit_hash = push_dataset_version_to_hub(
836
+ repo_id=self.dataset_repo_id,
837
+ version_label=new_dataset_version_label,
838
+ timestamp_str=commit_datetime.strftime(
839
+ "%Y-%m-%d %H:%M:%S UTC"),
840
+ dataset_dict=dataset_dict,
841
+ dataset_readme_content=readme_content,
842
+ hf_token=os.getenv("HF_TOKEN", None)
843
+ )
844
+ if not dataset_commit_hash:
845
+ raise Exception(
846
+ "Failed to publish dataset version.")
847
+ print(f"https://huggingface.co/datasets/{self.dataset_repo_id}" +
848
+ f"/blob/{dataset_commit_hash}/README.md")
849
+ self.dataset_commit_dict = {
850
+ "repo_id": self.dataset_repo_id,
851
+ "commit_hash": dataset_commit_hash,
852
+ "version_label": new_dataset_version_label,
853
+ "commit_datetime": commit_datetime,
854
+ }
855
+
856
+ self.next(self.continued_pre_training)
857
+
858
+
859
+ @step
860
+ def continued_pre_training(self):
861
+ """
862
+ Gives the base model some additional intrinsic knowkledge
863
+ through continued pre-training.
864
+ See unsloth.ai/blog/contpretraining
865
+ """
866
+ from retrain_pipelines.model.hf_utils import \
867
+ plot_log_history
868
+
869
+ #######################################
870
+ # base-model and associated tokenizer #
871
+ # from Hub (or local cache) #
872
+ #######################################
873
+ self.max_seq_length = 2048
874
+ model, tokenizer = FastLanguageModel.from_pretrained(
875
+ model_name=self.hf_base_model_dict["repo_id"],
876
+ revision=self.hf_base_model_dict["commit_hash"],
877
+ max_seq_length=self.max_seq_length,
878
+ dtype=None,
879
+ load_in_4bit=False,
880
+ # case of a gated or private base-model
881
+ token=os.getenv("HF_TOKEN", None)
882
+ )
883
+ #######################################
884
+
885
+ #######################################
886
+ # dataset prompt_template mapping #
887
+ #######################################
888
+ tools_dataset = DatasetDict(
889
+ {"train": self.unique_tools_dataset})
890
+ print(tools_dataset)
891
+ tool_prompt_template = "tool: {}"
892
+ def formatting_prompts_func(tools_batch):
893
+ tools_batch = tools_batch["tool"]
894
+ outputs = []
895
+ for tool in tools_batch:
896
+ # Must add EOS_TOKEN,
897
+ # otherwise generation will go on forever!
898
+ text = tool_prompt_template.format(tool) + \
899
+ tokenizer.eos_token
900
+ outputs.append(text)
901
+ return { "tools" : outputs, }
902
+ cpt_dataset = tools_dataset["train"].map(
903
+ formatting_prompts_func, batched=True,)
904
+ #######################################
905
+
906
+ #######################################
907
+ # PEFT adapter #
908
+ # for continued pre-training #
909
+ #######################################
910
+ model = FastLanguageModel.get_peft_model(
911
+ model,
912
+ r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
913
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
914
+ "gate_proj", "up_proj", "down_proj",
915
+ # Add for continued pretraining
916
+ "embed_tokens", "lm_head",],
917
+ lora_alpha = 32,
918
+ lora_dropout = 0, # Supports any, 0 is optimized
919
+ bias = "none", # Supports any, "none" is optimized
920
+ # True or "unsloth" for very long context
921
+ use_gradient_checkpointing = "unsloth",
922
+ use_rslora = True, # rank-stabilized LoRA
923
+ loftq_config = None, # LoftQ
924
+ #random_state = 3407,
925
+ )
926
+ #######################################
927
+
928
+ #######################################
929
+ # cpt_trainer #
930
+ #######################################
931
+ if (
932
+ "records_cap" in self.cpt_training_args and
933
+ self.cpt_training_args["records_cap"] is not None and
934
+ isinstance(self.cpt_training_args["records_cap"], int)
935
+ ):
936
+ cpt_dataset = cpt_dataset.take(
937
+ self.cpt_training_args["records_cap"])
938
+ print(f"cpt_dataset : {cpt_dataset}")
939
+
940
+ train_args = UnslothTrainingArguments(
941
+ # https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_strategy
942
+ per_device_train_batch_size=2,
943
+ gradient_accumulation_steps=8,
944
+
945
+ **{k: v for k, v in self.cpt_training_args.items()
946
+ if k != "records_cap"},
947
+
948
+ # 2 to 10x smaller learning rate
949
+ # for the embedding matrices
950
+ learning_rate=5e-5,
951
+ embedding_learning_rate=1e-5,
952
+
953
+ fp16=not is_bfloat16_supported(),
954
+ bf16=is_bfloat16_supported(),
955
+ logging_steps=1,
956
+ optim="adamw_8bit",
957
+ weight_decay=0.01,
958
+ lr_scheduler_type="linear",
959
+ #seed=3407,
960
+
961
+ output_dir=os.path.join(
962
+ self.unsloth_dir, "outputs", "cpt"),
963
+ save_total_limit = 2,
964
+
965
+ report_to="tensorboard",
966
+ logging_dir=os.path.join(
967
+ self.sft_model_dir,
968
+ "runs", "cpt")
969
+ )
970
+
971
+ self.cpt_traces_file_fullname = os.path.join(
972
+ self.unsloth_dir, "cpt_trainer_traces.txt")
973
+ print("Training started. " +
974
+ f"Check {self.cpt_traces_file_fullname} for live traces.",
975
+ flush=True)
976
+
977
+ trainer = UnslothTrainer(
978
+ model=model, tokenizer=tokenizer,
979
+ train_dataset=cpt_dataset,
980
+ dataset_text_field="tools",
981
+ max_seq_length=self.max_seq_length,
982
+ dataset_num_proc=2,
983
+ args=train_args,
984
+ )
985
+ #######################################
986
+
987
+ #######################################
988
+ # Show current memory stats #
989
+ #######################################
990
+ torch.cuda.ipc_collect()
991
+ torch.cuda.empty_cache()
992
+ gc.collect()
993
+
994
+ gpu_stats = torch.cuda.get_device_properties(0)
995
+ self.start_gpu_memory = \
996
+ round(torch.cuda.max_memory_reserved()
997
+ / 1024 / 1024 / 1024, 3)
998
+ self.max_memory = \
999
+ round(gpu_stats.total_memory
1000
+ / 1024 / 1024 / 1024, 3)
1001
+ print(f"GPU = {gpu_stats.name}. " +
1002
+ f"Max memory = {self.max_memory} GB.")
1003
+ print(f"{self.start_gpu_memory} GB of memory reserved.")
1004
+ #######################################
1005
+
1006
+ with open(self.cpt_traces_file_fullname, 'w') as f:
1007
+ with redirect_stdout(f):
1008
+ hf_logging.set_verbosity_error()
1009
+ hf_logging.disable_progress_bar()
1010
+ trainer_stats = trainer.train()
1011
+ hf_logging.set_verbosity_info()
1012
+ hf_logging.enable_progress_bar()
1013
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1014
+ f"seconds used for training " +
1015
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1016
+ f" minutes).")
1017
+
1018
+ self.cpt_log_history = trainer.state.log_history
1019
+ # print(self.cpt_log_history)
1020
+ self.cpt_log_history_fig = \
1021
+ plot_log_history(
1022
+ self.cpt_log_history,
1023
+ title="Continued pretraining loss"
1024
+ )
1025
+
1026
+ model.save_pretrained_merged(
1027
+ save_directory=self.cpt_model_dir,
1028
+ tokenizer=tokenizer,
1029
+ save_method="lora"
1030
+ )
1031
+ print(f"cpt_model_dir : {self.cpt_model_dir}\n")
1032
+
1033
+ self.next(self.supervised_finetuning)
1034
+
1035
+
1036
+ @step
1037
+ def supervised_finetuning(self):
1038
+ """
1039
+ Trains the model on tool-calling
1040
+ task specialization.
1041
+ """
1042
+ from retrain_pipelines.model.hf_utils import \
1043
+ plot_log_history
1044
+
1045
+ torch.cuda.ipc_collect()
1046
+ torch.cuda.empty_cache()
1047
+ gc.collect()
1048
+
1049
+ model, tokenizer = FastLanguageModel.from_pretrained(
1050
+ model_name=self.cpt_model_dir,
1051
+ max_seq_length=self.max_seq_length,
1052
+ dtype=None,
1053
+ load_in_4bit=False,
1054
+ )
1055
+ # !!!! bug fix BEGIN !!!!
1056
+ # otherwise, 'embed_tokens' and 'lm_head'
1057
+ # trained during CPT are "ignored",
1058
+ # i.e. not saved after SFT
1059
+ # (note that, alternatively, we could also
1060
+ # do this fix after sft-training and
1061
+ # just before saving ;
1062
+ # which would be equivalent to
1063
+ # freezing embeddings during finetuning
1064
+ # for better pretrained knowledge retention)
1065
+ # @see https://www.reddit.com/r/unsloth/comments/1dtzcd6/fastlanguagemodelpatch_peft_model_changing/
1066
+ model.model.model.embed_tokens.modules_to_save.default.to(
1067
+ device="cuda:0",
1068
+ dtype=torch.float32,
1069
+ non_blocking=True)
1070
+ model.model.model.embed_tokens.modules_to_save.default \
1071
+ .requires_grad_(True)
1072
+ model.model.lm_head.modules_to_save.default.to(
1073
+ device="cuda:0",
1074
+ dtype=torch.float32,
1075
+ non_blocking=True)
1076
+ model.model.lm_head.modules_to_save.default \
1077
+ .requires_grad_(True)
1078
+ # !!!! bug fix END !!!!
1079
+
1080
+ #######################################
1081
+ # dataset prompt_template mapping #
1082
+ #######################################
1083
+ # download from Hub (or get from local cache)
1084
+ queries_dataset = load_dataset(
1085
+ path=self.dataset_commit_dict["repo_id"],
1086
+ name="supervised_finetuning",
1087
+ revision=self.dataset_commit_dict["commit_hash"],
1088
+ token=os.getenv("HF_TOKEN", None))
1089
+ print(f"HF_DATASETS_CACHE : {HF_DATASETS_CACHE}") # HF_CACHE_HOME
1090
+ self.sft_prompt_template = dedent("""
1091
+ You specialize in generating tool calls. Given a query, your task is to return a list of tool calls based on your knowledge of known tools.
1092
+
1093
+ Rules:
1094
+ 1. You can only use tools you know. Do not create new tools under any circumstances.
1095
+ 2. If a query does not match any known tool, return an empty list ([]).
1096
+ 3. If information is missing to use a known tool, do not attempt to use it.
1097
+ 4. Your response must always be a valid JSON array, and nothing else.
1098
+
1099
+ Be precise and do not guess.
1100
+
1101
+ # query:
1102
+ {}
1103
+ # response:
1104
+ {}
1105
+ """).strip()
1106
+ tokenizer.chat_template = self.sft_prompt_template
1107
+
1108
+ EOS_TOKEN = tokenizer.eos_token
1109
+ def formatting_prompts_func(records):
1110
+ query = records["query"]
1111
+ tools = records["answers"]
1112
+ outputs = []
1113
+ for query, tools in zip(query, tools):
1114
+ # Must add EOS_TOKEN,
1115
+ # otherwise your generation will go on forever
1116
+ text = self.sft_prompt_template.format(query, tools) \
1117
+ + EOS_TOKEN
1118
+ outputs.append(text)
1119
+ return { "text" : outputs, }
1120
+ sft_train_dataset = queries_dataset["train"].map(
1121
+ formatting_prompts_func, batched=True)
1122
+ sft_valid_dataset = queries_dataset["validation"].map(
1123
+ formatting_prompts_func, batched=True,)
1124
+ #######################################
1125
+
1126
+ #######################################
1127
+ # PEFT adapter #
1128
+ # for supervised finetuning #
1129
+ #######################################
1130
+ # for cases where CPT has been merged into overall model
1131
+ # otherwize, keep on training current LoRa adapter
1132
+ # model = FastLanguageModel.get_peft_model(
1133
+ # model,
1134
+ # r = 128, # any number >0 ; 8, 16, 32, 64, 128, 256
1135
+ # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
1136
+ # "gate_proj", "up_proj", "down_proj"],
1137
+ # lora_alpha = 32,
1138
+ # lora_dropout = 0, # Supports any, but = 0 is optimized
1139
+ # bias = "none", # Supports any, but = "none" is optimized
1140
+ # # True or "unsloth" for very long context
1141
+ # use_gradient_checkpointing = "unsloth",
1142
+ # random_state = 3407,
1143
+ # use_rslora = True, # rank stabilized LoRA
1144
+ # loftq_config = None, # LoftQ
1145
+ # )
1146
+ #######################################
1147
+
1148
+ #######################################
1149
+ # sft_trainer #
1150
+ #######################################
1151
+ split = sft_train_dataset.train_test_split(
1152
+ test_size=1000,
1153
+ #seed=42
1154
+ )
1155
+ train_dataset = split['train']
1156
+ eval_dataset = split['test']
1157
+ if (
1158
+ "records_cap" in self.sft_training_args and
1159
+ self.sft_training_args["records_cap"] is not None and
1160
+ isinstance(self.sft_training_args["records_cap"], int)
1161
+ ):
1162
+ train_dataset = train_dataset.take(
1163
+ self.sft_training_args["records_cap"])
1164
+ eval_dataset = eval_dataset.take(
1165
+ self.sft_training_args["records_cap"])
1166
+ print(f"train_dataset : {train_dataset}")
1167
+ print(f"eval_dataset : {eval_dataset}")
1168
+
1169
+ train_args = UnslothTrainingArguments(
1170
+ per_device_train_batch_size=2,
1171
+ gradient_accumulation_steps=8,
1172
+
1173
+ **{k: v for k, v in self.sft_training_args.items()
1174
+ if k != "records_cap"},
1175
+
1176
+ per_device_eval_batch_size=2,
1177
+ eval_steps=200,
1178
+ eval_strategy="steps",
1179
+ do_eval=True,
1180
+
1181
+ learning_rate=5e-5,
1182
+ # embedding_learning_rate=1e-5, # Optionally here
1183
+
1184
+ fp16=not is_bfloat16_supported(),
1185
+ bf16=is_bfloat16_supported(),
1186
+
1187
+ optim="adamw_8bit",
1188
+ weight_decay=0.00,
1189
+ lr_scheduler_type="linear",
1190
+ #seed=3407,
1191
+
1192
+ output_dir=os.path.join(
1193
+ self.unsloth_dir, "outputs", "sft"),
1194
+ save_total_limit=2,
1195
+
1196
+ logging_steps=1,
1197
+ report_to="tensorboard",
1198
+ logging_dir=os.path.join(
1199
+ self.sft_model_dir,
1200
+ "runs", "sft")
1201
+ )
1202
+
1203
+ self.sft_traces_file_fullname = os.path.join(
1204
+ self.unsloth_dir, "sft_trainer_traces.txt")
1205
+ print("Training started. " +
1206
+ f"Check {self.sft_traces_file_fullname} for live traces.",
1207
+ flush=True)
1208
+
1209
+ trainer = UnslothTrainer(
1210
+ model=model, tokenizer=tokenizer,
1211
+ train_dataset=train_dataset,
1212
+ dataset_text_field="text",
1213
+ eval_dataset=eval_dataset,
1214
+ max_seq_length=self.max_seq_length,
1215
+ dataset_num_proc=8,
1216
+ args=train_args
1217
+ )
1218
+ trainer.can_return_loss = True
1219
+ #######################################
1220
+
1221
+ #######################################
1222
+ # Show current memory stats #
1223
+ #######################################
1224
+ torch.cuda.ipc_collect()
1225
+ torch.cuda.empty_cache()
1226
+ gc.collect()
1227
+
1228
+ used_memory = \
1229
+ round(torch.cuda.max_memory_reserved()
1230
+ /1024/1024/1024, 3)
1231
+ used_memory_for_lora = \
1232
+ round(used_memory-self.start_gpu_memory, 3)
1233
+ used_percentage = \
1234
+ round(used_memory/self.max_memory*100, 3)
1235
+ lora_percentage = \
1236
+ round(used_memory_for_lora/self.max_memory*100,
1237
+ 3)
1238
+ print(f"Peak reserved memory = " +
1239
+ f"{used_memory} GB.")
1240
+ print(f"Peak reserved memory for " +
1241
+ f"training = {used_memory_for_lora} " +
1242
+ f"GB.")
1243
+ print(f"Peak reserved memory % of " +
1244
+ f"max memory = {used_percentage} %.")
1245
+ print(f"Peak reserved memory for training " +
1246
+ f"% of max memory = {lora_percentage} %.")
1247
+ #######################################
1248
+
1249
+ with open(self.sft_traces_file_fullname, 'w') as f:
1250
+ with redirect_stdout(f):
1251
+ hf_logging.set_verbosity_error()
1252
+ hf_logging.disable_progress_bar()
1253
+ trainer_stats = trainer.train()
1254
+ hf_logging.set_verbosity_info()
1255
+ hf_logging.enable_progress_bar()
1256
+ print(f"{trainer_stats.metrics['train_runtime']} " +
1257
+ f"seconds used for training " +
1258
+ f"({round(trainer_stats.metrics['train_runtime']/60, 2)}" +
1259
+ f" minutes).")
1260
+
1261
+ self.sft_log_history = trainer.state.log_history
1262
+ self.sft_log_history_fig = \
1263
+ plot_log_history(
1264
+ self.sft_log_history,
1265
+ title="Supervised finetuning loss"
1266
+ )
1267
+
1268
+ model.save_pretrained_merged(
1269
+ self.sft_model_dir, tokenizer,
1270
+ save_method = "lora"
1271
+ )
1272
+ print(f"sft_model_dir : {self.sft_model_dir}\n")
1273
+
1274
+ self.next(self.evaluate_model)
1275
+
1276
+
1277
+ @step
1278
+ def evaluate_model(self):
1279
+ """
1280
+ Batch inference on the SFT validation dataset.
1281
+ """
1282
+ from retrain_pipelines.model import \
1283
+ infer_validation, compute_counts_n_metrics, \
1284
+ plot_validation_completions
1285
+
1286
+ torch.cuda.ipc_collect()
1287
+ torch.cuda.empty_cache()
1288
+ gc.collect()
1289
+
1290
+
1291
+ ######################################################
1292
+ # loading trained adapter #
1293
+ ######################################################
1294
+ # Unsloth (if loading both model & tokenizer at once #
1295
+ # same as we did in prior tasks, but now #
1296
+ # with tokenizer.chat_template being set #
1297
+ # in tokenizer.config) is forcing on us some kind of #
1298
+ # chat_template format hard-requirements #
1299
+ # coming from their dream-fantasmagorical world.. #
1300
+ ######################################################
1301
+ # load base from cache
1302
+ # (with base tokenizer, which we ignore)
1303
+ model, _ = FastLanguageModel.from_pretrained(
1304
+ model_name=self.hf_base_model_dict["repo_id"],
1305
+ revision=self.hf_base_model_dict["commit_hash"],
1306
+ max_seq_length=self.max_seq_length,
1307
+ dtype=None,
1308
+ load_in_4bit=False,
1309
+ # case of a gated or private base-model
1310
+ token=os.getenv("HF_TOKEN", None)
1311
+ )
1312
+ model = FastLanguageModel.for_inference(model)
1313
+ # load our CPT+SFT trained & locally-saved adapter
1314
+ model.load_adapter(peft_model_id=self.sft_model_dir)
1315
+ # Separately load our (potentially trained &)
1316
+ # locally-saved adapter-tokenizer
1317
+ # (loading it below via HF and not Unsloth)
1318
+ tokenizer = AutoTokenizer.from_pretrained(
1319
+ pretrained_model_name_or_path=self.sft_model_dir
1320
+ )
1321
+ ######################################################
1322
+
1323
+ ######################################################
1324
+ # validation dataset #
1325
+ ######################################################
1326
+ # download from Hub (or get from local cache)
1327
+ queries_dataset = load_dataset(
1328
+ path=self.dataset_commit_dict["repo_id"],
1329
+ name="supervised_finetuning",
1330
+ revision=self.dataset_commit_dict["commit_hash"],
1331
+ token=os.getenv("HF_TOKEN", None))
1332
+ if (
1333
+ "records_cap" in self.sft_training_args and
1334
+ self.sft_training_args["records_cap"] is not None and
1335
+ isinstance(self.sft_training_args["records_cap"], int)
1336
+ ):
1337
+ validation_data = queries_dataset["validation"].take(
1338
+ self.sft_training_args["records_cap"])
1339
+ else:
1340
+ validation_data = queries_dataset["validation"]
1341
+ print(validation_data, flush=True)
1342
+ ######################################################
1343
+
1344
+ self.max_new_tokens = 400
1345
+ start_time = time.time()
1346
+ validation_results = infer_validation(
1347
+ tokenizer=tokenizer,
1348
+ model=model,
1349
+ validation_data=validation_data,
1350
+ prompt_template=tokenizer.chat_template,
1351
+ batch_size=32, # 64,
1352
+ queries_attr_name=\
1353
+ self.hf_dataset["attributes"]["query_attr"],
1354
+ answers_attr_name=\
1355
+ self.hf_dataset["attributes"]["answers_attr"],
1356
+ max_new_tokens=self.max_new_tokens,
1357
+ device="cuda"
1358
+ )
1359
+ print("infer_validation - Elapsed time: " +
1360
+ f"{(time.time() - start_time):.2f} seconds")
1361
+ self.validation_results = validation_results # <= to artifacts store
1362
+
1363
+ eval_df = pl.LazyFrame(validation_results)
1364
+
1365
+ records = eval_df.with_columns(
1366
+ (pl.col("answer") == pl.col("completion")) \
1367
+ .alias("is_ground_truth_identical")
1368
+ ).collect() #engine=self.engine)
1369
+ print("perfect characters-match accuracy : " +
1370
+ str(records['is_ground_truth_identical'].mean()))
1371
+
1372
+ eval_metrics_df = compute_counts_n_metrics(
1373
+ eval_df, is_format_fault_tolerant=True)
1374
+ overall_metrics_df = eval_metrics_df.select([
1375
+ pl.col("precision").mean(),
1376
+ pl.col("recall").mean(),
1377
+ pl.col("f1").mean(),
1378
+ pl.col("jaccard").mean()
1379
+ ]).collect() #engine=self.engine)
1380
+ self.perf_metrics = overall_metrics_df.row(0, named=True)
1381
+ print(self.perf_metrics)
1382
+
1383
+ self.validation_completions_fig = \
1384
+ plot_validation_completions(
1385
+ eval_metrics_df, engine=self.engine)
1386
+
1387
+ del model
1388
+ del tokenizer
1389
+ torch.cuda.ipc_collect()
1390
+ torch.cuda.empty_cache()
1391
+ gc.collect()
1392
+
1393
+ self.next(self.model_version_blessing)
1394
+
1395
+
1396
+ @step
1397
+ def model_version_blessing(self):
1398
+ """
1399
+ Comparing newly-retrained model version
1400
+ against best-performing predecessor.
1401
+ """
1402
+ """
1403
+ Note: for Hugging Face integrated pipelines,
1404
+ we compare against lastest commit of main branch
1405
+ of the model repository there.
1406
+ When it comes to local "mf_run_id" of the pipeline run
1407
+ having generated that best prior model version
1408
+ (retrieved from model card metadata from HF yaml section),
1409
+ we check against records of the herein ML-framework instance,
1410
+ as "prior best version" of the model here beign retrained
1411
+ may have been originated from another one
1412
+ than the one executing the current retraining
1413
+ (in which case, we simply don't includ a "local" hyperlink
1414
+ in the model version pipeline_cards that will be
1415
+ produced later in the herein pipeline run).
1416
+ """
1417
+ from retrain_pipelines.model.hf_utils import \
1418
+ current_blessed_model_version_dict
1419
+
1420
+ main_perf_metric_name = "jaccard"
1421
+
1422
+ current_blessed_version_dict = \
1423
+ current_blessed_model_version_dict(
1424
+ repo_id=self.model_repo_id,
1425
+ hf_token=os.getenv("HF_TOKEN", None)
1426
+ )
1427
+ print("current_blessed_version_dict : " +
1428
+ str(current_blessed_version_dict))
1429
+
1430
+ if current_blessed_version_dict is None:
1431
+ print("case 'no prior blessed model version found"
1432
+ " => blessing.'")
1433
+ self.model_version_blessed = True
1434
+
1435
+ elif (
1436
+ main_perf_metric_name in
1437
+ current_blessed_version_dict["perf_metrics"]
1438
+ ):
1439
+ current_blessed_run_id = \
1440
+ current_blessed_version_dict["mf_run_id"]
1441
+ current_blessed_metric_value = \
1442
+ current_blessed_version_dict[
1443
+ "perf_metrics"][main_perf_metric_name]
1444
+
1445
+ self.model_version_blessed = (
1446
+ self.perf_metrics[main_perf_metric_name] >=
1447
+ current_blessed_metric_value
1448
+ )
1449
+
1450
+ if not self.model_version_blessed:
1451
+ self.current_blessed_version_dict = \
1452
+ current_blessed_version_dict
1453
+ for run in Flow(self.__class__.__name__):
1454
+ if str(run.id) == current_blessed_run_id:
1455
+ last_run_step = next(iter(run.steps()))
1456
+ last_task = next(iter(last_run_step.tasks()))
1457
+ # further filtering on successful runs that are
1458
+ # retraining of a prior version of the same model
1459
+ # (to minimize the risk that this was obtained
1460
+ # on another ML-framework instance)
1461
+ if (
1462
+ last_task.successful and
1463
+ hasattr(last_task.artifacts, 'model_version_blessed') and
1464
+ last_task.artifacts.model_version_blessed.data and
1465
+ hasattr(last_task.artifacts, 'model_repo_id') and
1466
+ last_task.artifacts.model_repo_id.data == self.model_repo_id
1467
+ ):
1468
+ self.current_blessed_run = run
1469
+ break
1470
+
1471
+ if not self.current_blessed_run:
1472
+ print(
1473
+ "Couldn't find blessed run " +
1474
+ f"{current_blessed_run_id} !\n" +
1475
+ "It seems that prior blessed run was " +
1476
+ "executed on another ML framework instance.",
1477
+ file=sys.stderr, flush=True)
1478
+
1479
+ print("new : " +
1480
+ str(self.perf_metrics[main_perf_metric_name]) +
1481
+ " - previous best : " +
1482
+ str(current_blessed_metric_value) +
1483
+ " - model_version_blessing : " +
1484
+ str(self.model_version_blessed))
1485
+
1486
+ else:
1487
+ raise Exception(
1488
+ "Performance metric '" +
1489
+ main_perf_metric_name +
1490
+ "' can't be found in eval results " +
1491
+ "from blessed run " +
1492
+ str(current_blessed_version_dict[
1493
+ "mf_run_id"]) + " !")
1494
+
1495
+ # self.model_version_blessed = True ### DEBUG - DELETE ###
1496
+
1497
+ self.next(self.model_to_hub)
1498
+
1499
+
1500
+ @step
1501
+ def model_to_hub(self):
1502
+ """
1503
+ Push to hub model version, including
1504
+ readme with versioning info.
1505
+ """
1506
+
1507
+ #############################
1508
+ # case of user-provided #
1509
+ # documentation artifact(s) #
1510
+ #############################
1511
+ # note that user can provide either
1512
+ # 'pipeline_card.py' or 'template.html'
1513
+ # or 'dataset_readme.py'
1514
+ # or 'dataset_readme_template.md'
1515
+ # or 'model_readme.py'
1516
+ # or 'model_readme_template.md'
1517
+ # or any combination of those
1518
+ # when specifying custom
1519
+ # 'pipeline_card_artifacts_path'
1520
+ if (
1521
+ "model_readme_template.md" in
1522
+ os.listdir(self.pipeline_card_artifacts_path)
1523
+ ):
1524
+ template_dir = self.pipeline_card_artifacts_path
1525
+ else:
1526
+ template_dir = os.path.dirname(
1527
+ importlib.util.find_spec(
1528
+ f"retrain_pipelines.pipeline_card."+
1529
+ f"{os.getenv('retrain_pipeline_type')}"
1530
+ ).origin)
1531
+ print(f"template_dir : '{template_dir}'")
1532
+ #############################
1533
+ if "model_readme.py" in os.listdir(
1534
+ self.pipeline_card_artifacts_path):
1535
+ from retrain_pipelines.utils import \
1536
+ get_get_model_readme_content
1537
+ get_model_readme_content = \
1538
+ get_get_model_readme_content(
1539
+ self.pipeline_card_artifacts_path)
1540
+ else:
1541
+ from retrain_pipelines.pipeline_card import \
1542
+ get_model_readme_content
1543
+ #############################
1544
+ from retrain_pipelines.model.hf_utils import \
1545
+ push_model_version_to_hub
1546
+
1547
+ #############################
1548
+ # model README #
1549
+ # from template #
1550
+ #############################
1551
+ commit_datetime = datetime.utcnow()
1552
+ new_model_version_label = get_new_repo_minor_version(
1553
+ repo_id=self.model_repo_id,
1554
+ repo_type="model",
1555
+ hf_token=os.getenv("HF_TOKEN", None))
1556
+ readme_content = get_model_readme_content(
1557
+ template_folder=template_dir,
1558
+
1559
+ model_repo_id=self.model_repo_id,
1560
+
1561
+ base_model_dict=self.hf_base_model_dict,
1562
+ training_dataset_dict=self.dataset_commit_dict,
1563
+
1564
+ version_label=new_model_version_label,
1565
+ commit_datetime=commit_datetime,
1566
+ perf_metrics=self.perf_metrics,
1567
+
1568
+ mf_flow_name=current.flow_name,
1569
+ mf_run_id=current.run.id
1570
+ )
1571
+ #############################
1572
+
1573
+ print("Pushing model version to HF hub " +
1574
+ ("(blessed). " if self.model_version_blessed
1575
+ else "(not blessed). ") +
1576
+ "May take a while..",
1577
+ flush=True)
1578
+ model_commit_hash = push_model_version_to_hub(
1579
+ repo_id=self.model_repo_id,
1580
+ model_version_blessed=\
1581
+ self.model_version_blessed,
1582
+ version_label=new_model_version_label,
1583
+ timestamp_str=commit_datetime.strftime(
1584
+ "%Y-%m-%d %H:%M:%S UTC"),
1585
+ model_dir=self.sft_model_dir,
1586
+ model_readme_content=readme_content,
1587
+ hf_token=os.getenv("HF_TOKEN", None)
1588
+ )
1589
+ if not model_commit_hash:
1590
+ raise Exception(
1591
+ "Failed to publish model version.")
1592
+ print("Push of model version to HF hub completed.",
1593
+ flush=True)
1594
+ print(f"https://huggingface.co/{self.model_repo_id}" +
1595
+ f"/blob/{model_commit_hash}/README.md")
1596
+
1597
+ self.model_commit_dict = {
1598
+ "repo_id": self.model_repo_id,
1599
+ "commit_hash": model_commit_hash,
1600
+ "version_label": new_model_version_label,
1601
+ "commit_datetime": commit_datetime,
1602
+ }
1603
+
1604
+ self.next(self.infra_validator)
1605
+
1606
+
1607
+ @step
1608
+ def infra_validator(self):
1609
+ """
1610
+ If the trained model version is blessed,
1611
+ validate serving.
1612
+ """
1613
+ """
1614
+ Note that using isolated virtual env
1615
+ (using @conda task decorator)
1616
+ is advisable to not embark the whole
1617
+ pipeline dependencies into the local server.
1618
+ We don't for educational purpose,
1619
+ keep things "simple" to grasp
1620
+ as well as to avoid forcing conda
1621
+ (for instance miniconda) as
1622
+ a virtual environment management mean
1623
+ to the user.
1624
+ """
1625
+ """
1626
+ Note : We load base model from HF-cache
1627
+ (mounted as /huggingface_hub_cache
1628
+ docker volume) and adapter from local dir
1629
+ (mounted as /FuncCallAdater docker volume.
1630
+ """
1631
+
1632
+ self.local_serve_is_ready = LocalServeReadinessEnum.NOT_APPLICABLE
1633
+
1634
+ if self.model_version_blessed:
1635
+ from retrain_pipelines.utils.docker import \
1636
+ env_has_docker
1637
+
1638
+ if env_has_docker():
1639
+ model_module_dir = \
1640
+ os.path.dirname(
1641
+ importlib.util.find_spec(
1642
+ "retrain_pipelines.model." +
1643
+ os.getenv('retrain_pipeline_type')
1644
+ ).origin)
1645
+
1646
+ # server & data-model & server-config modules artifacts
1647
+ files_to_copy = [
1648
+ "litserve_server.py",
1649
+ "litserve_datamodel.py",
1650
+ "litserve_serverconfig.py",
1651
+ ".dockerignore" # docker context loading
1652
+ # at image-build time,
1653
+ # exclude model weights
1654
+ ]
1655
+ for filename in files_to_copy:
1656
+ shutil.copy(
1657
+ os.path.join(model_module_dir, "litserve",
1658
+ filename),
1659
+ os.path.join(self.serving_artifacts_local_folder,
1660
+ filename)
1661
+ )
1662
+
1663
+ # save dependencies as artifact
1664
+ create_requirements(self.serving_artifacts_local_folder,
1665
+ exclude=["cudf-polars-.*", "cuda-python",
1666
+ "nvidia-.*", "(py)?libcudf-.*",
1667
+ "nvtx", "rmm-.*", "litserve",
1668
+ ".*retrain-pipelines.*"]
1669
+ )
1670
+
1671
+ # server config yaml
1672
+ env = Environment(loader=FileSystemLoader(
1673
+ os.path.join(model_module_dir, "litserve")))
1674
+ template = env.get_template(
1675
+ "litserve_serverconfig_template.yaml")
1676
+ server_config_data = {
1677
+ "port": "8000",
1678
+ "max_seq_length": self.max_seq_length,
1679
+ "max_new_token": self.max_new_tokens,
1680
+ "base_model": {
1681
+ "repo_id": self.hf_base_model_dict["repo_id"],
1682
+ "revision": self.hf_base_model_dict["commit_hash"]
1683
+ },
1684
+ "adapters": [
1685
+ {
1686
+ "name": "func_caller",
1687
+ "path": "/FuncCallAdapter"
1688
+ }
1689
+ ]
1690
+ }
1691
+ server_config_yaml = template.render(server_config_data)
1692
+ print(server_config_yaml)
1693
+ with open(os.path.join(
1694
+ self.serving_artifacts_local_folder,
1695
+ "litserve_serverconfig.yaml"), 'w'
1696
+ ) as output_file:
1697
+ output_file.write(server_config_yaml)
1698
+
1699
+ # Dockerfile
1700
+ env = Environment(loader=FileSystemLoader(
1701
+ os.path.join(model_module_dir)))
1702
+ template = env.get_template(
1703
+ "Dockerfile.litserve_template")
1704
+ # Change CUDA version here from available list
1705
+ # @see https://hub.docker.com/r/nvidia/cuda/tags
1706
+ dockerfile_content = template.render(
1707
+ {"cuda_version": "12.0.0"})
1708
+ with open(os.path.join(
1709
+ self.serving_artifacts_local_folder,
1710
+ "Dockerfile.litserve"), 'w'
1711
+ ) as output_file:
1712
+ output_file.write(dockerfile_content)
1713
+
1714
+ os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
1715
+
1716
+ ############################################
1717
+ # actually deploy the inference service #
1718
+ ############################################
1719
+ start_time = time.time()
1720
+ from retrain_pipelines.utils.docker import \
1721
+ build_and_run_docker, print_container_log_tail, \
1722
+ cleanup_docker
1723
+ from retrain_pipelines.model.litserve import \
1724
+ endpoint_started, endpoint_is_ready
1725
+
1726
+ self.port = 8765
1727
+ HF_HUB_CACHE = os.path.realpath(os.path.expanduser(
1728
+ os.getenv(
1729
+ "HF_HUB_CACHE",
1730
+ os.path.join(os.getenv("HF_HOME",
1731
+ "~/.cache/huggingface"),
1732
+ "hub")
1733
+ )))
1734
+ print(f"HF_HUB_CACHE : {HF_HUB_CACHE}")
1735
+ image_name = container_name = "litserve-model"
1736
+
1737
+ serving_container = build_and_run_docker(
1738
+ image_name=image_name, image_tag="1.0",
1739
+ build_path=self.serving_artifacts_local_folder,
1740
+ dockerfile="Dockerfile.litserve",
1741
+ ports_publish_dict={'8000/tcp': self.port},
1742
+ env_vars_dict={
1743
+ "HF_HUB_CACHE": "/huggingface_hub_cache",
1744
+ "HF_TOKEN": os.getenv("HF_TOKEN")
1745
+ },
1746
+ volumes_dict={
1747
+ self.sft_model_dir:
1748
+ {"bind": "/FuncCallAdapter",
1749
+ "mode": "ro"},
1750
+ HF_HUB_CACHE:
1751
+ {"bind": "/huggingface_hub_cache",
1752
+ "mode": "ro"}
1753
+ }
1754
+ )
1755
+
1756
+ if not serving_container:
1757
+ print("failed spinning the LitServe container",
1758
+ file=sys.stderr)
1759
+ self.local_serve_is_ready = \
1760
+ LocalServeReadinessEnum.FAILURE
1761
+ try:
1762
+ cleanup_docker(
1763
+ container_name=container_name,
1764
+ image_name=f"{image_name}:1.0",
1765
+ no_pruning=True # for intermediate layers recycling
1766
+ # (during later re-runs)
1767
+ # to avoid long rebuild time
1768
+ # of exactly the same.
1769
+ )
1770
+ except Exception as cleanup_ex:
1771
+ # fail silently
1772
+ pass
1773
+ else:
1774
+ print("Awaiting endpoint launch..")
1775
+ start_time = time.time()
1776
+ if not endpoint_started(
1777
+ container_name, port=self.port, timeout=10*60
1778
+ ):
1779
+ print(
1780
+ f"The endpoint '{container_name}' " +
1781
+ f"did not start.")
1782
+ self.local_serve_is_ready = \
1783
+ LocalServeReadinessEnum.FAILURE
1784
+ # health check on the spun-up endpoint
1785
+ elif endpoint_is_ready(port=self.port):
1786
+ self.local_serve_is_ready = \
1787
+ LocalServeReadinessEnum.SUCCESS
1788
+ elapsed_time = time.time() - start_time
1789
+ print("deploy_local - Elapsed time: " +
1790
+ f"{elapsed_time:.2f} seconds")
1791
+ ############################################
1792
+ else:
1793
+ # env doesn't have docker
1794
+ self.local_serve_is_ready = \
1795
+ LocalServeReadinessEnum.FAILURE_NO_DOCKER
1796
+
1797
+ if LocalServeReadinessEnum.SUCCESS == self.local_serve_is_ready:
1798
+ from retrain_pipelines.model.litserve.litserve_datamodel \
1799
+ import Response
1800
+
1801
+ import requests
1802
+
1803
+ url = f"http://localhost:{self.port}/predict"
1804
+ headers = {"accept": "application/x-www-form-urlencoded"}
1805
+
1806
+ try:
1807
+ start_time = time.time()
1808
+ data = {
1809
+ "adapter_name": "func_caller",
1810
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1811
+ }
1812
+ print(f"inference test - data: {data}")
1813
+ response = requests.post(url, headers=headers, data=data)
1814
+ parsed_response = Response(**{"output": response.json()})
1815
+ elapsed_time = time.time() - start_time
1816
+ print("parsed_response ('func_caller' adapter ON) :" +
1817
+ str(parsed_response) +
1818
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1819
+
1820
+ start_time = time.time()
1821
+ data = {
1822
+ "queries": '["Hello.", "Is 49 a perfect square?"]'
1823
+ }
1824
+ print(f"inference test - data: {data}")
1825
+ response = requests.post(url, headers=headers, data=data)
1826
+ parsed_response = Response(**{"output": response.json()})
1827
+ elapsed_time = time.time() - start_time
1828
+ print(f"parsed_response (no adapter) : {parsed_response}" +
1829
+ f"\t-\tElapsed time: {elapsed_time:.2f} seconds")
1830
+
1831
+ except Exception as ex:
1832
+ print(ex, file=sys.stderr)
1833
+ traceback.print_tb(ex.__traceback__, file=sys.stderr)
1834
+ self.local_serve_is_ready = \
1835
+ LocalServeReadinessEnum.FAILURE
1836
+ pass
1837
+
1838
+ try:
1839
+ cleanup_docker(
1840
+ container_name=container_name,
1841
+ image_name=f"{image_name}:1.0",
1842
+ no_pruning=True # for intermediate layers recycling
1843
+ # (during later re-runs)
1844
+ # to avoid long rebuild time
1845
+ # of exactly the same.
1846
+ )
1847
+ except Exception as cleanup_ex:
1848
+ # fail silently
1849
+ pass
1850
+
1851
+ self.next(self.pipeline_card)
1852
+
1853
+
1854
+ @card(id='default')
1855
+ @card(type='html', id='custom')
1856
+ @step
1857
+ def pipeline_card(self):
1858
+ import re
1859
+ import datetime
1860
+ import importlib.metadata
1861
+
1862
+ #############################
1863
+ # case of user-provided #
1864
+ # documentation artifact(s) #
1865
+ #############################
1866
+ # note that user can provide either
1867
+ # 'pipeline_card.py' or 'template.html'
1868
+ # or 'dataset_readme.py'
1869
+ # or 'dataset_readme_template.md'
1870
+ # or 'model_readme.py'
1871
+ # or 'model_readme_template.md'
1872
+ # or any combination of those
1873
+ # when specifying custom
1874
+ # 'pipeline_card_artifacts_path'
1875
+ if "template.html" in os.listdir(
1876
+ self.pipeline_card_artifacts_path
1877
+ ):
1878
+ template_dir = self.pipeline_card_artifacts_path
1879
+ else:
1880
+ template_dir = os.path.dirname(
1881
+ importlib.util.find_spec(
1882
+ f"retrain_pipelines.pipeline_card."+
1883
+ f"{os.getenv('retrain_pipeline_type')}"
1884
+ ).origin)
1885
+ #############################
1886
+ if "pipeline_card.py" in os.listdir(
1887
+ self.pipeline_card_artifacts_path
1888
+ ):
1889
+ from retrain_pipelines.utils import get_get_html
1890
+ get_html = \
1891
+ get_get_html(self.pipeline_card_artifacts_path)
1892
+ else:
1893
+ from retrain_pipelines.pipeline_card import \
1894
+ get_html
1895
+ from retrain_pipelines.pipeline_card.helpers import \
1896
+ mf_dag_svg
1897
+ #############################
1898
+
1899
+
1900
+ #############################
1901
+ ## "default" card ##
1902
+ #############################
1903
+ self.metadata = {
1904
+ "name": "TabNet Model",
1905
+ "version": "1.0",
1906
+ "retrain_pipelines": f"retrain-pipelines {__version__}",
1907
+ "retrain_pipeline_type": os.environ["retrain_pipeline_type"],
1908
+ "description": "A PyTorch TabNet model retrained",
1909
+ "authors": [current.username],
1910
+ "tags": ["classification", "tabnet"],
1911
+ "license": "MIT License",
1912
+ "data_augmentation": [
1913
+ {
1914
+ "name": "Augmentation",
1915
+ "description": "Truncating queries and " + \
1916
+ "associate those to " + \
1917
+ "no tool-call answers. " + \
1918
+ "Intent being to instruct on " + \
1919
+ "not hallucinating missing " + \
1920
+ "tool-calls parameters values."
1921
+ },
1922
+ {
1923
+ "name": "Enrichment",
1924
+ "description": "Addition of records " + \
1925
+ "from an external data-source. " + \
1926
+ "Here to instruct on no tool-call."
1927
+ }
1928
+ ],
1929
+ "references": [
1930
+ {
1931
+ "title": "Base model",
1932
+ "link": f"https://hf.co/{self.hf_base_model_dict['repo_id']}"
1933
+ },
1934
+ {
1935
+ "title": "Function-calling dataset",
1936
+ "link": f"https://hf.co/{self.hf_dataset_dict['repo_id']}"
1937
+ },
1938
+ {
1939
+ "title": "Data-enrichment dataset",
1940
+ "link": f"https://hf.co/{self.hf_enrich_dataset_dict['repo_id']}"
1941
+ },
1942
+ {
1943
+ "title": "Unsloth",
1944
+ "link": "https://unsloth.ai/blog/contpretraining"
1945
+ }
1946
+ ]
1947
+ }
1948
+
1949
+ current.card['default'].append(Markdown(
1950
+ "model_version_blessed : **%s**" % str(self.model_version_blessed)))
1951
+ current.card['default'].append(Artifact(
1952
+ {"model_version_blessed": self.model_version_blessed}))
1953
+
1954
+ current.card['default'].append(
1955
+ Image.from_matplotlib(self.sft_log_history_fig))
1956
+ current.card['default'].append(
1957
+ Image.from_matplotlib(self.validation_completions_fig))
1958
+ #############################
1959
+
1960
+ #############################
1961
+ ## html "custom" card ##
1962
+ #############################
1963
+ dt = datetime.datetime.now(tz=datetime.timezone.utc)
1964
+ formatted_dt = dt.strftime("%A %b %d %Y %I:%M:%S %p %Z")
1965
+ task_obj_python_cmd = f"metaflow.Task(" + \
1966
+ f"\"{current.pathspec}\", " + \
1967
+ f"attempt={str(current.retry_count)})"
1968
+ params={
1969
+ 'template_dir': template_dir,
1970
+ 'title': f"{current.flow_name}",
1971
+ "subtitle": f"(flow run # {len(list(current.run.parent.runs()))}," + \
1972
+ f" run_id: {str(current.run.id)} - {formatted_dt})",
1973
+ 'model_version_blessed': self.model_version_blessed,
1974
+ # 'current_blessed_run': self.current_blessed_run,
1975
+ 'current_blessed_model_commit_hash': (
1976
+ self.current_blessed_version_dict["commit_hash"]
1977
+ if self.current_blessed_version_dict
1978
+ else None
1979
+ ),
1980
+ 'LocalServeReadinessEnum': LocalServeReadinessEnum,
1981
+ 'local_serve_is_ready': self.local_serve_is_ready,
1982
+ # EDA
1983
+ 'main_dataset_repo_id': self.hf_dataset['repo_id'],
1984
+ 'main_dataset_commit_hash': self.hf_dataset_dict['commit_hash'],
1985
+ 'main_dataset_commit_datetime': \
1986
+ self.hf_dataset_dict['commit_datetime'],
1987
+
1988
+ 'records_count': self.records_count,
1989
+ 'data_schema': self.data_schema,
1990
+ 'answers_tools_count_fig': self.answers_tools_count_fig,
1991
+ 'words_count_fig': self.words_count_fig,
1992
+
1993
+ # model training
1994
+ 'dataset_repo_id': self.dataset_repo_id,
1995
+ 'dataset_version_label': self.dataset_commit_dict["version_label"],
1996
+ 'dataset_commit_datetime': self.dataset_commit_dict["commit_datetime"],
1997
+ 'dataset_commit_hash': self.dataset_commit_dict["commit_hash"],
1998
+ 'dataset_augmentation_rate': self.actual_augmentation_rate,
1999
+ 'dataset_enrichment_rate': self.enrichment_rate,
2000
+
2001
+ 'model_repo_id': self.model_repo_id,
2002
+ 'model_version_label': self.model_commit_dict["version_label"],
2003
+ 'model_commit_datetime': self.model_commit_dict["commit_datetime"],
2004
+ 'model_commit_hash': self.model_commit_dict["commit_hash"],
2005
+
2006
+ 'cpt_log_history_fig': self.cpt_log_history_fig,
2007
+ 'sft_log_history_fig': self.sft_log_history_fig,
2008
+
2009
+ 'validation_completions_fig': self.validation_completions_fig,
2010
+
2011
+ 'pipeline_parameters_dict': {"cpt": self.cpt_training_args,
2012
+ "sft": self.sft_training_args},
2013
+
2014
+ 'metrics_dict': self.perf_metrics,
2015
+
2016
+ 'task_obj_python_cmd': task_obj_python_cmd,
2017
+ 'dag_svg': mf_dag_svg(self)
2018
+ }
2019
+ self.html = get_html(params)
2020
+ #############################
2021
+ current
2022
+ #############################
2023
+
2024
+ self.next(self.pipeline_to_hub)
2025
+
2026
+
2027
+ @step
2028
+ def pipeline_to_hub(self):
2029
+ """
2030
+ publish versioned source-code and pipeline-card
2031
+ for ths run on the Hugging Face Hub.
2032
+ """
2033
+
2034
+ model_commit_datetime = \
2035
+ self.model_commit_dict["commit_datetime"]
2036
+ timestamp_str = \
2037
+ "{:%Y%m%d_%H%M%S}".format(model_commit_datetime) + \
2038
+ "{:03d}".format(model_commit_datetime.microsecond//1000) + \
2039
+ "_UTC"
2040
+ subfolder_name = \
2041
+ "v" + self.model_commit_dict["version_label"] + \
2042
+ "_" + timestamp_str
2043
+ commit_datetime = datetime.utcnow()
2044
+
2045
+ ###############################
2046
+ # source-code #
2047
+ ###############################
2048
+ # We upload only herein file #
2049
+ # plus user-provided versions #
2050
+ # of the customizable ones #
2051
+ # (if any). #
2052
+ ###############################
2053
+ custom_source_files = [os.path.abspath(__file__)]
2054
+ if (
2055
+ self.pipeline_card_artifacts_path != \
2056
+ self.default_pipeline_card_module_dir
2057
+ ):
2058
+ candidate_source_files = [
2059
+ "pipeline_card.py",
2060
+ "template.html",
2061
+ "dataset_readme.py",
2062
+ "dataset_readme_template.md",
2063
+ "model_readme.py",
2064
+ "model_readme_template.md"
2065
+ ]
2066
+ for candidate_source_file in candidate_source_files:
2067
+ file_fullpath = os.path.join(
2068
+ self.pipeline_card_artifacts_path,
2069
+ candidate_source_file)
2070
+ if os.path.exists(file_fullpath):
2071
+ custom_source_files.append(file_fullpath)
2072
+
2073
+ source_code_commit_hash = \
2074
+ push_files_to_hub_repo_branch(
2075
+ repo_id=self.model_repo_id,
2076
+ branch_name="retrain-pipelines_source-code",
2077
+ file_fullnames=custom_source_files,
2078
+ include_requirements_txt=True,
2079
+ path_in_repo=subfolder_name,
2080
+ commit_message=\
2081
+ "source-code for model version " + \
2082
+ subfolder_name + \
2083
+ f"- retrain-pipelines {__version__}",
2084
+ repo_type="model",
2085
+ hf_token=os.getenv("HF_TOKEN", None)
2086
+ )
2087
+ print(source_code_commit_hash)
2088
+ self.source_code_commit_dict = {
2089
+ "repo_id": self.model_repo_id,
2090
+ "branch_name": "retrain-pipelines_source-code",
2091
+ "commit_datetime": commit_datetime,
2092
+ "commit_hash": source_code_commit_hash
2093
+ }
2094
+ ###############################
2095
+
2096
+ ###############################
2097
+ # pipeline-card #
2098
+ ###############################
2099
+ pipeline_card_fullname = None
2100
+ for run_step in current.run.steps():
2101
+ task = list(run_step.tasks())[0]
2102
+ task_name = task.path_components[2]
2103
+ if "pipeline_card" == task_name:
2104
+ pipeline_card = get_cards(
2105
+ task, id='custom', type='html')[0]
2106
+ pipeline_card_fullname = os.path.realpath(
2107
+ os.path.join(
2108
+ task.metadata_dict.get("ds-root", None),
2109
+ mf_config.CARD_SUFFIX, pipeline_card.path
2110
+ ))
2111
+ print(pipeline_card_fullname)
2112
+ break
2113
+ pipeline_card_commit_hash = \
2114
+ push_files_to_hub_repo_branch(
2115
+ repo_id=self.model_repo_id,
2116
+ branch_name="retrain-pipelines_pipeline-card",
2117
+ file_fullnames=[pipeline_card_fullname],
2118
+ path_in_repo=subfolder_name,
2119
+ commit_message=\
2120
+ "pipeline-card for model version " + \
2121
+ subfolder_name + \
2122
+ f"- retrain-pipelines {__version__}",
2123
+ repo_type="model",
2124
+ hf_token=os.getenv("HF_TOKEN", None)
2125
+ )
2126
+ print(pipeline_card_commit_hash)
2127
+ self.pipeline_card_commit_dict = {
2128
+ "repo_id": self.model_repo_id,
2129
+ "branch_name": "retrain-pipelines_pipeline-card",
2130
+ "commit_datetime": commit_datetime,
2131
+ "commit_hash": pipeline_card_commit_hash
2132
+ }
2133
+ ###############################
2134
+
2135
+ self.next(self.deploy)
2136
+
2137
+
2138
+ @step
2139
+ def deploy(self):
2140
+ """
2141
+ placeholder for the serving SDK deploy call
2142
+ (on the target production platform).
2143
+ Include any artifact you want,
2144
+ consider including the portable pipelione-card
2145
+ itself !
2146
+ """
2147
+
2148
+ if (
2149
+ self.model_version_blessed and
2150
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2151
+ ):
2152
+ pass # your code here
2153
+
2154
+ self.next(self.load_test)
2155
+
2156
+
2157
+ @step
2158
+ def load_test(self):
2159
+ """
2160
+ placeholder
2161
+ """
2162
+
2163
+ if (
2164
+ self.model_version_blessed and
2165
+ (self.local_serve_is_ready == LocalServeReadinessEnum.SUCCESS)
2166
+ ):
2167
+ pass # your code here
2168
+
2169
+ self.next(self.end)
2170
+
2171
+
2172
+ @step
2173
+ def end(self):
2174
+ pass
2175
+
2176
+
2177
+ if __name__ == "__main__":
2178
+ UnslothFuncCallFlow()
2179
+