-
Notifications
You must be signed in to change notification settings - Fork 359
/
Copy pathutils.py
307 lines (233 loc) · 9.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
from __future__ import annotations
import functools
import math
from typing import Any, Callable, Optional, Tuple
import torch
import transformer_engine.pytorch.cpp_extensions as ext
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
for tensor in tensors:
if tensor is not None and tensor.requires_grad:
return True
return False
def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""
Trick to deallocate tensor memory when delete operation does not
release the tensor due to PyTorch override.
Must be used carefully.
"""
from .float8_tensor import Float8Tensor
for t in tensors:
if t is not None:
if isinstance(t, Float8Tensor):
t._data.data = torch.Tensor()
else:
t.data = torch.Tensor()
del t
def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device())
return (props.major, props.minor)
def attention_mask_func(
attention_scores: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""Get attention mask"""
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_default_init_method() -> Callable:
"""Weight initialization method if not provided by user"""
return init_method_normal(0.023)
def init_method_constant(val: float) -> Callable:
"""Init method to set all tensor elements to a constant value."""
if val == 1.0:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.ones_(tensor)
elif val == 0.0:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.zeros_(tensor)
else:
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.constant_(tensor, val)
return init_
def init_method_normal(sigma: float) -> Callable:
"""Init method based on N(0, sigma)."""
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma: float, num_layers: int) -> Callable:
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor: torch.Tensor) -> Callable:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def all_close(a: torch.Tensor, b: torch.Tensor) -> bool:
"""torch.allclose with cpu to not run into OOMs"""
return torch.allclose(a.cpu(), b.cpu())
def print_rank_0(*args: Any) -> None:
"""print on rank 0"""
if torch.cuda.current_device() == 0:
print(*args)
def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None:
"""util function to show some tensor stats"""
if a.shape != b.shape:
print_rank_0("Tensors have different shape")
return
print_rank_0(a)
print_rank_0(b)
max_err = torch.max(torch.abs(a - b))
max_a = torch.max(a)
max_b = torch.max(b)
print_rank_0(f"max err={max_err}, max a={max_a}, max_b={max_b}")
def ensure_divisibility(numerator: int, denominator: int) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
def divide(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_dim(
tensor: torch.Tensor, dim: int, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
split_size = divide(tensor.size()[dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, split_size, dim=dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def validate_ctx_manager(ctx: Callable) -> None:
"""Checks if passed in object can be used as a context manager."""
try:
with ctx():
pass
except Exception as e:
raise ValueError("Object must be a valid ctx manager") from e
def validate_rng_states_func(get_rng_tracker: Callable) -> None:
"""Checks if passed in param function has everything
required for tensor/model and sequence parallel.
"""
assert callable(get_rng_tracker), "get_rng_tracker is not a valid function"
rng_tracker = None
try:
rng_tracker = get_rng_tracker()
except Exception as e:
raise RuntimeError("Cannot call get_rng_tracker function") from e
assert hasattr(rng_tracker, "get_states") and callable(
rng_tracker.get_states
), "rng_tracker object does not have valid method get_states"
assert hasattr(rng_tracker, "set_states") and callable(
rng_tracker.set_states
), "rng_tracker object does not have valid method set_states"
assert hasattr(rng_tracker, "fork") and callable(
rng_tracker.fork
), "rng_tracker object does not have valid method fork"
validate_ctx_manager(rng_tracker.fork)
def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None) -> torch.Tensor:
"""Assert that a tensor is not a view (i.e., its '._base' field is
not set)."""
if isinstance(tensor, list):
return [assert_viewless_tensor(t) for t in tensor]
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
f"likely accumulate over iterations). {extra_msg}"
)
return tensor
def safely_set_viewless_tensor_data(tensor: torch.Tensor, new_data_tensor: torch.Tensor) -> None:
"""Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
"""
extra_msg = (
"FYI, tensor._base has shape "
f"{'--' if tensor._base is None else tensor._base.shape},"
f"and new_data_tensor has shape {new_data_tensor.shape}."
)
assert_viewless_tensor(tensor, extra_msg=extra_msg)
tensor.data = new_data_tensor
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
if tensor is None:
return None
if tensor.dtype == dtype:
return tensor
with torch.enable_grad():
return tensor.to(dtype=dtype)
def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool:
"""Check if tensor dimensions are supported for FP8 TN GEMM"""
return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0
def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None:
"""Assert that tensor dimensions are supported for FP8 TN GEMM"""
# single tensor check so it's clear which tensor is triggering the assertion
assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, (
"FP8 execution requires 2D input matrices with "
"height divisible by 8 and width divisible by 16, "
f"but got tensor with dims={list(tensor.size())}"
)
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return torch.cuda.get_device_capability()[0] >= 8
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
minor, patch = divmod(encoded_version, 100)
return (major, minor, patch)
def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device
If `None`, then returns the default CUDA device.
"""
if device is None:
# Use default CUDA device
device = torch.get_default_device()
if device.type != "cuda":
device = torch.device("cuda", torch.cuda.current_device())
elif not isinstance(device, torch.device):
device = torch.device(device)
if device.type == "cuda" and device.index is None:
device = torch.device("cuda", torch.cuda.current_device())
return device
def canonicalize_dtype(dtype: Optional[torch.dtype]) -> torch.dtype:
"""Canonicalize PyTorch datatype
If `None`, then returns the default PyTorch datatype.
"""
if dtype is None:
# Use default dtype
dtype = torch.get_default_dtype()
return dtype
def devices_match(device1: torch.device, device2: torch.device) -> bool:
"""Whether two devices are the same"""
device1 = torch.device(device1)
device2 = torch.device(device2)
if device1.type != device2.type:
return False
if device1.type == "cuda":
index1 = device1.index
index2 = device2.index
if index1 == index2:
return True
if index1 is None:
index1 = torch.cuda.current_device()
if index2 is None:
index2 = torch.cuda.current_device()
return index1 == index2
return device1 == device2