Spaces:
Runtime error
Runtime error
Remove parallel
Browse files
superposed/llama/superposed_generation.py
CHANGED
@@ -34,29 +34,29 @@ class SuperposedLlama:
|
|
34 |
model_parallel_size: Optional[int] = None,
|
35 |
seed: int = 1,
|
36 |
):
|
37 |
-
if not torch.distributed.is_initialized():
|
38 |
-
|
39 |
-
if not model_parallel_is_initialized():
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
45 |
if device == None:
|
46 |
torch.cuda.set_device(local_rank)
|
47 |
device = torch.cuda.current_device()
|
48 |
torch.manual_seed(seed)
|
49 |
|
50 |
-
if local_rank > 0:
|
51 |
-
|
52 |
|
53 |
-
start_time = time.time()
|
54 |
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
55 |
-
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
56 |
-
assert model_parallel_size == len(
|
57 |
-
|
58 |
-
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
59 |
-
ckpt_path = checkpoints[
|
60 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
61 |
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
62 |
params = json.loads(f.read())
|
|
|
34 |
model_parallel_size: Optional[int] = None,
|
35 |
seed: int = 1,
|
36 |
):
|
37 |
+
# if not torch.distributed.is_initialized():
|
38 |
+
# torch.distributed.init_process_group("nccl")
|
39 |
+
# if not model_parallel_is_initialized():
|
40 |
+
# if model_parallel_size is None:
|
41 |
+
# model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
|
42 |
+
# initialize_model_parallel(model_parallel_size)
|
43 |
|
44 |
+
# local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
45 |
if device == None:
|
46 |
torch.cuda.set_device(local_rank)
|
47 |
device = torch.cuda.current_device()
|
48 |
torch.manual_seed(seed)
|
49 |
|
50 |
+
# if local_rank > 0:
|
51 |
+
# sys.stdout = open(os.devnull, "w")
|
52 |
|
53 |
+
# start_time = time.time()
|
54 |
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
55 |
+
# assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
56 |
+
# assert model_parallel_size == len(
|
57 |
+
# checkpoints
|
58 |
+
# ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
59 |
+
ckpt_path = checkpoints[0]
|
60 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
61 |
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
62 |
params = json.loads(f.read())
|