Xenova HF Staff commited on
Commit
b3288a9
·
verified ·
1 Parent(s): 19778b2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -1
README.md CHANGED
@@ -7,7 +7,93 @@ tags: []
7
 
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  ## Model Details
13
 
 
7
 
8
  <!-- Provide a quick summary of what the model is/does. -->
9
 
10
+ ## Code to create model
11
+ ```py
12
+ import torch
13
+ from transformers import MimiConfig, MimiModel, AutoProcessor
14
+
15
+ model_id = 'kyutai/mimi'
16
+ config = MimiConfig.from_pretrained(
17
+ model_id,
18
+ intermediate_size=64,
19
+ hidden_size=16,
20
+ num_hidden_layers=2,
21
+ num_key_value_heads=2,
22
+ upsample_groups=16,
23
+ num_filters=8,
24
+ codebook_dim=8,
25
+ vector_quantization_hidden_dimension=8,
26
+ codebook_size=32,
27
+ )
28
+
29
+ # Create model and randomize all weights
30
+ model = MimiModel(config)
31
+
32
+ torch.manual_seed(0) # Set for reproducibility
33
+ for name, param in model.named_parameters():
34
+ param.data = torch.randn_like(param)
35
+
36
+ processor = AutoProcessor.from_pretrained(model_id)
37
+ ```
38
+
39
+ ## ONNX conversion code
40
+ ```py
41
+ import torch
42
+ import torch.nn as nn
43
+ from transformers import MimiModel
44
+
45
+ class MimiEncoder(nn.Module):
46
+ def __init__(self, model):
47
+ super(MimiEncoder, self).__init__()
48
+ self.model = model
49
+
50
+ def forward(self, input_values, padding_mask=None):
51
+ return self.model.encode(input_values, padding_mask=padding_mask).audio_codes
52
+
53
+ class MimiDecoder(nn.Module):
54
+ def __init__(self, model):
55
+ super(MimiDecoder, self).__init__()
56
+ self.model = model
57
+
58
+ def forward(self, audio_codes, padding_mask=None):
59
+ return self.model.decode(audio_codes, padding_mask=padding_mask).audio_values
60
+
61
+ model = MimiModel.from_pretrained("hf-internal-testing/tiny-random-MimiModel")
62
+ encoder = MimiEncoder(model)
63
+ decoder = MimiDecoder(model)
64
+
65
+ dummy_encoder_inputs = torch.randn((5, 1, 82500))
66
+ torch.onnx.export(
67
+ encoder,
68
+ dummy_encoder_inputs,
69
+ "encoder_model.onnx",
70
+ export_params=True,
71
+ opset_version=14,
72
+ do_constant_folding=True,
73
+ input_names=['input_values'],
74
+ output_names=['audio_codes'],
75
+ dynamic_axes={
76
+ 'input_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
77
+ 'audio_codes': {0: 'batch_size', 2: 'codes_length'},
78
+ },
79
+ )
80
+
81
+ dummy_decoder_inputs = torch.randint(8, (4, 32, 91))
82
+ torch.onnx.export(
83
+ decoder,
84
+ dummy_decoder_inputs,
85
+ "decoder_model.onnx",
86
+ export_params=True,
87
+ opset_version=14,
88
+ do_constant_folding=True,
89
+ input_names=['audio_codes'],
90
+ output_names=['audio_values'],
91
+ dynamic_axes={
92
+ 'audio_codes': {0: 'batch_size', 2: 'codes_length'},
93
+ 'audio_values': {0: 'batch_size', 1: 'num_channels', 2: 'sequence_length'},
94
+ },
95
+ )
96
+ ```
97
 
98
  ## Model Details
99