File size: 573 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
thon from transformers import PretrainedConfig from typing import List class ResnetConfig(PretrainedConfig): model_type = "resnet" def __init__( self, block_type="bottleneck", layers: List[int] = [3, 4, 6, 3], num_classes: int = 1000, input_channels: int = 3, cardinality: int = 1, base_width: int = 64, stem_width: int = 64, stem_type: str = "", avg_down: bool = False, **kwargs, ): if block_type not in ["basic", "bottleneck"]: raise ValueError(f"`block_type` must be 'basic' or bottleneck', got {block_type}.") |