akshatOP commited on
Commit
8cd6f6e
·
1 Parent(s): a2cfe7a

Update all files: Fix Parler-TTS imports, PyTorch version, and model loading

Browse files
download_and_finetune_sst.py CHANGED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Download model
5
+ model_name = "facebook/wav2vec2-base-960h"
6
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
7
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
8
+
9
+ # Load dataset (replace with your dataset)
10
+ dataset = load_dataset("librispeech_asr", "clean", split="train.100") # Example dataset
11
+
12
+ # Preprocess function
13
+ def preprocess_function(examples):
14
+ audio = examples["audio"]
15
+ inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding=True)
16
+ with processor.as_target_processor():
17
+ labels = processor(examples["text"], return_tensors="pt", padding=True)
18
+ return {
19
+ "input_values": inputs["input_values"][0],
20
+ "labels": labels["input_ids"][0]
21
+ }
22
+
23
+ train_dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names)
24
+
25
+ # Training arguments
26
+ training_args = TrainingArguments(
27
+ output_dir="./sst_finetuned",
28
+ per_device_train_batch_size=8,
29
+ num_train_epochs=3,
30
+ save_steps=500,
31
+ logging_steps=10,
32
+ )
33
+
34
+ # Initialize Trainer
35
+ trainer = Trainer(
36
+ model=model,
37
+ args=training_args,
38
+ train_dataset=train_dataset,
39
+ )
40
+
41
+ # Fine-tune
42
+ trainer.train()
43
+
44
+ # Save fine-tuned model
45
+ trainer.save_model("./sst_finetuned")
46
+ processor.save_pretrained("./sst_finetuned")
47
+
48
+ print("SST model fine-tuned and saved to './sst_finetuned'. Upload to models/sst_model in your Space.")
download_and_finetune_tts.py CHANGED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
+
4
+ # Download model
5
+ model_name = "parler-tts/parler-tts-mini-v1"
6
+ model = ParlerTTSForConditionalGeneration.from_pretrained(model_name)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ # Load dataset (replace with your dataset)
10
+ dataset = load_dataset("lj_speech") # Example dataset; adjust as needed
11
+
12
+ # Preprocess function (customize based on your dataset)
13
+ def preprocess_function(examples):
14
+ # Tokenize text and prepare audio (example; adjust for your data)
15
+ inputs = tokenizer(examples["text"], return_tensors="pt", padding=True, truncation=True)
16
+ # Add audio processing if needed
17
+ return {"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}
18
+
19
+ train_dataset = dataset["train"].map(preprocess_function, batched=True)
20
+
21
+ # Training arguments
22
+ training_args = TrainingArguments(
23
+ output_dir="./tts_finetuned",
24
+ per_device_train_batch_size=8,
25
+ num_train_epochs=3,
26
+ save_steps=500,
27
+ logging_steps=10,
28
+ )
29
+
30
+ # Initialize Trainer
31
+ trainer = Trainer(
32
+ model=model,
33
+ args=training_args,
34
+ train_dataset=train_dataset,
35
+ )
36
+
37
+ # Fine-tune
38
+ trainer.train()
39
+
40
+ # Save fine-tuned model
41
+ trainer.save_model("./tts_finetuned")
42
+ tokenizer.save_pretrained("./tts_finetuned")
43
+
44
+ print("TTS model fine-tuned and saved to './tts_finetuned'. Upload to models/tts_model in your Space.")
requirements.txt CHANGED
@@ -1,9 +1,8 @@
1
  fastapi==0.103.2
2
  uvicorn==0.23.2
3
- git+https://github.com/huggingface/transformers.git@main#egg=transformers
4
  torch==2.1.2
5
  soundfile==0.12.1
6
  numpy==1.26.4
7
- llama-cpp-python==0.2.28
8
  pydantic==2.5.3
9
  datasets==2.16.1
 
1
  fastapi==0.103.2
2
  uvicorn==0.23.2
3
+ transformers==4.41.0
4
  torch==2.1.2
5
  soundfile==0.12.1
6
  numpy==1.26.4
 
7
  pydantic==2.5.3
8
  datasets==2.16.1