drewThomasson commited on
Commit
31c247f
·
verified ·
1 Parent(s): ae03f2c

Update watermarking.py

Browse files
Files changed (1) hide show
  1. watermarking.py +8 -17
watermarking.py CHANGED
@@ -1,21 +1,22 @@
1
  import os
2
  import argparse
3
-
4
  import silentcipher
5
  import torch
6
  import torchaudio
7
 
8
- CSM_1B_HF_WATERMARK = list(map(int, os.getenv("WATERMARK_KEY").split(" ")))
9
-
 
 
 
 
10
 
11
  def cli_check_audio() -> None:
12
  parser = argparse.ArgumentParser()
13
  parser.add_argument("--audio_path", type=str, required=True)
14
  args = parser.parse_args()
15
-
16
  check_audio_from_file(args.audio_path)
17
 
18
-
19
  def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
20
  model = silentcipher.get_model(
21
  model_type="44.1k",
@@ -23,7 +24,6 @@ def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
23
  )
24
  return model
25
 
26
-
27
  @torch.inference_mode()
28
  def watermark(
29
  watermarker: silentcipher.server.Model,
@@ -32,13 +32,11 @@ def watermark(
32
  watermark_key: list[int],
33
  ) -> tuple[torch.Tensor, int]:
34
  audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
35
- encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
36
-
37
  output_sample_rate = min(44100, sample_rate)
38
  encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
39
  return encoded, output_sample_rate
40
 
41
-
42
  @torch.inference_mode()
43
  def verify(
44
  watermarker: silentcipher.server.Model,
@@ -48,31 +46,24 @@ def verify(
48
  ) -> bool:
49
  watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
50
  result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
51
-
52
  is_watermarked = result["status"]
53
  if is_watermarked:
54
  is_csm_watermarked = result["messages"][0] == watermark_key
55
  else:
56
  is_csm_watermarked = False
57
-
58
  return is_watermarked and is_csm_watermarked
59
 
60
-
61
  def check_audio_from_file(audio_path: str) -> None:
62
  watermarker = load_watermarker(device="cuda")
63
-
64
  audio_array, sample_rate = load_audio(audio_path)
65
  is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_HF_WATERMARK)
66
-
67
  outcome = "Watermarked" if is_watermarked else "Not watermarked"
68
  print(f"{outcome}: {audio_path}")
69
 
70
-
71
  def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
72
  audio_array, sample_rate = torchaudio.load(audio_path)
73
  audio_array = audio_array.mean(dim=0)
74
  return audio_array, int(sample_rate)
75
 
76
-
77
  if __name__ == "__main__":
78
- cli_check_audio()
 
1
  import os
2
  import argparse
 
3
  import silentcipher
4
  import torch
5
  import torchaudio
6
 
7
+ # Set a default watermark key if environment variable is not set
8
+ watermark_key_str = os.getenv("WATERMARK_KEY")
9
+ if watermark_key_str is None:
10
+ CSM_1B_HF_WATERMARK = [0, 0, 0, 0] # Default placeholder
11
+ else:
12
+ CSM_1B_HF_WATERMARK = list(map(int, watermark_key_str.split(" ")))
13
 
14
  def cli_check_audio() -> None:
15
  parser = argparse.ArgumentParser()
16
  parser.add_argument("--audio_path", type=str, required=True)
17
  args = parser.parse_args()
 
18
  check_audio_from_file(args.audio_path)
19
 
 
20
  def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
21
  model = silentcipher.get_model(
22
  model_type="44.1k",
 
24
  )
25
  return model
26
 
 
27
  @torch.inference_mode()
28
  def watermark(
29
  watermarker: silentcipher.server.Model,
 
32
  watermark_key: list[int],
33
  ) -> tuple[torch.Tensor, int]:
34
  audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
35
+ encoded, * = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
 
36
  output_sample_rate = min(44100, sample_rate)
37
  encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
38
  return encoded, output_sample_rate
39
 
 
40
  @torch.inference_mode()
41
  def verify(
42
  watermarker: silentcipher.server.Model,
 
46
  ) -> bool:
47
  watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
48
  result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
 
49
  is_watermarked = result["status"]
50
  if is_watermarked:
51
  is_csm_watermarked = result["messages"][0] == watermark_key
52
  else:
53
  is_csm_watermarked = False
 
54
  return is_watermarked and is_csm_watermarked
55
 
 
56
  def check_audio_from_file(audio_path: str) -> None:
57
  watermarker = load_watermarker(device="cuda")
 
58
  audio_array, sample_rate = load_audio(audio_path)
59
  is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_HF_WATERMARK)
 
60
  outcome = "Watermarked" if is_watermarked else "Not watermarked"
61
  print(f"{outcome}: {audio_path}")
62
 
 
63
  def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
64
  audio_array, sample_rate = torchaudio.load(audio_path)
65
  audio_array = audio_array.mean(dim=0)
66
  return audio_array, int(sample_rate)
67
 
 
68
  if __name__ == "__main__":
69
+ cli_check_audio()