-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathtfloat32.h
479 lines (393 loc) · 12.6 KB
/
tfloat32.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Defines a proxy class for storing Tensor Float 32 data type.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include "cutlass/floating_point_nvrtc.h"
#else
#include <cmath>
#include <limits>
#include <cstdint>
#include <cstring> // std::memcpy
#endif
#include "cutlass/cutlass.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Tensor Float 32 data type
struct alignas(4) tfloat32_t {
//
// Data members
//
/// Storage type
uint32_t storage;
//
// Methods
//
private:
CUTLASS_HOST_DEVICE
static uint32_t float_to_storage(float s) {
#if defined(__CUDA_ARCH__)
uint32_t result = reinterpret_cast<uint32_t const &>(s);
#else
uint32_t result;
std::memcpy(&result, &s, sizeof(float));
#endif
return result;
}
public:
/// Constructs from an unsigned int
CUTLASS_HOST_DEVICE
static tfloat32_t bitcast(uint32_t x) {
tfloat32_t h;
h.storage = x;
return h;
}
/// Emulated rounding is fast in device code
CUTLASS_HOST_DEVICE
static tfloat32_t round_half_ulp_truncate(float const &s) {
uint32_t x = float_to_storage(s);
#if defined(__CUDA_ARCH__)
if (::isfinite(s)) {
x += 0x1000u;
}
#else
if (std::isfinite(s)) {
x += 0x1000u;
}
#endif
return tfloat32_t::bitcast(x);
}
tfloat32_t() = default;
/// Floating-point conversion - round toward nearest even
CUTLASS_HOST_DEVICE
explicit tfloat32_t(float x): storage(round_half_ulp_truncate(x).raw()) { }
// Conversion from double (this rounds twice)
CUTLASS_HOST_DEVICE
explicit tfloat32_t(double x): tfloat32_t(float(x)) { }
/// Integer conversion - round toward zero
CUTLASS_HOST_DEVICE
explicit tfloat32_t(int x) {
float flt = static_cast<float>(x);
#if defined(__CUDA_ARCH__)
storage = reinterpret_cast<uint32_t const &>(flt);
#else
std::memcpy(&storage, &flt, sizeof(storage));
#endif
}
// Conversion to float
CUTLASS_HOST_DEVICE
operator float() const {
// Conversions to IEEE single-precision requires clearing dont-care bits
// of the mantissa.
unsigned bits = (storage & ~0x1fffu);
#if defined(__CUDA_ARCH__)
return reinterpret_cast<float const &>(bits);
#else
float flt;
std::memcpy(&flt, &bits, sizeof(flt));
return flt;
#endif
}
/// Converts to double
CUTLASS_HOST_DEVICE
explicit operator double() const {
return double(float(*this));
}
/// Converts to int
CUTLASS_HOST_DEVICE
explicit operator int() const {
return int(float(*this));
}
/// Casts to bool
CUTLASS_HOST_DEVICE
explicit operator bool() const {
return (float(*this) != 0.0f);
}
/// Obtains raw bits
CUTLASS_HOST_DEVICE
uint32_t raw() const {
return storage;
}
/// Returns the sign bit
CUTLASS_HOST_DEVICE
bool signbit() const {
return ((raw() & 0x80000000) != 0);
}
/// Returns the biased exponent
CUTLASS_HOST_DEVICE
int exponent_biased() const {
return int((raw() >> 23) & 0x0ff);
}
/// Returns the unbiased exponent
CUTLASS_HOST_DEVICE
int exponent() const {
return exponent_biased() - 127;
}
/// Returns the mantissa
CUTLASS_HOST_DEVICE
int mantissa() const {
return int(raw() & 0x7fffff);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool signbit(cutlass::tfloat32_t const& h) {
return h.signbit();
}
CUTLASS_HOST_DEVICE
cutlass::tfloat32_t abs(cutlass::tfloat32_t const& h) {
return cutlass::tfloat32_t::bitcast(h.raw() & 0x7fffffff);
}
CUTLASS_HOST_DEVICE
bool isnan(cutlass::tfloat32_t const& h) {
return (h.exponent_biased() == 0x0ff) && h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isfinite(cutlass::tfloat32_t const& h) {
return (h.exponent_biased() != 0x0ff);
}
CUTLASS_HOST_DEVICE
cutlass::tfloat32_t nan_tf32(const char*) {
// NVIDIA canonical NaN
return cutlass::tfloat32_t::bitcast(0x7fffffff);
}
CUTLASS_HOST_DEVICE
bool isinf(cutlass::tfloat32_t const& h) {
return (h.exponent_biased() == 0x0ff) && !h.mantissa();
}
CUTLASS_HOST_DEVICE
bool isnormal(cutlass::tfloat32_t const& h) {
return h.exponent_biased() && h.exponent_biased() != 0x0ff;
}
CUTLASS_HOST_DEVICE
int fpclassify(cutlass::tfloat32_t const& h) {
int exp = h.exponent_biased();
int mantissa = h.mantissa();
if (exp == 0x0ff) {
if (mantissa) {
return FP_NAN;
}
else {
return FP_INFINITE;
}
}
else if (!exp) {
if (mantissa) {
return FP_SUBNORMAL;
}
else {
return FP_ZERO;
}
}
return FP_NORMAL;
}
CUTLASS_HOST_DEVICE
cutlass::tfloat32_t sqrt(cutlass::tfloat32_t const& h) {
#if defined(__CUDACC_RTC__)
return cutlass::tfloat32_t(sqrtf(float(h)));
#else
return cutlass::tfloat32_t(std::sqrt(float(h)));
#endif
}
CUTLASS_HOST_DEVICE
tfloat32_t copysign(tfloat32_t const& a, tfloat32_t const& b) {
uint32_t a_mag = (a.raw() & 0x7fffffff);
uint32_t b_sign = (b.raw() & 0x80000000);
uint32_t result = (a_mag | b_sign);
return tfloat32_t::bitcast(result);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Standard Library operations and definitions
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace std {
#if !defined(__CUDACC_RTC__)
/// Numeric limits
template <>
struct numeric_limits<cutlass::tfloat32_t> {
static bool const is_specialized = true;
static bool const is_signed = true;
static bool const is_integer = false;
static bool const is_exact = false;
static bool const has_infinity = true;
static bool const has_quiet_NaN = true;
static bool const has_signaling_NaN = false;
static std::float_denorm_style const has_denorm = std::denorm_present;
static bool const has_denorm_loss = true;
static std::float_round_style const round_style = std::round_to_nearest;
static bool const is_iec559 = false;
static bool const is_bounded = true;
static bool const is_modulo = false;
static int const digits = 19;
/// Least positive value
static cutlass::tfloat32_t min() { return cutlass::tfloat32_t::bitcast(0x01); }
/// Minimum finite value
static cutlass::tfloat32_t lowest() { return cutlass::tfloat32_t::bitcast(0xff7fffff); }
/// Maximum finite value
static cutlass::tfloat32_t max() { return cutlass::tfloat32_t::bitcast(0x7f7fffff); }
/// Returns smallest finite value
static cutlass::tfloat32_t epsilon() { return cutlass::tfloat32_t::bitcast(0x1000); }
/// Returns smallest finite value
static cutlass::tfloat32_t round_error() { return cutlass::tfloat32_t(0.5f); }
/// Returns smallest finite value
static cutlass::tfloat32_t infinity() { return cutlass::tfloat32_t::bitcast(0x7f800000); }
/// Returns smallest finite value
static cutlass::tfloat32_t quiet_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
/// Returns smallest finite value
static cutlass::tfloat32_t signaling_NaN() { return cutlass::tfloat32_t::bitcast(0x7fffffff); }
/// Returns smallest finite value
static cutlass::tfloat32_t denorm_min() { return cutlass::tfloat32_t::bitcast(0x1); }
};
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace std
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// Arithmetic operators
//
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
CUTLASS_HOST_DEVICE
bool operator==(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) == float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator!=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) != float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) < float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator<=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) <= float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) > float(rhs);
}
CUTLASS_HOST_DEVICE
bool operator>=(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return float(lhs) >= float(rhs);
}
CUTLASS_HOST_DEVICE
tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return tfloat32_t(float(lhs) + float(rhs));
}
CUTLASS_HOST_DEVICE
tfloat32_t operator-(tfloat32_t const& lhs) {
return tfloat32_t::bitcast(0x80000000 ^ lhs.raw());
}
CUTLASS_HOST_DEVICE
tfloat32_t operator-(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return tfloat32_t(float(lhs) - float(rhs));
}
CUTLASS_HOST_DEVICE
tfloat32_t operator*(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return tfloat32_t(float(lhs) * float(rhs));
}
CUTLASS_HOST_DEVICE
tfloat32_t operator/(tfloat32_t const& lhs, tfloat32_t const& rhs) {
return tfloat32_t(float(lhs) / float(rhs));
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator+=(tfloat32_t & lhs, tfloat32_t const& rhs) {
lhs = tfloat32_t(float(lhs) + float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator-=(tfloat32_t & lhs, tfloat32_t const& rhs) {
lhs = tfloat32_t(float(lhs) - float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator*=(tfloat32_t & lhs, tfloat32_t const& rhs) {
lhs = tfloat32_t(float(lhs) * float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator/=(tfloat32_t & lhs, tfloat32_t const& rhs) {
lhs = tfloat32_t(float(lhs) / float(rhs));
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator++(tfloat32_t & lhs) {
float tmp(lhs);
++tmp;
lhs = tfloat32_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t& operator--(tfloat32_t & lhs) {
float tmp(lhs);
--tmp;
lhs = tfloat32_t(tmp);
return lhs;
}
CUTLASS_HOST_DEVICE
tfloat32_t operator++(tfloat32_t & lhs, int) {
tfloat32_t ret(lhs);
float tmp(lhs);
tmp++;
lhs = tfloat32_t(tmp);
return ret;
}
CUTLASS_HOST_DEVICE
tfloat32_t operator--(tfloat32_t & lhs, int) {
tfloat32_t ret(lhs);
float tmp(lhs);
tmp--;
lhs = tfloat32_t(tmp);
return ret;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// User-defined literals
//
CUTLASS_HOST_DEVICE
cutlass::tfloat32_t operator "" _tf32(long double x) {
return cutlass::tfloat32_t(float(x));
}
CUTLASS_HOST_DEVICE
cutlass::tfloat32_t operator "" _tf32(unsigned long long int x) {
return cutlass::tfloat32_t(int(x));
}
/////////////////////////////////////////////////////////////////////////////////////////////////