Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
abacaj committed Sep 8, 2023
1 parent e914abb commit 4e6a503
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
8 changes: 4 additions & 4 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def _initialize_torch_distributed(args: argparse.Namespace):
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
if args.cross_block_networking or args.force_socket_networking:
os.environ["NCCL_NET"] = "Socket"
else:
os.environ["NCCL_NET"] = "IB"
# if args.cross_block_networking or args.force_socket_networking:
# os.environ["NCCL_NET"] = "Socket"
# else:
# os.environ["NCCL_NET"] = "IB"
torch.distributed.init_process_group(
backend=args.default_backend,
world_size=args.world_size,
Expand Down
9 changes: 6 additions & 3 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,17 +631,20 @@ def force_communicator_creation(
torch.distributed.barrier(group=group)
if all_reduce:
one_tensor = torch.cuda.FloatTensor([1.0])
torch.distributed.all_reduce(one_tensor, op=torch.distributed.ReduceOp.SUM, group=group)
if world_size and world_size > 1:
torch.distributed.all_reduce(one_tensor, op=torch.distributed.ReduceOp.SUM)
if all_gather:
assert (
rank is not None and world_size is not None
), "Must supply rank and world_size for all_gather initialization"
tensor_list = [torch.empty_like(one_tensor) for _ in range(world_size)]
tensor_list[rank] = one_tensor
torch.distributed.all_gather(tensor_list, one_tensor, group=group)
if world_size > 1:
torch.distributed.all_gather(tensor_list, one_tensor)
if broadcast:
one_tensor = torch.cuda.FloatTensor([1.0])
torch.distributed.broadcast(one_tensor, src_rank, group=group)
if world_size > 1:
torch.distributed.broadcast(one_tensor, src_rank)


def force_pipeline_communicator_creation(ignore_virtual=False):
Expand Down
6 changes: 4 additions & 2 deletions megatron/text_generation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,16 +582,18 @@ def put(self) -> Union[Tuple[str, int], str]:
(
response,
response_logprobs,
_,
all_tokens,
generations,
_,
human_readable_tokens,
) = retval

end_time = datetime.datetime.now()
print("Query latency: ", end_time - start_time, flush=True)
print(all_tokens)
output = {
"text": response,
"query_time_ms": (end_time - start_time).total_seconds() * 1000,
"tokens_generated": sum([len(t) for t in all_tokens]),
"logprobs": response_logprobs,
"generations": generations,
"human_readable_tokens": human_readable_tokens,
Expand Down

0 comments on commit 4e6a503

Please sign in to comment.