lzyhha commited on
Commit
d85f714
·
1 Parent(s): ee318f7

imwatermark

Browse files
Files changed (2) hide show
  1. flux/util.py +0 -45
  2. requirements.txt +0 -1
flux/util.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
  import torch
5
  from einops import rearrange
6
  from huggingface_hub import hf_hub_download
7
- from imwatermark import WatermarkEncoder
8
  from PIL import ExifTags, Image
9
  from safetensors.torch import load_file as load_sft
10
 
@@ -470,47 +469,3 @@ def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> di
470
  state_dict[name] = expanded_state_dict_weight
471
 
472
  return state_dict
473
-
474
-
475
- class WatermarkEmbedder:
476
- def __init__(self, watermark):
477
- self.watermark = watermark
478
- self.num_bits = len(WATERMARK_BITS)
479
- self.encoder = WatermarkEncoder()
480
- self.encoder.set_watermark("bits", self.watermark)
481
-
482
- def __call__(self, image: torch.Tensor) -> torch.Tensor:
483
- """
484
- Adds a predefined watermark to the input image
485
-
486
- Args:
487
- image: ([N,] B, RGB, H, W) in range [-1, 1]
488
-
489
- Returns:
490
- same as input but watermarked
491
- """
492
- image = 0.5 * image + 0.5
493
- squeeze = len(image.shape) == 4
494
- if squeeze:
495
- image = image[None, ...]
496
- n = image.shape[0]
497
- image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
498
- # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
499
- # watermarking libary expects input as cv2 BGR format
500
- for k in range(image_np.shape[0]):
501
- image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
502
- image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
503
- image.device
504
- )
505
- image = torch.clamp(image / 255, min=0.0, max=1.0)
506
- if squeeze:
507
- image = image[0]
508
- image = 2 * image - 1
509
- return image
510
-
511
-
512
- # A fixed 48-bit message that was chosen at random
513
- WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
514
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
515
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
516
- embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
 
4
  import torch
5
  from einops import rearrange
6
  from huggingface_hub import hf_hub_download
 
7
  from PIL import ExifTags, Image
8
  from safetensors.torch import load_file as load_sft
9
 
 
469
  state_dict[name] = expanded_state_dict_weight
470
 
471
  return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -13,5 +13,4 @@ numba
13
  scipy
14
  tqdm
15
  einops
16
- imwatermark
17
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu11torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
 
13
  scipy
14
  tqdm
15
  einops
 
16
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu11torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl