Spaces:
Paused
Paused
x-lai
commited on
Commit
·
ba91034
1
Parent(s):
da654fa
Release training script
Browse filesFormer-commit-id: 16f73b66904d9d45e67de3ae95f7fe0e54f097f8
- README.md +12 -2
- utils/reason_seg_dataset.py +2 -1
- utils/sem_seg_dataset.py +0 -4
README.md
CHANGED
@@ -169,7 +169,11 @@ Download SAM ViT-H pre-trained weights from the [link](https://dl.fbaipublicfile
|
|
169 |
|
170 |
### Training
|
171 |
```
|
172 |
-
deepspeed --master_port=24999 train_ds.py
|
|
|
|
|
|
|
|
|
173 |
```
|
174 |
When training is finished, to get the full model weight:
|
175 |
```
|
@@ -178,7 +182,13 @@ cd ./runs/lisa-7b/ckpt_model && python zero_to_fp32.py . ../pytorch_model.bin
|
|
178 |
|
179 |
### Validation
|
180 |
```
|
181 |
-
deepspeed --master_port=24999 train_ds.py
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
```
|
183 |
|
184 |
|
|
|
169 |
|
170 |
### Training
|
171 |
```
|
172 |
+
deepspeed --master_port=24999 train_ds.py \
|
173 |
+
--version="PATH_TO_LLaVA_Wegihts" \
|
174 |
+
--dataset_dir='./dataset' \
|
175 |
+
--vision_pretrained="PATH_TO_SAM_Weights" \
|
176 |
+
--exp_name="lisa-7b"
|
177 |
```
|
178 |
When training is finished, to get the full model weight:
|
179 |
```
|
|
|
182 |
|
183 |
### Validation
|
184 |
```
|
185 |
+
deepspeed --master_port=24999 train_ds.py \
|
186 |
+
--version="PATH_TO_LLaVA_Wegihts" \
|
187 |
+
--dataset_dir='./dataset' \
|
188 |
+
--vision_pretrained="PATH_TO_SAM_Weights" \
|
189 |
+
--exp_name="lisa-7b" \
|
190 |
+
--weight='PATH_TO_pytorch_model.bin' \
|
191 |
+
--eval_only
|
192 |
```
|
193 |
|
194 |
|
utils/reason_seg_dataset.py
CHANGED
@@ -76,12 +76,13 @@ class ReasonSegDataset(torch.utils.data.Dataset):
|
|
76 |
) as f:
|
77 |
items = json.load(f)
|
78 |
for item in items:
|
79 |
-
img_name = item["
|
80 |
self.img_to_explanation[img_name] = {
|
81 |
"query": item["query"],
|
82 |
"outputs": item["outputs"],
|
83 |
}
|
84 |
|
|
|
85 |
|
86 |
def __len__(self):
|
87 |
return self.samples_per_epoch
|
|
|
76 |
) as f:
|
77 |
items = json.load(f)
|
78 |
for item in items:
|
79 |
+
img_name = item["image"]
|
80 |
self.img_to_explanation[img_name] = {
|
81 |
"query": item["query"],
|
82 |
"outputs": item["outputs"],
|
83 |
}
|
84 |
|
85 |
+
print("len(self.img_to_explanation): ", len(self.img_to_explanation))
|
86 |
|
87 |
def __len__(self):
|
88 |
return self.samples_per_epoch
|
utils/sem_seg_dataset.py
CHANGED
@@ -104,10 +104,6 @@ def init_paco_lvis(base_image_dir):
|
|
104 |
obj, part = cat_split
|
105 |
obj = obj.split("_(")[0]
|
106 |
part = part.split("_(")[0]
|
107 |
-
# if random.random() < 0.5:
|
108 |
-
# name = obj + " " + part
|
109 |
-
# else:
|
110 |
-
# name = "the {} of the {}".format(part, obj)
|
111 |
name = (obj, part)
|
112 |
class_map_paco_lvis[cat["id"]] = name
|
113 |
img_ids = coco_api_paco_lvis.getImgIds()
|
|
|
104 |
obj, part = cat_split
|
105 |
obj = obj.split("_(")[0]
|
106 |
part = part.split("_(")[0]
|
|
|
|
|
|
|
|
|
107 |
name = (obj, part)
|
108 |
class_map_paco_lvis[cat["id"]] = name
|
109 |
img_ids = coco_api_paco_lvis.getImgIds()
|