ramimu commited on
Commit
aff7e63
·
verified ·
1 Parent(s): 2ec882e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +14 -2
train.py CHANGED
@@ -11,8 +11,20 @@ from transformers import CLIPTextModel, CLIPTokenizer
11
  from peft import LoraConfig, get_peft_model
12
 
13
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
14
- dataset_path = "/workspace/data"
15
- output_dir = "/workspace/lora-trained"
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # 1) grab the model locally
18
  print("📥 Downloading Flux‑Dev model…")
 
11
  from peft import LoraConfig, get_peft_model
12
 
13
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
14
+
15
+ # download
16
+ model_path = snapshot_download(
17
+ MODEL_ID,
18
+ local_dir="./fluxdev-model",
19
+ use_auth_token=True
20
+ )
21
+
22
+ # later loading
23
+ pipe = StableDiffusionPipeline.from_pretrained(
24
+ model_path,
25
+ torch_dtype=torch.float16,
26
+ use_auth_token=True
27
+ ).to("cuda")
28
 
29
  # 1) grab the model locally
30
  print("📥 Downloading Flux‑Dev model…")