File size: 24,709 Bytes
9d6cb8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 |
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The intent of the tests contained in this file is to check as many PEFT features as possible with torch.compile. This
# is thus a document on how well torch.compile is supported by PEFT. Currently, we know that certain features do not
# work with torch.compile. The corresponding tests should be marked with `@pytest.mark.xfail(strict=True)`.
#
# When adding a new test that fails with torch.compile, please make sure first that it does NOT fail without
# torch.compile.
import gc
import os
import pytest
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import (
AdaLoraConfig,
BOFTConfig,
IA3Config,
LNTuningConfig,
LoHaConfig,
LoKrConfig,
LoraConfig,
OFTConfig,
PeftModel,
TaskType,
VeraConfig,
get_peft_model,
)
# only run (very slow) torch.compile tests when explicitly asked to
if os.environ.get("PEFT_DEBUG_WITH_TORCH_COMPILE") != "1":
pytest.skip(allow_module_level=True)
# Mapping: name of the setting -> (Peft config instance, torch.compile kwargs)
SETTINGS = {
"adalora": (AdaLoraConfig(task_type=TaskType.CAUSAL_LM), {}),
"boft": (BOFTConfig(task_type=TaskType.CAUSAL_LM), {}),
"dora": (LoraConfig(task_type=TaskType.CAUSAL_LM, use_dora=True), {}),
"ia3": (IA3Config(task_type=TaskType.CAUSAL_LM), {}),
"ln_tuning": (LNTuningConfig(task_type=TaskType.CAUSAL_LM, target_modules=["final_layer_norm"]), {}),
"loha": (LoHaConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
"lokr": pytest.param(
(LoKrConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
marks=pytest.mark.xfail(strict=True),
),
"lora": (LoraConfig(task_type=TaskType.CAUSAL_LM), {}),
"lora-target-embeddings": pytest.param(
(LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["embed_tokens"]), {}),
marks=pytest.mark.xfail(strict=True),
),
"lora-with-modules-to-save": (LoraConfig(task_type=TaskType.CAUSAL_LM, modules_to_save=["embed_tokens"]), {}),
"oft": (OFTConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"]), {}),
"vera": (VeraConfig(task_type=TaskType.CAUSAL_LM), {}),
}
@pytest.mark.single_gpu_tests
class TestTorchCompileCausalLM:
"""
Tests for using torch.compile with causal LM.
Tip: When adding a new test, set `fake_compile = False` below. With this setting, torch.compile is being skipped.
This is useful for two reasons:
- compile is slow, so to quickly iterate on the test, it's best to disable it and only enable it at the very end
- even if you expect the test to fail with compile, as compile does not work with every PEFT feature, it still MUST
succeed without compile, otherwise the test is incorrect.
Before creating the PR, disable `fake_compile`.
"""
fake_compile = False
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
max_train_loss = 15.0 # generous threshold for maximum loss after training
@pytest.fixture(autouse=True)
def teardown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@pytest.fixture(scope="class")
def tokenizer(self):
return AutoTokenizer.from_pretrained(self.model_id)
@pytest.fixture(scope="class")
def data(self, tokenizer):
def tokenize(samples):
# For some reason, the max sequence length is not honored by the tokenizer, resulting in IndexErrors. Thus,
# manually ensure that sequences are not too long.
tokenized = tokenizer(samples["quote"])
tokenized["input_ids"] = [input_ids[: tokenizer.model_max_length] for input_ids in tokenized["input_ids"]]
tokenized["attention_mask"] = [
input_ids[: tokenizer.model_max_length] for input_ids in tokenized["attention_mask"]
]
return tokenized
data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(tokenize, batched=True)
# We need to manually remove unused columns. This is because we cannot use remove_unused_columns=True in the
# Trainer, as this leads to errors with torch.compile. We also cannot just leave them in, as they contain
# strings. Therefore, manually remove all unused columns.
data = data.remove_columns(["quote", "author", "tags"])
return data
def compile(self, model, compile_kwargs):
compile_kwargs = compile_kwargs.copy()
# those are only for the Trainer arguments
compile_kwargs.pop("torch_compile_backend", None)
compile_kwargs.pop("torch_compile_mode", None)
if self.fake_compile:
return model
return torch.compile(model, **compile_kwargs)
@pytest.mark.parametrize("settings", SETTINGS.values(), ids=SETTINGS.keys())
def test_causal_lm_training_trainer_compile(self, settings, tokenizer, data, tmp_path):
r"""Train a PEFT model with torch.compile using Trainer"""
tmp_dir = tmp_path / "model"
config, compile_kwargs = settings
if isinstance(config, AdaLoraConfig):
pytest.skip(reason="AdaLora does not work correctly with Trainer")
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
)
model = get_peft_model(model, config)
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
train_kwargs = {
"per_device_train_batch_size": 4,
"max_steps": 5,
"learning_rate": 1e-3,
"logging_steps": 1,
"output_dir": tmp_dir,
"seed": 0,
}
training_args = TrainingArguments(
torch_compile=not self.fake_compile,
torch_compile_backend=compile_kwargs.get("torch_compile_backend", None),
torch_compile_mode=compile_kwargs.get("torch_compile_mode", None),
**train_kwargs,
)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.eval()
atol, rtol = 1e-4, 1e-4
with torch.inference_mode():
output_after = model(sample)
tokens_after = model.generate(sample)
assert torch.isfinite(output_after.logits).all()
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert trainer.state.log_history[-1]["train_loss"] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
output_loaded = model(sample)
tokens_loaded = model.generate(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
assert (tokens_after == tokens_loaded).all()
@pytest.mark.parametrize("settings", SETTINGS.values(), ids=SETTINGS.keys())
def test_causal_lm_training_pytorch_compile(self, settings, tokenizer, data, tmp_path):
r"""Train a PEFT model with torch.compile using PyTorch training loop"""
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
)
config, compile_kwargs = settings
model = get_peft_model(model, config)
if isinstance(config, AdaLoraConfig):
model.base_model.peft_config["default"].total_step = 5
model = self.compile(model, compile_kwargs)
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
model.config.use_cache = False
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
losses = []
max_steps = 5 * batch_size
for i in range(0, max_steps, batch_size):
batch = tokenizer.pad(data["train"][i : i + batch_size], return_tensors="pt").to(model.device)
# add targets
batch["labels"] = batch["input_ids"].clone()
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
if isinstance(config, AdaLoraConfig):
model.base_model.update_and_allocate(i)
model.eval()
with torch.inference_mode():
output_after = model(sample)
tokens_after = model.generate(sample)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert losses[-1] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto")
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
output_loaded = model(sample)
tokens_loaded = model.generate(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
assert (tokens_after == tokens_loaded).all()
@pytest.mark.xfail(strict=True)
def test_causal_lm_training_lora_bnb_compile(self, tokenizer, data, tmp_path):
r"""Train a bnb quantized LoRA model with torch.compile using PyTorch training loop"""
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
config = LoraConfig(task_type=TaskType.CAUSAL_LM)
model = get_peft_model(model, config)
model = self.compile(model, {})
# record outputs before training
model.eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_before = model(sample)
model.train()
model.config.use_cache = False
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 4
losses = []
max_steps = 5 * batch_size
for i in range(0, max_steps, batch_size):
batch = tokenizer.pad(data["train"][i : i + batch_size], return_tensors="pt").to(model.device)
# add targets
batch["labels"] = batch["input_ids"].clone()
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
model.eval()
with torch.inference_mode():
output_after = model(sample)
assert torch.isfinite(output_after.logits).all()
atol, rtol = 1e-4, 1e-4
# sanity check: model was updated
assert not torch.allclose(output_before.logits, output_after.logits, atol=atol, rtol=rtol)
assert losses[-1] < self.max_train_loss
# check saving the model and loading it without compile
model.save_pretrained(tmp_path)
del model
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, device_map="auto", quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
model = PeftModel.from_pretrained(model, tmp_path)
with torch.inference_mode():
# after loading, outputs are float32 for some reason
output_loaded = model(sample)
assert torch.allclose(output_after.logits, output_loaded.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_multiple_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
model.add_adapter("other", config)
model = self.compile(model, {})
with torch.inference_mode():
output_default_adapter = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other_adapter = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default_adapter.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other_adapter.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default_adapter.logits, output_other_adapter.logits, atol=atol, rtol=rtol)
# now delete the other adapter
model.delete_adapter("other")
model.set_adapter("default")
with torch.inference_mode():
output_after_delete = model(sample)
# outputs after delete == output of default adapter
assert torch.allclose(output_default_adapter.logits, output_after_delete.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_disable_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
output_lora = model(sample)
with model.disable_adapter():
with torch.inference_mode():
output_disabled = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model == output disabled adapter != output of lora adapter
assert torch.allclose(output_base.logits, output_disabled.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
def test_causal_lm_merging_lora_adapter_compile(self, tokenizer, data):
# merge the adapter
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_lora = model(sample)
model.merge_adapter()
with torch.inference_mode():
output_merged = model(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of lora adapter == output of merged adapter
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
assert torch.allclose(output_lora.logits, output_merged.logits, atol=atol, rtol=rtol)
def test_causal_lm_merging_multiple_lora_adapters_compile(self, tokenizer, data):
# merge multiple adapters at once
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model.add_adapter("other", config)
with torch.inference_mode():
output_default = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(sample)
model.base_model.merge_adapter(["default", "other"])
with torch.inference_mode():
output_merged = model(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
# outputs of merged adapter != all others
assert not torch.allclose(output_base.logits, output_merged.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_merged.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_merged.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_merge_and_unload_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model = self.compile(model, {})
with torch.inference_mode():
output_lora = model(sample)
unloaded = model.merge_and_unload()
with torch.inference_mode():
output_unloaded = unloaded(sample)
# merging is less precise, be more tolerant
atol, rtol = 1e-1, 1e-1
# outputs of the base model != output of lora adapter == output of unloaded adapter
assert not torch.allclose(output_base.logits, output_lora.logits, atol=atol, rtol=rtol)
assert torch.allclose(output_lora.logits, output_unloaded.logits, atol=atol, rtol=rtol)
@pytest.mark.xfail(strict=True)
def test_causal_lm_mixed_batch_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
# we need at least 3 samples for this to work!
sample = {
"input_ids": torch.arange(12).reshape(3, 4).to("cuda"),
"attention_mask": torch.ones(3, 4).long().to("cuda"),
}
with torch.inference_mode():
output_base = model(**sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
output_default = model(**sample)
model.add_adapter("other", config)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(**sample)
model = self.compile(model, {})
# set adapter_indices so that it alternates between 0 (base), lora 1, and lora 2
adapter_names = ["__base__", "default", "other"]
with torch.inference_mode():
output_mixed = model(**sample, adapter_names=adapter_names)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of lora adapter 1 != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_mixed.logits, atol=atol, rtol=rtol)
# outputs of mixed adapter is mix of all 3
assert torch.allclose(output_base.logits[0], output_mixed.logits[0], atol=atol, rtol=rtol)
assert torch.allclose(output_default.logits[1], output_mixed.logits[1], atol=atol, rtol=rtol)
assert torch.allclose(output_other.logits[2], output_mixed.logits[2], atol=atol, rtol=rtol)
def test_causal_lm_add_weighted_adapter_lora_adapter_compile(self, tokenizer, data):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
self.model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
).eval()
sample = torch.tensor(data["train"][:1]["input_ids"]).to(model.device)
with torch.inference_mode():
output_base = model(sample)
config = LoraConfig(task_type=TaskType.CAUSAL_LM, init_lora_weights=False)
model = get_peft_model(model, config).eval()
model.add_adapter("other", config)
with torch.inference_mode():
output_default = model(sample)
model.set_adapter("other")
with torch.inference_mode():
output_other = model(sample)
model.add_weighted_adapter(["default", "other"], [0.5, 0.5], adapter_name="combined")
model.set_adapter("combined")
with torch.inference_mode():
output_combined = model(sample)
atol, rtol = 1e-4, 1e-4
# outputs of the base model != output of default adapter != output of other adapter
assert not torch.allclose(output_base.logits, output_default.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_base.logits, output_other.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_other.logits, atol=atol, rtol=rtol)
# outputs of combined adapter != all others
assert not torch.allclose(output_base.logits, output_combined.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_default.logits, output_combined.logits, atol=atol, rtol=rtol)
assert not torch.allclose(output_other.logits, output_combined.logits, atol=atol, rtol=rtol)
|