x-lai commited on
Commit
ba91034
·
1 Parent(s): da654fa

Release training script

Browse files

Former-commit-id: 16f73b66904d9d45e67de3ae95f7fe0e54f097f8

README.md CHANGED
@@ -169,7 +169,11 @@ Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfile
169
 
170
  ### Training
171
  ```
172
- deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b"
 
 
 
 
173
  ```
174
  When training is finished, to get the full model weight:
175
  ```
@@ -178,7 +182,13 @@ cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
178
 
179
  ### Validation
180
  ```
181
- deepspeed --master_port=24999 train_ds.py --version="PATH_TO_LLaVA_Wegihts" --dataset_dir='./dataset' --vision_pretrained="PATH_TO_SAM_Weights" --exp_name="lisa-7b" --weight='PATH_TO_pytorch_model.bin' --eval_only
 
 
 
 
 
 
182
  ```
183
 
184
 
 
169
 
170
  ### Training
171
  ```
172
+ deepspeed --master_port=24999 train_ds.py \
173
+ --version="PATH_TO_LLaVA_Wegihts" \
174
+ --dataset_dir='./dataset' \
175
+ --vision_pretrained="PATH_TO_SAM_Weights" \
176
+ --exp_name="lisa-7b"
177
  ```
178
  When training is finished, to get the full model weight:
179
  ```
 
182
 
183
  ### Validation
184
  ```
185
+ deepspeed --master_port=24999 train_ds.py \
186
+ --version="PATH_TO_LLaVA_Wegihts" \
187
+ --dataset_dir='./dataset' \
188
+ --vision_pretrained="PATH_TO_SAM_Weights" \
189
+ --exp_name="lisa-7b" \
190
+ --weight='PATH_TO_pytorch_model.bin' \
191
+ --eval_only
192
  ```
193
 
194
 
utils/reason_seg_dataset.py CHANGED
@@ -76,12 +76,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
76
  ) as f:
77
  items = json.load(f)
78
  for item in items:
79
- img_name = item["image_path"].split("/")[-1]
80
  self.img_to_explanation[img_name] = {
81
  "query": item["query"],
82
  "outputs": item["outputs"],
83
  }
84
 
 
85
 
86
  def __len__(self):
87
  return self.samples_per_epoch
 
76
  ) as f:
77
  items = json.load(f)
78
  for item in items:
79
+ img_name = item["image"]
80
  self.img_to_explanation[img_name] = {
81
  "query": item["query"],
82
  "outputs": item["outputs"],
83
  }
84
 
85
+ print("len(self.img_to_explanation): ", len(self.img_to_explanation))
86
 
87
  def __len__(self):
88
  return self.samples_per_epoch
utils/sem_seg_dataset.py CHANGED
@@ -104,10 +104,6 @@ def init_paco_lvis(base_image_dir):
104
  obj, part = cat_split
105
  obj = obj.split("_(")[0]
106
  part = part.split("_(")[0]
107
- # if random.random() < 0.5:
108
- # name = obj + " " + part
109
- # else:
110
- # name = "the {} of the {}".format(part, obj)
111
  name = (obj, part)
112
  class_map_paco_lvis[cat["id"]] = name
113
  img_ids = coco_api_paco_lvis.getImgIds()
 
104
  obj, part = cat_split
105
  obj = obj.split("_(")[0]
106
  part = part.split("_(")[0]
 
 
 
 
107
  name = (obj, part)
108
  class_map_paco_lvis[cat["id"]] = name
109
  img_ids = coco_api_paco_lvis.getImgIds()