ethanlshen commited on
Commit
d8add38
·
verified ·
1 Parent(s): 37a92ca

Remove parallel

Browse files
superposed/llama/superposed_generation.py CHANGED
@@ -34,29 +34,29 @@ class SuperposedLlama:
34
  model_parallel_size: Optional[int] = None,
35
  seed: int = 1,
36
  ):
37
- if not torch.distributed.is_initialized():
38
- torch.distributed.init_process_group("nccl")
39
- if not model_parallel_is_initialized():
40
- if model_parallel_size is None:
41
- model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
42
- initialize_model_parallel(model_parallel_size)
43
 
44
- local_rank = int(os.environ.get("LOCAL_RANK", 0))
45
  if device == None:
46
  torch.cuda.set_device(local_rank)
47
  device = torch.cuda.current_device()
48
  torch.manual_seed(seed)
49
 
50
- if local_rank > 0:
51
- sys.stdout = open(os.devnull, "w")
52
 
53
- start_time = time.time()
54
  checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
55
- assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
56
- assert model_parallel_size == len(
57
- checkpoints
58
- ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
59
- ckpt_path = checkpoints[get_model_parallel_rank()]
60
  checkpoint = torch.load(ckpt_path, map_location="cpu")
61
  with open(Path(ckpt_dir) / "params.json", "r") as f:
62
  params = json.loads(f.read())
 
34
  model_parallel_size: Optional[int] = None,
35
  seed: int = 1,
36
  ):
37
+ # if not torch.distributed.is_initialized():
38
+ # torch.distributed.init_process_group("nccl")
39
+ # if not model_parallel_is_initialized():
40
+ # if model_parallel_size is None:
41
+ # model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
42
+ # initialize_model_parallel(model_parallel_size)
43
 
44
+ # local_rank = int(os.environ.get("LOCAL_RANK", 0))
45
  if device == None:
46
  torch.cuda.set_device(local_rank)
47
  device = torch.cuda.current_device()
48
  torch.manual_seed(seed)
49
 
50
+ # if local_rank > 0:
51
+ # sys.stdout = open(os.devnull, "w")
52
 
53
+ # start_time = time.time()
54
  checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
55
+ # assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
56
+ # assert model_parallel_size == len(
57
+ # checkpoints
58
+ # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
59
+ ckpt_path = checkpoints[0]
60
  checkpoint = torch.load(ckpt_path, map_location="cpu")
61
  with open(Path(ckpt_dir) / "params.json", "r") as f:
62
  params = json.loads(f.read())