Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 with Tensor Reorg #760

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9d0a2c6
doesn't work but at least it runs (with a loss of -1.0)...
ademeure Aug 31, 2024
03f3136
Everything working up to grad norm (single-gpu, no recompute)
ademeure Sep 1, 2024
3018cc6
More laconic ACT_XL, PARAM_X, etc. indexing...
ademeure Sep 1, 2024
474a60b
broken progress (forward is OK, backward is not) for restructure usin…
ademeure Sep 1, 2024
e27385d
It's alive!!! gpt2_update() is working as is forward/backward/recompu…
ademeure Sep 2, 2024
0c4b1e1
Refactoring into nicer ACT/PARAM_L/etc. macros with get_tensor and au…
ademeure Sep 2, 2024
65264ed
Activation checkpointing for entire layers is working!
ademeure Sep 2, 2024
07ed7ea
First draft of TensorGPU approach
ademeure Sep 4, 2024
a70f322
WIP most things converted to TensorGPU, bit more encoder and a lot mo…
ademeure Sep 6, 2024
a864fe5
More TensorGPU integration + better stochastic rounding + better laye…
ademeure Sep 7, 2024
c02382e
WIP, new unified adam is working, scaling is crashing, should be easy…
ademeure Sep 7, 2024
eedb4d0
WIP, seems to all kinda work (famous last words) - but cuBLAS doesn't…
ademeure Sep 10, 2024
b09dbc9
fake FP8 kinda-sorta works
ademeure Sep 16, 2024
d2b3e82
compilation fix
ademeure Sep 16, 2024
b94c3b7
moved tensor functionality into tensor.cuh, added transpose_simple ke…
ademeure Sep 16, 2024
2900ec9
WIP FP8 forward (doesn't quite work yet - obviously...)
ademeure Sep 16, 2024
2e26a9c
FP8 forward working again (!!!) + AdamW for FP8/BF16/FP32
ademeure Sep 17, 2024
41fd098
FP8 forward+backward+update (again)! FINALLY!
ademeure Sep 17, 2024
0740c2f
1st phase of cleanup
ademeure Sep 18, 2024
acd1058
FP8 cleanup part 2
ademeure Sep 18, 2024
c68bb9f
WIP multi-gpu and new global norm (doesn't work yet)
ademeure Sep 18, 2024
41caf5f
fixed global norm (not bit identical to previous implementation but I…
ademeure Sep 18, 2024
f769bdc
Optimized adam/global_norm using new "gpu_tensor_end_element" array
ademeure Sep 18, 2024
ac0dc6e
more optimization and cleanup for global_norm (83% DRAM efficiency) a…
ademeure Sep 19, 2024
ef3053c
tentative multigpu ZeRO 1 with AllGather + absmax history window
ademeure Sep 19, 2024
0df701b
save absmax/scale/descale to state file (untested, does it work? let'…
ademeure Sep 19, 2024
b1827d1
one last bit of cleanup before travelling
ademeure Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ $(NVCC_CUDNN): llmc/cudnn_att.cpp
$(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_INCLUDES) -o $@

train_gpt2cu: train_gpt2.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

train_gpt2fp32cu: train_gpt2_fp32.cu
$(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)
Expand Down
217 changes: 142 additions & 75 deletions llmc/adamw.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,151 @@ __device__ float lerp(float start, float end, float weight) {
return fma(weight, end, fma(-weight, start, start));
}

template <typename Tp, typename Tg>
__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
float grad_scale, unsigned int seed) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_parameters) { return; } // guard

// get the gradient, m, and v for this parameter
float grad = grad_scale * (float)grads_memory[idx];
float m = m_memory[idx];
float v = v_memory[idx];
// update the first moment (momentum)
m = lerp(grad, m, beta1);
m_memory[idx] = m;
// update the second moment (RMSprop)
v = lerp(grad * grad, v, beta2);
v_memory[idx] = v;
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
// fetch the old value of this parameter as a float, from either source
float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx];
// update this parameter
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));
// update our low precision version of the parameters using stochastic rounding
// this will be used in the next forward pass
stochastic_rounding(param, &params_memory[idx], seed);
// write the full, float version of the param into our master copy, if we maintain one
// this will be used in the next update
if (master_params_memory != NULL) { master_params_memory[idx] = param; }
}
template <bool use_master_weights=true, typename Tparam=floatX, typename Tgrad=floatX, typename Tm=float, typename Tv=float, typename Tmaster=float>
__device__ size_t adamw_update_part(TensorGPU<Tparam> param_tensor,
size_t idx, size_t current_start, size_t current_end, size_t stride, unsigned int seed, unsigned int shard_idx,
TensorGPU<Tgrad> grad_tensor, TensorGPU<Tmaster> master_tensor, TensorGPU<Tm> opt_m_tensor, TensorGPU<Tv> opt_v_tensor,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction,
float eps, float wd, float grad_scale, int t) {
auto out_master128 = new_tensor128<use_master_weights>(master_tensor, true);
auto out_opt_m128 = new_tensor128(opt_m_tensor, true);
auto out_opt_v128 = new_tensor128(opt_v_tensor, true);
auto out_param128 = new_tensor128(param_tensor);

template <typename Tp, typename Tg>
__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
float grad_scale, unsigned int seed) {
adamw_update(params_memory + blockIdx.y * w_stride,
master_params_memory ? master_params_memory + blockIdx.y * s_stride : NULL,
grads_memory + blockIdx.y * g_stride,
m_memory + blockIdx.y * s_stride,
v_memory + blockIdx.y * s_stride,
num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale,
seed
);
}
__syncthreads(); // todo - this should improve memory locality
while (idx < current_end) {
unsigned int random = get_random_noise(seed, idx);

template <typename Tp>
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_parameters) { return; }
params_memory += blockIdx.y * w_stride; // adjust for layer offset
master_params_memory += blockIdx.y * s_stride;
stochastic_rounding(master_params_memory[idx], &params_memory[idx], seed);
}
tensor128<Tparam> param128;
tensor128<Tgrad> grad128;
tensor128<Tm> opt_m128;
tensor128<Tv> opt_v128;
tensor128<Tmaster> master128;
int next_idx[TT::NUM_TYPES_PARAM] = {0};
int current_idx[TT::NUM_TYPES_PARAM] = {0};

// todo - assuming either DPP or ZeRO 1 now (sharded optimizer/master, unsharded gradients/parameters)
// offset is 32-bit (checked <= elements in add_tensor_spec)
unsigned int offset = idx - current_start;
unsigned int unsharded_offset = offset + shard_idx * opt_v_tensor.num_elements;

// this implementation has a stride causing sparse reads/writes and bank conflicts for non-FP8
// todo - compare performance with a version that uses 128-bit for FP32, 64-bit for BF16, 32-bit for FP8 (probably much faster)
#pragma unroll
for (int i = 0; i < 16; i += 4, offset += 4, unsharded_offset += 4) {
if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, unsharded_offset);
if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, unsharded_offset, false, true);
if (current_idx[PARAMETER_OPT_M] == 0) opt_m128 = load_tensor128(opt_m_tensor, offset, false,true);
if (current_idx[PARAMETER_OPT_V] == 0) opt_v128 = load_tensor128(opt_v_tensor, offset, false, true);
if (current_idx[PARAMETER_MASTER] == 0 && use_master_weights) master128 = load_tensor128(master_tensor, offset, false, true);

for (int k = 0; k < 4; k++) {
float grad = grad128.get(current_idx[PARAMETER_GRAD] + k);
float m = opt_m128.get(current_idx[PARAMETER_OPT_M] + k);
float v = opt_v128.get(current_idx[PARAMETER_OPT_V] + k);

m = lerp(grad, m, beta1);
v = lerp(grad * grad, v, beta2);
out_opt_m128.set(current_idx[PARAMETER_OPT_M] + k, m);
out_opt_v128.set(current_idx[PARAMETER_OPT_V] + k, v);
m /= beta1_correction;
v /= beta2_correction;

float old_param;
if constexpr (use_master_weights) {
old_param = master128.get(current_idx[PARAMETER_MASTER] + k);
} else {
old_param = param128.get(current_idx[PARAMETER] + k);
}

float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param));
if constexpr (use_master_weights) {
out_master128.set(current_idx[PARAMETER_MASTER] + k, param);
}
out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random);
}
next_idx[PARAMETER] = (i + 4) % (16 / sizeof(Tparam));
next_idx[PARAMETER_GRAD] = (i + 4) % (16 / sizeof(Tgrad));
next_idx[PARAMETER_OPT_M] = (i + 4) % (16 / sizeof(Tm));
next_idx[PARAMETER_OPT_V] = (i + 4) % (16 / sizeof(Tv));
next_idx[PARAMETER_MASTER] = (i + 4) % (16 / sizeof(Tmaster));

template <typename Tp, typename Tg>
void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay,
float grad_scale, unsigned int seed, cudaStream_t stream) {
// AdamW update
int block_size = 512;
int num_blocks = CEIL_DIV(num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
adamw_kernel3<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>(params_memory, master_params_memory, grads_memory,
m_memory, v_memory, num_parameters, w_stride, g_stride, s_stride,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay,
grad_scale, seed);
cudaCheck(cudaGetLastError());
if (next_idx[PARAMETER] == 0) out_param128.store(unsharded_offset - current_idx[PARAMETER]);
if (next_idx[PARAMETER_OPT_M] == 0) out_opt_m128.store(offset - current_idx[PARAMETER_OPT_M]);
if (next_idx[PARAMETER_OPT_V] == 0) out_opt_v128.store(offset - current_idx[PARAMETER_OPT_V]);
if constexpr (use_master_weights) {
if (next_idx[PARAMETER_MASTER] == 0) out_master128.store(offset - current_idx[PARAMETER_MASTER]);
}

for (int n = 0; n < TT::NUM_TYPES_PARAM; n++) {
current_idx[n] = next_idx[n];
}
}
idx += stride;
}
out_param128.update_absmax(1);
return idx;
}

template <typename Tp>
void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream) {
int block_size = 512; // must match block size of adamw_update so that RNG also matches
int num_blocks = CEIL_DIV(num_parameters, block_size);
init_from_master_kernel<<<dim3(num_blocks, num_slices), block_size, 0, stream>>>
(params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed);
cudaCheck(cudaGetLastError());
template <bool use_master_weights=true>
__global__ void adamw_update_everything(int num_params_tensors, int start_tensor, int last_tensor, unsigned int seed , unsigned int shard_idx,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction,
float eps, float weight_decay, float grad_scale, int t) {
// ...
constexpr size_t block_size = 64;
constexpr size_t iteration_size = 16; // todo - this causes sparsit and bank clashes for FP32/BF16 loads/stores
size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size);
unsigned int stride = gridDim.x * blockDim.x * iteration_size;

int opt_m_spec_id = 2 * num_params_tensors;
int last_opt_m_id = opt_m_spec_id + last_tensor; // opt_m is sharded with ZeRO 1 so use it as reference
opt_m_spec_id += start_tensor - 1; // -1 to compensate for the increment at the start of the loop below

while (true) {
size_t current_end;
do {
opt_m_spec_id++;
if (opt_m_spec_id > last_opt_m_id) return; // done!

// on A100+ we can prefetch 256B (32 values) into the L2, on older GPUs just use a regular load
#if __CUDA_ARCH__ < 800
current_end = tensor_end_element_ptr[opt_m_spec_id];
#else
asm("ld.global.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id));
#endif
} while (idx >= current_end);

int spec_id = opt_m_spec_id - 2 * num_params_tensors;
size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element;

TensorSpec param_spec = tensor_specs_ptr[spec_id];
TensorGPU<floatX> grad_tensor = tensor_specs_ptr[spec_id + 1*num_params_tensors];
TensorGPU<float> opt_m_tensor = tensor_specs_ptr[spec_id + 2*num_params_tensors];
TensorGPU<float> opt_v_tensor = tensor_specs_ptr[spec_id + 3*num_params_tensors];
TensorGPU<float> master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_m_tensor;

float wd = (param_spec.tensor_flags & TENSOR_2D) ? weight_decay : 0.0f;

if (param_spec.data_type == DType::FP32) {
idx = adamw_update_part<use_master_weights>((TensorGPU<float>)param_spec,
idx, current_start, current_end, stride, seed, shard_idx,
grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor,
learning_rate, beta1, beta2, beta1_correction, beta2_correction,
eps, wd, grad_scale, t);
} else if (param_spec.data_type == DType::BF16) {
idx = adamw_update_part<use_master_weights>((TensorGPU<__nv_bfloat16>)param_spec,
idx, current_start, current_end, stride, seed, shard_idx,
grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor,
learning_rate, beta1, beta2, beta1_correction, beta2_correction,
eps, wd, grad_scale, t);
} else if (param_spec.data_type == DType::FP8E4M3) {
idx = adamw_update_part<use_master_weights>((TensorGPU<__nv_fp8_e4m3>)param_spec,
idx, current_start, current_end, stride, seed, shard_idx,
grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor,
learning_rate, beta1, beta2, beta1_correction, beta2_correction,
eps, wd, grad_scale, t);
} else {
assert(false); // TODO (no FP16 to avoid compile time increase but trivial to add here)
}
}
}
Loading