dangtr0408 commited on
Commit
6985472
·
1 Parent(s): 0d15013

minor changes

Browse files
Files changed (2) hide show
  1. Models/del_training.ipynb +62 -62
  2. inference.py +4 -4
Models/del_training.ipynb CHANGED
@@ -1,62 +1,62 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "2b6bb4be",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "import os\n",
11
- "import torch"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": null,
17
- "id": "dc802b47",
18
- "metadata": {},
19
- "outputs": [],
20
- "source": [
21
- "models_path = \"./current_model_120k_vi.pth\"\n",
22
- "name = \"./model.pth\"\n",
23
- "params_whole = torch.load(models_path, map_location='cpu')\n",
24
- "\n",
25
- "for key in list(params_whole.keys()):\n",
26
- " if key != 'net':\n",
27
- " params_whole.pop(key)\n",
28
- "\n",
29
- "keep = ['decoder', 'predictor', 'text_encoder', 'style_encoder']\n",
30
- "for module_name in list(params_whole['net'].keys()):\n",
31
- " if module_name not in keep:\n",
32
- " params_whole['net'].pop(module_name)\n",
33
- "\n",
34
- "torch.save(params_whole, name)\n",
35
- "\n",
36
- "\n",
37
- "os.remove(models_path)"
38
- ]
39
- }
40
- ],
41
- "metadata": {
42
- "kernelspec": {
43
- "display_name": "base",
44
- "language": "python",
45
- "name": "python3"
46
- },
47
- "language_info": {
48
- "codemirror_mode": {
49
- "name": "ipython",
50
- "version": 3
51
- },
52
- "file_extension": ".py",
53
- "mimetype": "text/x-python",
54
- "name": "python",
55
- "nbconvert_exporter": "python",
56
- "pygments_lexer": "ipython3",
57
- "version": "3.11.7"
58
- }
59
- },
60
- "nbformat": 4,
61
- "nbformat_minor": 5
62
- }
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "2b6bb4be",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import torch"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "dc802b47",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "models_path = \"./current_model_120k_vi.pth\"\n",
22
+ "name = \"./model.pth\"\n",
23
+ "params_whole = torch.load(models_path, map_location='cpu')\n",
24
+ "\n",
25
+ "for key in list(params_whole.keys()):\n",
26
+ " if key != 'net':\n",
27
+ " params_whole.pop(key)\n",
28
+ "\n",
29
+ "keep = ['decoder', 'predictor', 'text_encoder', 'style_encoder']\n",
30
+ "for module_name in list(params_whole['net'].keys()):\n",
31
+ " if module_name not in keep:\n",
32
+ " params_whole['net'].pop(module_name)\n",
33
+ "\n",
34
+ "torch.save(params_whole, name)\n",
35
+ "\n",
36
+ "\n",
37
+ "#os.remove(models_path)"
38
+ ]
39
+ }
40
+ ],
41
+ "metadata": {
42
+ "kernelspec": {
43
+ "display_name": "base",
44
+ "language": "python",
45
+ "name": "python3"
46
+ },
47
+ "language_info": {
48
+ "codemirror_mode": {
49
+ "name": "ipython",
50
+ "version": 3
51
+ },
52
+ "file_extension": ".py",
53
+ "mimetype": "text/x-python",
54
+ "name": "python",
55
+ "nbconvert_exporter": "python",
56
+ "pygments_lexer": "ipython3",
57
+ "version": "3.11.7"
58
+ }
59
+ },
60
+ "nbformat": 4,
61
+ "nbformat_minor": 5
62
+ }
inference.py CHANGED
@@ -64,7 +64,7 @@ class TextCleaner:
64
 
65
  class Preprocess:
66
  def __text_normalize(self, text):
67
- punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":"]
68
  map_to = "."
69
  punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
70
  #ensure consistency.
@@ -72,8 +72,8 @@ class Preprocess:
72
  #replace punctuation that acts like a comma or period
73
  #text = re.sub(r'\.{2,}', '.', text)
74
  text = punctuation_pattern.sub(map_to, text)
75
- #remove or replace special chars except . , { } ? ' - \ % $ & /
76
- text = re.sub(r'[^\w\s.,{}?\'\-\[\]\%\$\&\/]', ' ', text)
77
  #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
78
  text = re.sub(r'\s+', ' ', text).strip()
79
  return text
@@ -211,7 +211,7 @@ class StyleTTS2(torch.nn.Module):
211
  audio = audio*(1-denoise) + audio_denoise*denoise
212
 
213
  with torch.no_grad():
214
- if split_dur>0 and len(audio)/sr>split_dur:
215
  #This option will split the ref audio to multiple parts, calculate styles and average them
216
  count = 0
217
  ref_s = None
 
64
 
65
  class Preprocess:
66
  def __text_normalize(self, text):
67
+ punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
68
  map_to = "."
69
  punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
70
  #ensure consistency.
 
72
  #replace punctuation that acts like a comma or period
73
  #text = re.sub(r'\.{2,}', '.', text)
74
  text = punctuation_pattern.sub(map_to, text)
75
+ #remove or replace special chars except . , { } % $ & ' - \ /
76
+ text = re.sub(r'[^\w\s.,{}%$&\'\-\[\]\/]', ' ', text)
77
  #replace consecutive whitespace chars with a single space and strip leading/trailing spaces
78
  text = re.sub(r'\s+', ' ', text).strip()
79
  return text
 
211
  audio = audio*(1-denoise) + audio_denoise*denoise
212
 
213
  with torch.no_grad():
214
+ if split_dur>0 and len(audio)/sr>=4: #Only effective if audio length is >= 4s
215
  #This option will split the ref audio to multiple parts, calculate styles and average them
216
  count = 0
217
  ref_s = None