From 9d0a2c69072550809133dccffc646c5a08187607 Mon Sep 17 00:00:00 2001 From: ademeure Date: Sat, 31 Aug 2024 22:57:12 +0000 Subject: [PATCH 01/27] doesn't work but at least it runs (with a loss of -1.0)... --- llmc/cuda_common.h | 3 + train_gpt2.cu | 776 ++++++++++++++++++++++++--------------------- 2 files changed, 422 insertions(+), 357 deletions(-) diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 006ad3010..0baade163 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -82,13 +82,16 @@ enum PrecisionMode { #if defined(ENABLE_FP32) typedef float floatX; #define PRECISION_MODE PRECISION_FP32 +#define DTYPE_FLOATX DType::FP32 // use fp16 (note: this may require gradient scaler, currently not implemented!) #elif defined(ENABLE_FP16) typedef half floatX; #define PRECISION_MODE PRECISION_FP16 +#define DTYPE_FLOATX DType::FP16 #else // Default to bfloat16 typedef __nv_bfloat16 floatX; #define PRECISION_MODE PRECISION_BF16 +#define DTYPE_FLOATX DType::BF16 #endif // ---------------------------------------------------------------------------- diff --git a/train_gpt2.cu b/train_gpt2.cu index 16f801387..63283ebcb 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,6 +1,8 @@ /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ +bool UNIQUE_TENSOR_MEMORY = false; + #include #include #include @@ -84,6 +86,27 @@ constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // ---------------------------------------------------------------------------- // GPT-2 model definition +enum TT : uint8_t { + PARAMETER=0, PARAMETER_GRADIENT, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, + MULTIUSE, ACTIVATION, ACTIVATION_GRADIENT, + DEFAULT, COUNT=DEFAULT +}; + +typedef struct { + int wte, wpe, lnfw, lnfb; // not per layer + int ln1w, ln1b, qkvw, qkvb, attprojw, attprojb, ln2w, ln2b, fcw, fcb, fcprojw, fcprojb; // per layer +} ParameterTensors; + +typedef struct { + int encoded, lnf, lnf_mean, lnf_rstd, losses, output; // not per layer + int ln1, ln1_mean, ln1_rstd, atty, att, attproj, residual2, ln2, ln2_mean, ln2_rstd, fch, fch_gelu, fcproj, residual3, qkvr; // per layer +} ActivationTensors; + +typedef struct { + int bt4c; // (B, T, 4*C) + int btc; // (B, T, C) +} MultiuseTensors; + typedef struct { int max_seq_len; // max sequence length, e.g. 1024 int vocab_size; // vocab size, e.g. 50257 @@ -93,134 +116,173 @@ typedef struct { int channels; // number of channels, e.g. 768 } GPT2Config; -// the parameters of the model -constexpr const int NUM_PARAMETER_TENSORS = 16; typedef struct { - floatX* wte; // (V, C) - floatX* wpe; // (maxT, C) - floatX* ln1w; // (L, C) - floatX* ln1b; // (L, C) - floatX* qkvw; // (L, 3*C, C) - floatX* qkvb; // (L, 3*C) - floatX* attprojw; // (L, C, C) - floatX* attprojb; // (L, C) - floatX* ln2w; // (L, C) - floatX* ln2b; // (L, C) - floatX* fcw; // (L, 4*C, C) - floatX* fcb; // (L, 4*C) - floatX* fcprojw; // (L, C, 4*C) - floatX* fcprojb; // (L, C) - floatX* lnfw; // (C) - floatX* lnfb; // (C) -} ParameterTensors; -static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!"); - -void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) { - size_t Vp = config.padded_vocab_size; - size_t C = config.channels; - size_t maxT = config.max_seq_len; - size_t L = config.num_layers; - param_sizes[0] = Vp * C; // wte - param_sizes[1] = maxT * C; // wpe - param_sizes[2] = L * C; // ln1w - param_sizes[3] = L * C; // ln1b - param_sizes[4] = L * (3 * C) * C; // qkvw - param_sizes[5] = L * (3 * C); // qkvb - param_sizes[6] = L * C * C; // attprojw - param_sizes[7] = L * C; // attprojb - param_sizes[8] = L * C; // ln2w - param_sizes[9] = L * C; // ln2b - param_sizes[10] = L * (4 * C) * C; // fcw - param_sizes[11] = L * (4 * C); // fcb - param_sizes[12] = L * C * (4 * C); // fcprojw - param_sizes[13] = L * C; // fcprojb - param_sizes[14] = C; // lnfw - param_sizes[15] = C; // lnfb - - // populate the parameter sizes in bytes (all the same for now, keeping for future use) - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - param_sizeof[i] = sizeof(floatX); - } -} + GPT2Config config; + ParameterTensors params; + ParameterTensors params_grads; + ParameterTensors params_master; + ParameterTensors params_opt_m; + ParameterTensors params_opt_v; + ActivationTensors acts; + ActivationTensors acts_grads; + MultiuseTensors multiuse; -// allocate memory for the parameters and point the individual tensors to the right places -void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) { - // calculate the total number of parameters and bytes across all tensors - size_t num_parameters_bytes = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - num_parameters_bytes += param_elements[i] * param_sizeof[i]; - } - // malloc all parameters all at once on the device + size_t num_parameters; + size_t num_parameters_bytes; + + char* gpu_mem; void* params_memory; - cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes)); - // assign all the tensors their place in the array - floatX** ptrs[] = { - ¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb, - ¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb, - ¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb - }; - char* params_memory_iterator = (char*)params_memory; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - *(ptrs[i]) = (floatX*)params_memory_iterator; - params_memory_iterator += param_elements[i] * param_sizeof[i]; - } - return params_memory; -} + void* grads_memory; + float* m_memory; + float* v_memory; + float* master_weights; + + // other run state configuration + int batch_size = 0; // the batch size (B) of current forward pass + int seq_len = 0; // the sequence length (T) of current forward pass + int* inputs = NULL; // the input tokens for the current forward pass + int* targets = NULL; // the target tokens for the current forward pass + float mean_loss = -1.0f; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps + float* accumulated_mean_loss = NULL; // GPU buffer used to accumulate loss across micro-steps + float* cpu_losses = NULL; // CPU buffer to copy the losses to, allocated with cudaMallocHost + bool init_state = true; // set to true if master weights need to be initialized + int use_master_weights = 1; // keep master weights copy in float for optim update? 0|1 + int gelu_fusion = 0; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) + int recompute = 0; // recompute gelu | layernorm forward during model backward? 0|1|2 + // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? + int* workload_indices = NULL; // encoder_backward, B*T*num_c_groups (int) + int4* bucket_info = NULL; // encoder_backward, B*T*num_c_groups (int4) - size for worst case + + unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. + unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights +} GPT2; -constexpr int NUM_ACTIVATION_TENSORS = 21; typedef struct { - floatX* encoded; // (B, T, C) - floatX* ln1; // (L, B, T, C) - float* ln1_mean; // (L, B, T) - float* ln1_rstd; // (L, B, T) - floatX* atty; // (L, B, T, C) - // cuDNN saves only some statistics information -#if ENABLE_CUDNN - float* att; // (L, B, NH, T) -#else - floatX* att; // (L, B, NH, T, T) -#endif + char name[16]; + size_t offset; // into base pointer + size_t num_elements; // per shard + size_t num_shards; + DType data_type; + TT tensor_type; +} TensorSpec; + +TensorSpec tensor_specs[16*1024]; +size_t num_tensor_specs = 0; +size_t current_tensor_offset = 0; +TT current_tensor_type = TT::PARAMETER; +size_t tensors_start[TT::COUNT] = {0}; +size_t tensors_bytes[TT::COUNT] = {0}; +size_t tensors_elements[TT::COUNT] = {0}; + +int add_tensor_spec(const char* name, size_t num_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, TT tensor_type=TT::DEFAULT) { + assert(num_tensor_specs < 16*1024); + assert((num_elements % num_shards) == 0); + TensorSpec* spec = &tensor_specs[num_tensor_specs++]; + + strncpy(spec->name, name, 16); + spec->num_elements = num_elements / num_shards; + spec->num_shards = num_shards; + spec->data_type = data_type; + spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; + tensors_elements[spec->tensor_type] += spec->num_elements; + + + if (copy_offset_from >= 0) { + spec->offset = tensor_specs[copy_offset_from].offset; + size_t original_tensor_bytes = tensor_specs[copy_offset_from].num_elements * sizeof_dtype(tensor_specs[copy_offset_from].data_type); + size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); + assert(new_tensor_bytes <= original_tensor_bytes); + } else { + spec->offset = current_tensor_offset; + current_tensor_offset += spec->num_elements * sizeof_dtype(data_type); - floatX* residual2; // (L, B, T, C) - floatX* ln2; // (L, B, T, C) - float* ln2_mean; // (L, B, T) - float* ln2_rstd; // (L, B, T) - floatX* fch; // (L, B, T, 4*C) - floatX* fch_gelu; // (L, B, T, 4*C) - floatX* residual3; // (L, B, T, C) - floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms - float* lnf_mean; // (B, T) - float* lnf_rstd; // (B, T) - float* losses; // (B, T), will be accumulated in micro-steps - // adding these two compared to the CPU .c code, needed for attention kernel as buffers - floatX* qkvr; // (L, B, T, 3*C) - // in inference mode, this buffer will store the logits - // in training mode, this buffer will contain the *gradients* of the logits. - // during the processing of transformer blocks, we will also use this as a - // general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C), - // (B, NH, T, T), and (B, T, V) shaped tensors. - floatX* output; - - // some additional scratch buffers - floatX* scratch_bt4c; // (B, T, 4*C) - floatX* scratch_btc; // (B, T, C) -} ActivationTensors; + tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); + if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { + tensors_start[spec->tensor_type] = spec->offset; + } + } + return num_tensor_specs - 1; +} +int add_layer_specs(int num_layers, const char* name, size_t num_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, bool copy_per_layer=false, TT tensor_type=TT::DEFAULT) { + int first_tensor_id = num_tensor_specs; + for (int l = 0; l < num_layers; l++) { + char layer_name[16]; + assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); + add_tensor_spec(num_layers > 1 ? layer_name : name, num_elements, num_shards, data_type, copy_offset_from, tensor_type); + if (copy_per_layer) { + copy_offset_from++; + } + } + return first_tensor_id; +} -struct TensorSpec { - void** ptr; - size_t size; - DType type; -}; +#define TENSOR_SPECS(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, dtype) +#define TENSOR_SPECS_LOWP(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, dtype_lowp) +#define TENSOR_SPECS_FP32(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, DType::FP32) // todo - won't work loading model +void gpt2_allocate(GPT2 *model) { + size_t Vp = model->config.padded_vocab_size; + size_t C = model->config.channels; + size_t maxT = model->config.max_seq_len; + size_t L = model->config.num_layers; + size_t B = model->batch_size; + size_t T = model->seq_len; + size_t NH = model->config.num_heads; + size_t output_size = B*T * max(4*C, max(NH*T, Vp)); + size_t BTC = B*T*C; + + size_t shards = 1; + int num_gpu = multi_gpu_config.num_processes; + int shards_opt = (multi_gpu_config.zero_stage >= 1) ? num_gpu : 1; + int shards_grad = (multi_gpu_config.zero_stage >= 2) ? num_gpu : 1; + + // 1) parameters & optimizer state + for (int t = TT::PARAMETER; t <= TT::PARAMETER_OPT_V; t++) { + DType dtype = (t <= TT::PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; + DType dtype_lowp = (t <= TT::PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future + + current_tensor_type = (TT)t; + ParameterTensors* spec; + switch (t) { + case TT::PARAMETER: spec = &model->params; shards = 1; break; + case TT::PARAMETER_GRADIENT: spec = &model->params_grads; shards = shards_grad; break; + case TT::PARAMETER_MASTER: spec = &model->params_master; shards = shards_opt; break; + case TT::PARAMETER_OPT_M: spec = &model->params_opt_m; shards = shards_opt; break; + case TT::PARAMETER_OPT_V: spec = &model->params_opt_v; shards = shards_opt; break; + } + if (t == PARAMETER_MASTER && !model->use_master_weights) { + continue; + } + + TENSOR_SPECS (wte, 1, Vp * C); + TENSOR_SPECS (wpe, 1, maxT * C); + TENSOR_SPECS (ln1w, L, C); + TENSOR_SPECS (ln1b, L, C); + TENSOR_SPECS_LOWP(qkvw, L, 3 * C * C); + TENSOR_SPECS (qkvb, L, 3 * C); + TENSOR_SPECS_LOWP(attprojw, L, C * C); + TENSOR_SPECS (attprojb, L, C); + TENSOR_SPECS (ln2w, L, C); + TENSOR_SPECS (ln2b, L, C); + TENSOR_SPECS_LOWP(fcw, L, 4 * C * C); + TENSOR_SPECS_LOWP(fcb, L, 4 * C); + TENSOR_SPECS_LOWP(fcprojw, L, 4 * C * C); + TENSOR_SPECS (fcprojb, L, C); + TENSOR_SPECS (lnfw, 1, C); + TENSOR_SPECS (lnfb, 1, C); + } -#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)}; + // 2) multiuse & scratch tensors + if (UNIQUE_TENSOR_MEMORY) { + model->multiuse.bt4c = -1; + model->multiuse.btc = -1; + } else { + model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, TT::MULTIUSE); + model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, TT::MULTIUSE); + } -void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) { - size_t Vp = config.padded_vocab_size; - size_t L = config.num_layers; - size_t NH = config.num_heads; - size_t C = config.channels; + /* tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); @@ -251,132 +313,138 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); -} + */ -void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) { - size_t bytes = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - bytes += tensors[i].size * sizeof_dtype(tensors[i].type); - } - printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024))); + // 3) activations + current_tensor_type = TT::ACTIVATION; + ActivationTensors* spec = &model->acts; + DType dtype = DTYPE_FLOATX; + DType dtype_lowp = DTYPE_FLOATX; // todo FP8 + shards = 1; + + TENSOR_SPECS (encoded, 1, BTC); + TENSOR_SPECS (lnf, 1, BTC); + TENSOR_SPECS_FP32(lnf_mean, 1, B*T); + TENSOR_SPECS_FP32(lnf_rstd, 1, B*T); + TENSOR_SPECS_FP32(losses, 1, B*T); + TENSOR_SPECS (output, 1, output_size); + TENSOR_SPECS_FP32(ln1_mean, L, B*T); + TENSOR_SPECS_FP32(ln1_rstd, L, B*T); + TENSOR_SPECS (atty, L, BTC); + TENSOR_SPECS (residual2, L, BTC); + TENSOR_SPECS_FP32(ln2_mean, L, B*T); + TENSOR_SPECS_FP32(ln2_rstd, L, B*T); + TENSOR_SPECS (residual3, L, BTC); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC); + TENSOR_SPECS (qkvr, L, 3 * BTC); + #ifdef ENABLE_CUDNN + TENSOR_SPECS_FP32(att, L, NH * B * T); + #else + TENSOR_SPECS (att, L, NH * B * T * T); + #endif - void* acts_memory; - cudaCheck(cudaMalloc((void**)&acts_memory, bytes)); + if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS_LOWP(fcproj, L, BTC); + TENSOR_SPECS_LOWP(attproj, L, BTC); + } else { + spec->fcproj = add_layer_specs(L, "fcproj", BTC, L, dtype_lowp, model->multiuse.btc); + spec->attproj = add_layer_specs(L, "attproj", BTC, L, dtype_lowp, model->multiuse.btc); + } - // cudaMalloc does not guarantee initial memory values so we memset the allocation here - // this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed - // todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer - cudaCheck(cudaMemset(acts_memory, 0, bytes)); + if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS(ln1, L, BTC); + TENSOR_SPECS(ln2, L, BTC); + TENSOR_SPECS(fch_gelu, L, 4 * BTC); + } else if (model->recompute < 2) { + TENSOR_SPECS(ln1, L, BTC); + TENSOR_SPECS(ln2, L, BTC); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, L, dtype_lowp, model->acts.output); + } else { + spec->ln1 = add_layer_specs(L, "ln1", BTC, L, dtype, model->acts.output); + spec->ln2 = add_layer_specs(L, "ln2", BTC, L, dtype, model->acts.output); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, L, dtype_lowp, model->acts.output); + } - char* acts_memory_iterator = (char*)acts_memory; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - // extra protection so we don't accidentally use an empty buffer - if(tensors[i].size == 0) { - *(tensors[i].ptr) = NULL; - }else { - *(tensors[i].ptr) = acts_memory_iterator; - acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type); - } + // 4) activation gradients + current_tensor_type = TT::ACTIVATION_GRADIENT; + dtype_lowp = DTYPE_FLOATX; // todo FP8 + shards = 1; + + if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS(output, 1, B*T * max(3*C, max(NH*T, Vp))); + TENSOR_SPECS(lnf, 1, BTC); + TENSOR_SPECS(ln1, L, BTC); + TENSOR_SPECS(atty, L, BTC); + TENSOR_SPECS(residual2, L, BTC); + TENSOR_SPECS(ln2, L, BTC); + TENSOR_SPECS(fch, L, 4 * BTC); + TENSOR_SPECS(fch_gelu, L, 4 * BTC); + TENSOR_SPECS(residual3, L, BTC); + TENSOR_SPECS(qkvr, L, 3 * BTC); + } else { + spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->acts.output); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.btc); + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, model->multiuse.btc); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, model->multiuse.btc); + spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, model->multiuse.btc); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, model->multiuse.btc); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c); + spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, model->multiuse.btc); + spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c); } - return acts_memory; -} -typedef struct { - GPT2Config config; - // the weights of the model, and their sizes - ParameterTensors params; - size_t param_elements[NUM_PARAMETER_TENSORS]; - size_t param_sizeof[NUM_PARAMETER_TENSORS]; - void* params_memory; - size_t num_parameters; - size_t num_parameters_bytes; - // gradients of the weights - ParameterTensors grads; - void* grads_memory; - // buffers for the AdamW optimizer - float* m_memory; - float* v_memory; - float* master_weights; // is NULL unless fp32 weights is enabled. - // the activations of the model, and their sizes - ActivationTensors acts; - TensorSpec acts_specs[NUM_ACTIVATION_TENSORS]; - void* acts_memory; - // other run state configuration - int batch_size; // the batch size (B) of current forward pass - int seq_len; // the sequence length (T) of current forward pass - int* inputs; // the input tokens for the current forward pass - int* targets; // the target tokens for the current forward pass - float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps - float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps - float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost - unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. - unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights - int use_master_weights; // keep master weights copy in float for optim update? 0|1 - bool init_state; // set to true if master weights need to be initialized - int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) - int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2 - // todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch? - int* workload_indices; // encoder_backward, B*T*num_c_groups (int) - int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case -} GPT2; + // allocate a single huge GPU buffer for all the tensors + printf("Current tensor offset in MiB: %zu", current_tensor_offset / (1024*1024)); + cudaCheck(cudaMalloc(&model->gpu_mem, current_tensor_offset)); -void gpt2_init_common(GPT2 *model) { - // common inits outside of the model weights - // memory lazily initialized in forward() - model->acts_memory = NULL; - model->inputs = NULL; - model->targets = NULL; - model->accumulated_mean_loss = NULL; - model->cpu_losses = NULL; - // the B,T params are determined and set, fixed on first batch in forward() - model->batch_size = 0; - model->seq_len = 0; - model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward() - model->params_memory = NULL; - // memory lazily initialized in backward() - model->grads_memory = NULL; - model->workload_indices = NULL; // on cpu, for encoder_backward - model->bucket_info = NULL; // on cpu, for encoder_backward - // memory lazily initialized in update() - model->m_memory = NULL; - model->v_memory = NULL; - model->master_weights = NULL; - // other default settings - model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding - model->use_master_weights = 1; // safe default: do keep master weights in fp32 - model->init_state = true; - model->recompute = 1; // good default: recompute gelu but not layernorm - model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main()) -} + //initialise helper variables + model->num_parameters = tensors_elements[TT::PARAMETER]; + model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; -void gpt2_allocate_weights(GPT2 *model) { - // fill in all the parameter tensor dimensions and types - fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config); - model->num_parameters = 0; - model->num_parameters_bytes = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - model->num_parameters += model->param_elements[i]; - model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i]; + model->params_memory = (void*)(model->gpu_mem + tensors_start[TT::PARAMETER]); + model->grads_memory = (void*)(model->gpu_mem + tensors_start[TT::PARAMETER_GRADIENT]); + model->m_memory = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_OPT_M]); + model->v_memory = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_OPT_V]); + if (model->use_master_weights) { + model->master_weights = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_MASTER]); } - // create memory for model parameters on the device - assert(model->params_memory == nullptr); - model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); -} - -void gpt2_allocate_state(GPT2 *model, int B, int T) { - printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024))); - assert(model->grads_memory == nullptr); - model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof); - - // record the current B,T as well - model->batch_size = B; - model->seq_len = T; + // printf gpu_mem and params_memory + printf("gpu_mem: %p\n", model->gpu_mem); + printf("params_memory: %p\n", model->params_memory); + printf("number of parameters: %zu\n", model->num_parameters); + printf("number of parameters bytes: %zu\n", model->num_parameters_bytes); + + // parameter gradient bytes + size_t param_grad_bytes = tensors_bytes[TT::PARAMETER_GRADIENT]; + printf("number of parameter gradient bytes: %zu\n", param_grad_bytes / (1024*1024)); + // number of master weight bytes + size_t master_weight_bytes = tensors_bytes[TT::PARAMETER_MASTER]; + printf("number of master weight bytes: %zu\n", master_weight_bytes / (1024*1024)); + // opt state m + size_t m_bytes = tensors_bytes[TT::PARAMETER_OPT_M]; + printf("number of m bytes: %zu\n", m_bytes / (1024*1024)); + // opt state v + size_t v_bytes = tensors_bytes[TT::PARAMETER_OPT_V]; + printf("number of v bytes: %zu\n", v_bytes / (1024*1024)); + // number of act bytes + size_t act_bytes = tensors_bytes[TT::ACTIVATION]; + printf("number of act bytes: %zu\n", act_bytes / (1024*1024)); + // number of act gradient bytes + size_t act_grad_bytes = tensors_bytes[TT::ACTIVATION_GRADIENT]; + printf("number of act grad bytes: %zu\n", act_grad_bytes / (1024*1024)); + // number of multiuse bytes + size_t multiuse_bytes = tensors_bytes[TT::MULTIUSE]; + printf("number of multiuse bytes: %zu\n", multiuse_bytes / (1024*1024)); + + printf("number of act+actgrad+multiuse bytes: %zu\n", (multiuse_bytes + act_bytes + act_grad_bytes) / (1024*1024)); + + // ======================= + // allocate_state stuff + // ======================= // allocate the space - fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute); - model->acts_memory = malloc_and_point_activations(model->acts_specs); - // also create memory for caching inputs and targets cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); @@ -388,33 +456,23 @@ void gpt2_allocate_state(GPT2 *model, int B, int T) { model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); - size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for - printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); - assert(model->m_memory == nullptr); - assert(model->v_memory == nullptr); - cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); - cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); - - if (model->use_master_weights == 1) { - assert(model->master_weights == nullptr); - printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); - cudaCheck(cudaMalloc((void**) &model->master_weights, shard_num_parameters * sizeof(float))); - } - size_t free, total; cudaCheck(cudaMemGetInfo(&free, &total)); printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); // give an estimate of the maximum batch size - size_t bytes_per_sequence = 0; - for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { - bytes_per_sequence += model->acts_specs[i].size * sizeof_dtype(model->acts_specs[i].type) / B; - } + size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATION] / B; + bytes_per_sequence += tensors_bytes[TT::ACTIVATION_GRADIENT] / B; + bytes_per_sequence += tensors_bytes[TT::MULTIUSE] / B; // pessimistic? printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); } +void gpt2_init_common(GPT2 *model) { + // other default settings + model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding +} + void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { // write the model to a checkpoint file printf0("Writing model to %s\n", checkpoint_path); @@ -433,8 +491,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - device_to_file(model_file, model->params_memory, model->num_parameters_bytes, - IO_BUF_SIZE, main_stream); + device_to_file(model_file, model->params_memory, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); // close file, we're done fcloseCheck(model_file); } @@ -490,12 +547,10 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w model->config.channels = model_header[6]; model->config.padded_vocab_size = model_header[7]; - // allocate memory for the model parameters - gpt2_allocate_weights(model); + gpt2_allocate(model); // read in the parameters if weight_init is true if (weight_init) { - assert(model->params_memory != NULL); file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); } fcloseCheck(model_file); @@ -570,7 +625,7 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { model->config.vocab_size = 50257; model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency - gpt2_allocate_weights(model); + gpt2_allocate(model); // allocate and random init the memory for all the parameters with GPT-2 schema // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5) @@ -583,6 +638,7 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); // we have to init all these tensors exactly in the order that PyTorch initializes them // so that we can match them up and get correctness and exactly the same initial conditions + /* size_t L = model->config.num_layers; size_t offset = 0; for (int l = 0; l < L; l++) { @@ -623,12 +679,15 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { offset += model->param_elements[i]; } } - + */ // copy them to GPU cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); } +#define GPU_X(x) (floatX*)((char*)model->gpu_mem + tensor_specs[x].offset) +#define GPU_F32(x) (float*)((char*)model->gpu_mem + tensor_specs[x].offset) + // propagate inputs through the network to produce logits. // right now, this function is fully synchronous with the host void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { @@ -636,12 +695,6 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // we must be careful and use size_t instead of int, otherwise // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. - // ensure the model was initialized or error out - if (model->params_memory == NULL) { - printf("Error: model was not initialized properly.\n"); - exit(EXIT_FAILURE); - } - // convenience parameters const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; @@ -665,42 +718,42 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // forward pass ParameterTensors params = model->params; // for brevity ActivationTensors acts = model->acts; - encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0] + encoder_forward(GPU_X(acts.encoded), model->inputs, GPU_X(params.wte), GPU_X(params.wpe), B, T, C, main_stream); // encoding goes into residual[0] // first layernorm isn't fused - layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream); + layernorm_forward(GPU_X((model->recompute < 2) ? acts.ln1 : acts.lnf), GPU_F32(acts.ln1_mean), GPU_F32(acts.ln1_rstd), GPU_X(acts.encoded), GPU_X(params.ln1w), GPU_X(params.ln1b), B, T, C, main_stream); for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); - floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + floatX* residual = GPU_X(l == 0 ? acts.encoded : (acts.residual3 + l - 1)); // get the pointers of the weights for this layer - floatX* l_qkvw = params.qkvw + l * 3*C * C; - floatX* l_qkvb = params.qkvb + l * 3*C; - floatX* l_attprojw = params.attprojw + l * C * C; - floatX* l_attprojb = params.attprojb + l * C; - floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcb = params.fcb + l * 4*C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; - floatX* l_fcprojb = params.fcprojb + l * C; + floatX* l_qkvw = GPU_X(params.qkvw + l); + floatX* l_qkvb = GPU_X(params.qkvb + l); + floatX* l_attprojw = GPU_X(params.attprojw); + floatX* l_attprojb = GPU_X(params.attprojb); + floatX* l_ln2w = GPU_X(params.ln2w); + floatX* l_ln2b = GPU_X(params.ln2b); + floatX* l_fcw = GPU_X(params.fcw); + floatX* l_fcb = GPU_X(params.fcb); + floatX* l_fcprojw = GPU_X(params.fcprojw); + floatX* l_fcprojb = GPU_X(params.fcprojb); // get the pointers of the activations for this layer - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; - floatX* l_atty = acts.atty + l * B * T * C; - floatX* l_residual2 = acts.residual2 + l * B * T * C; - floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch = acts.fch + l * B * T * 4*C; + floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + l : acts.lnf); + floatX* l_qkvr = GPU_X(acts.qkvr + l); + floatX* l_atty = GPU_X(acts.atty + l); + floatX* l_residual2 = GPU_X(acts.residual2 + l); + floatX* l_ln2 = GPU_X((model->recompute < 2) ? acts.ln2 + l : acts.lnf); + float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); + float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); + floatX* l_fch = GPU_X(acts.fch + l); // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; - floatX* l_residual3 = acts.residual3 + l * B * T * C; - floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc. + floatX* l_fch_gelu = GPU_X((model->recompute < 1) ? acts.fch_gelu + l : acts.fch_gelu); + floatX* l_residual3 = GPU_X(acts.residual3 + l); + floatX* scratch = GPU_X(acts.output); // used for non-cudnn attention, fcproj, attproj, etc. // now do the forward pass #ifdef ENABLE_CUDNN @@ -708,7 +761,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); #else - floatX* l_att = acts.att + l * B * NH * T * T; + floatX* l_att = GPU_X(acts.att + l); if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); } @@ -724,21 +777,21 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); // OK, fusion across blocks. if(l+1 != L) { - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T; - float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T; - const floatX* l_ln1w = params.ln1w + (l + 1) * C; - const floatX* l_ln1b = params.ln1b + (l + 1) * C; + floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + (l + 1) : acts.lnf); + float* l_ln1_mean = GPU_F32(acts.ln1_mean + (l + 1)); + float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + (l + 1)); + const floatX* l_ln1w = GPU_X(params.ln1w + (l + 1)); + const floatX* l_ln1b = GPU_X(params.ln1b + (l + 1)); fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b, B * T, C, main_stream); } else { - fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch, - params.lnfw, params.lnfb, + fused_residual_forward5(l_residual3, GPU_X(acts.lnf), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), l_residual2, scratch, + GPU_X(params.lnfw), GPU_X(params.lnfb), B * T, C, main_stream); } } - matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_forward_cublaslt(GPU_X(acts.output), GPU_X(acts.lnf), GPU_X(params.wte), NULL, B, T, C, Vp, main_stream); cudaCheck(cudaDeviceSynchronize()); } @@ -760,11 +813,11 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(GPU_F32(acts.losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream); - cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(GPU_X(acts.output), GPU_F32(acts.losses), dloss, model->targets, B, T, V, Vp, False, main_stream); + cudaCheck(cudaMemcpy(model->cpu_losses, GPU_F32(acts.losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -774,6 +827,8 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B } void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { + return; // debugging forward only first... + if(model->grads_memory == nullptr) { fprintf(stderr, "Need to allocate gradients before backward"); exit(EXIT_FAILURE); @@ -785,8 +840,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // there are currently two state vars during the gradient accumulation inner loop: // 1) the losses accumulate += into acts.losses, reset here // 2) the gradients accumulate += into grads_memory, reset here - cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); - cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream)); + cudaCheck(cudaMemsetAsync(GPU_X(model->acts.losses), 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(model->grads_memory, 0, tensors_bytes[TT::PARAMETER_GRADIENT], main_stream)); } // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow @@ -799,7 +854,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t C = model->config.channels; ParameterTensors params = model->params; // for brevity - ParameterTensors grads = model->grads; + ParameterTensors grads = model->params_grads; ActivationTensors acts = model->acts; // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier @@ -807,27 +862,27 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + fused_classifier(GPU_X(acts.output), GPU_F32(acts.losses), dloss, model->targets, B, T, V, Vp, True, main_stream); // backward pass: go in the reverse order of the forward pass, and call backward() functions // reset residual stream gradients (put here to work with gradient accumulation) - floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass + floatX* dresidual = GPU_X(model->multiuse.btc); // the main buffer holding the gradient in the backward pass cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = (float*)acts.output; - floatX* scratchX = (floatX*)acts.output; + float* scratchF = GPU_F32(acts.output); + floatX* scratchX = GPU_X(acts.output); // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + matmul_backward(GPU_X(model->multiuse.bt4c), GPU_X(grads.wte), NULL, GPU_X(acts.output), GPU_X(acts.lnf), GPU_X(params.wte), NULL, B, T, C, Vp, main_stream); // backward the final layernorm - floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 - layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream); + floatX* residual = GPU_X(acts.residual3 + (L-1)); // last residual is in residual3 + layernorm_backward(dresidual, GPU_X(grads.lnfw), GPU_X(grads.lnfb), scratchF, GPU_X(model->multiuse.bt4c), residual, GPU_X(params.lnfw), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), B, T, C, main_stream); // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic // scratch for backward computations @@ -837,47 +892,47 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int for (int l = L-1; l >= 0; l--) { NvtxRange layer_range("Layer", l); - residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C; + residual = GPU_X(l == 0 ? acts.encoded : acts.residual3 + (l-1)); // get the pointers of the weights for this layer - floatX* l_ln1w = params.ln1w + l * C; - floatX* l_ln1b = params.ln1b + l * C; - floatX* l_qkvw = params.qkvw + l * 3*C * C; - floatX* l_attprojw = params.attprojw + l * C * C; - floatX* l_ln2w = params.ln2w + l * C; - floatX* l_ln2b = params.ln2b + l * C; - floatX* l_fcw = params.fcw + l * 4*C * C; - floatX* l_fcprojw = params.fcprojw + l * C * 4*C; + floatX* l_ln1w = GPU_X(params.ln1w + l); + floatX* l_ln1b = GPU_X(params.ln1b + l); + floatX* l_qkvw = GPU_X(params.qkvw + l); + floatX* l_attprojw = GPU_X(params.attprojw + l); + floatX* l_ln2w = GPU_X(params.ln2w + l); + floatX* l_ln2b = GPU_X(params.ln2b + l); + floatX* l_fcw = GPU_X(params.fcw + l); + floatX* l_fcprojw = GPU_X(params.fcprojw + l); // get the pointers of the gradients of the weights for this layer - floatX* dl_ln1w = grads.ln1w + l * C; - floatX* dl_ln1b = grads.ln1b + l * C; - floatX* dl_qkvw = grads.qkvw + l * 3*C * C; - floatX* dl_qkvb = grads.qkvb + l * 3*C; - floatX* dl_attprojw = grads.attprojw + l * C * C; - floatX* dl_attprojb = grads.attprojb + l * C; - floatX* dl_ln2w = grads.ln2w + l * C; - floatX* dl_ln2b = grads.ln2b + l * C; - floatX* dl_fcw = grads.fcw + l * 4*C * C; - floatX* dl_fcb = grads.fcb + l * 4*C; - floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C; - floatX* dl_fcprojb = grads.fcprojb + l * C; + floatX* dl_ln1w = GPU_X(grads.ln1w + l); + floatX* dl_ln1b = GPU_X(grads.ln1b + l); + floatX* dl_qkvw = GPU_X(grads.qkvw + l); + floatX* dl_qkvb = GPU_X(grads.qkvb + l); + floatX* dl_attprojw = GPU_X(grads.attprojw + l); + floatX* dl_attprojb = GPU_X(grads.attprojb + l); + floatX* dl_ln2w = GPU_X(grads.ln2w + l); + floatX* dl_ln2b = GPU_X(grads.ln2b + l); + floatX* dl_fcw = GPU_X(grads.fcw + l); + floatX* dl_fcb = GPU_X(grads.fcb + l); + floatX* dl_fcprojw = GPU_X(grads.fcprojw + l); + floatX* dl_fcprojb = GPU_X(grads.fcprojb + l); // get the pointers of the activations for this layer - floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf; - float* l_ln1_mean = acts.ln1_mean + l * B * T; - float* l_ln1_rstd = acts.ln1_rstd + l * B * T; - floatX* l_qkvr = acts.qkvr + l * B * T * 3*C; - floatX* l_atty = acts.atty + l * B * T * C; - floatX* l_residual2 = acts.residual2 + l * B * T * C; - floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf; - float* l_ln2_mean = acts.ln2_mean + l * B * T; - float* l_ln2_rstd = acts.ln2_rstd + l * B * T; - floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C; - floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu; + floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + l : acts.lnf); + float* l_ln1_mean = GPU_F32(acts.ln1_mean + l); + float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + l); + floatX* l_qkvr = GPU_X(acts.qkvr + l * B * T * 3*C); + floatX* l_atty = GPU_X(acts.atty + l * B * T * C); + floatX* l_residual2 = GPU_X(acts.residual2 + l); + floatX* l_ln2 = GPU_X((model->recompute < 2) ? acts.ln2 + l : acts.lnf); + float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); + float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); + floatX* l_fch_pre_gelu = GPU_X(acts.fch + l); + floatX* l_fch_gelu = GPU_X((model->recompute < 1) ? acts.fch_gelu + l : acts.fch_gelu); // get the pointers of the gradients of the activations for this layer // notice that there is no l *, because we just have a single copy, and keep // re-using this memory in every Transformer block as we calculate backward pass - floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c; + floatX* dl_bt4c = (floatX*)GPU_X(model->multiuse.bt4c); // start the backward pass for this layer if(model->recompute >= 1) { @@ -899,7 +954,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); #else - floatX* l_att = acts.att + l * B * NH * T * T; + floatX* l_att = GPU_X(acts.att + l); // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory floatX* buffer_a = l_atty; floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need @@ -934,20 +989,20 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } } - encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info, + encoder_backward(GPU_X(grads.wte), GPU_X(grads.wpe), scratchX, model->workload_indices, model->bucket_info, dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks if(last_step) { // reduce all the losses within the current GPU (across all microsteps) - global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream); + global_sum_deterministic(model->accumulated_mean_loss, GPU_X(acts.losses), B*T, main_stream); // reduce loss across GPUs to a single, final float across all microsteps and GPUs #if MULTI_GPU ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; + floatX* const pointers[] = {GPU_X(grads.wte), GPU_X(grads.wpe), GPU_X(grads.lnfw), GPU_X(grads.lnfb)}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -962,6 +1017,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // Gets the offset of a specific tensor for a specific layer in the GPT2 model // layer_id is ignored for weights that are not part of a transformer block +/* ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { // first offset our way to the parameter tensor start ptrdiff_t offset = 0; @@ -976,15 +1032,18 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te } return {offset, size}; } +*/ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); floatX* grads_memory = (floatX*)model->grads_memory; // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = (float*)model->acts.output; + float* grad_norm_squared = GPU_F32(model->acts.output); float grad_norm_squared_cpu = 0.0f; + /* + int num_slices[2] = {1, model->config.num_layers}; int max_num_block_sums = get_max_num_block_sums(num_slices, 2); if (multi_gpu_config->zero_stage == 1) { @@ -1016,12 +1075,16 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); } cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); + */ + float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); return grad_norm_cpu; } void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { + return; // debugging forward only for now + // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -1047,6 +1110,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // AdamW update // handle adamw for all the transformer blocks + /* for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { // generate a unique seed for each tensor unsigned int seed = random_u32(&model->rng_state); @@ -1107,6 +1171,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo #endif } } + */ cudaCheck(cudaDeviceSynchronize()); } @@ -1141,12 +1206,7 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { } void gpt2_free(GPT2 *model) { - cudaFreeCheck(&model->params_memory); - cudaFreeCheck(&model->grads_memory); - cudaFreeCheck(&model->m_memory); - cudaFreeCheck(&model->v_memory); - cudaFreeCheck(&model->master_weights); - cudaFreeCheck(&model->acts_memory); + cudaFreeCheck(&model->gpu_mem); cudaFreeCheck(&model->inputs); cudaFreeCheck(&model->targets); cudaFreeCheck(&model->accumulated_mean_loss); @@ -1559,6 +1619,12 @@ int main(int argc, char *argv[]) { // build the GPT-2 model GPT2 model; gpt2_init_common(&model); + model.use_master_weights = use_master_weights; + model.gelu_fusion = gelu_fusion; + model.recompute = recompute; + model.batch_size = B; + model.seq_len = T; + if (resuming == 1) { // if `-y 1` was set, then we are resuming from the latest checkpoint // if we are using master weights, we'll init them later inside load_state() @@ -1573,9 +1639,6 @@ int main(int argc, char *argv[]) { gpt_build_from_descriptor(&model, load_filename); } - model.use_master_weights = use_master_weights; - model.gelu_fusion = gelu_fusion; - model.recompute = recompute; printf0("| weight init method | %-50s |\n", resuming == 1 ? "intermediate checkpoint" : load_filename); printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf0("| vocab_size V | %-50d |\n", model.config.vocab_size); @@ -1662,7 +1725,6 @@ int main(int argc, char *argv[]) { // if we found a checkpoint to resume from, load the optimization state int step = 0; - gpt2_allocate_state(&model, B, T); if (resuming == 1) { snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); load_state(&step, &model, &train_loader, filename_buffer); @@ -1759,7 +1821,7 @@ int main(int argc, char *argv[]) { // note this is still somewhat wasteful because we don't have a KV cache! gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = model.acts.output + (t - 1) * model.config.padded_vocab_size; + floatX* logits = (floatX*)(model.gpu_mem + tensors_start[model.acts.output] + (t - 1) * model.config.padded_vocab_size); // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) From 03f31367b878c79aac8f076e163ff8d08a895f23 Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 1 Sep 2024 15:28:31 +0000 Subject: [PATCH 02/27] Everything working up to grad norm (single-gpu, no recompute) --- train_gpt2.cu | 207 +++++++++++++++++++++++++++----------------------- 1 file changed, 113 insertions(+), 94 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 63283ebcb..710515a26 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -231,6 +231,7 @@ void gpt2_allocate(GPT2 *model) { size_t NH = model->config.num_heads; size_t output_size = B*T * max(4*C, max(NH*T, Vp)); size_t BTC = B*T*C; + size_t BT = B*T; size_t shards = 1; int num_gpu = multi_gpu_config.num_processes; @@ -282,40 +283,6 @@ void gpt2_allocate(GPT2 *model) { model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, TT::MULTIUSE); } - /* - tensors[0] = TENSOR_SPEC(data->encoded, B * T * C); - // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0); - tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T); - tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T); - tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C); - #ifdef ENABLE_CUDNN - // FP32 stats tensor for cuDNN to be passed to backward pass - tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T); - #else - tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T); - #endif - tensors[6] = TENSOR_SPEC(data->residual2, L * B * T * C); - // if recompute >= 1 then we will recompute the layernorm forward activation during backward pass - tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0); - tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T); - tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T); - tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C); - // if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer - tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C); - tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C); - tensors[13] = TENSOR_SPEC(data->lnf, B * T * C); - tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T); - tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T); - tensors[16] = TENSOR_SPEC(data->losses, B * T); - tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); - tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp))); - - tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C); - tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C); - */ - - // 3) activations current_tensor_type = TT::ACTIVATION; ActivationTensors* spec = &model->acts; @@ -325,16 +292,17 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS (encoded, 1, BTC); TENSOR_SPECS (lnf, 1, BTC); - TENSOR_SPECS_FP32(lnf_mean, 1, B*T); - TENSOR_SPECS_FP32(lnf_rstd, 1, B*T); - TENSOR_SPECS_FP32(losses, 1, B*T); + TENSOR_SPECS_FP32(lnf_mean, 1, BT); + TENSOR_SPECS_FP32(lnf_rstd, 1, BT); + TENSOR_SPECS_FP32(losses, 1, BT); TENSOR_SPECS (output, 1, output_size); - TENSOR_SPECS_FP32(ln1_mean, L, B*T); - TENSOR_SPECS_FP32(ln1_rstd, L, B*T); + + TENSOR_SPECS_FP32(ln1_mean, L, BT); + TENSOR_SPECS_FP32(ln1_rstd, L, BT); TENSOR_SPECS (atty, L, BTC); TENSOR_SPECS (residual2, L, BTC); - TENSOR_SPECS_FP32(ln2_mean, L, B*T); - TENSOR_SPECS_FP32(ln2_rstd, L, B*T); + TENSOR_SPECS_FP32(ln2_mean, L, BT); + TENSOR_SPECS_FP32(ln2_rstd, L, BT); TENSOR_SPECS (residual3, L, BTC); TENSOR_SPECS_LOWP(fch, L, 4 * BTC); TENSOR_SPECS (qkvr, L, 3 * BTC); @@ -348,10 +316,12 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS_LOWP(fcproj, L, BTC); TENSOR_SPECS_LOWP(attproj, L, BTC); } else { - spec->fcproj = add_layer_specs(L, "fcproj", BTC, L, dtype_lowp, model->multiuse.btc); - spec->attproj = add_layer_specs(L, "attproj", BTC, L, dtype_lowp, model->multiuse.btc); + spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc); + spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc); } + // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward + // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { TENSOR_SPECS(ln1, L, BTC); TENSOR_SPECS(ln2, L, BTC); @@ -359,20 +329,21 @@ void gpt2_allocate(GPT2 *model) { } else if (model->recompute < 2) { TENSOR_SPECS(ln1, L, BTC); TENSOR_SPECS(ln2, L, BTC); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, L, dtype_lowp, model->acts.output); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); } else { - spec->ln1 = add_layer_specs(L, "ln1", BTC, L, dtype, model->acts.output); - spec->ln2 = add_layer_specs(L, "ln2", BTC, L, dtype, model->acts.output); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, L, dtype_lowp, model->acts.output); + spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->multiuse.btc); // todo - not OK for backwards + spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->multiuse.btc); // todo - not OK for backwards + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); } // 4) activation gradients current_tensor_type = TT::ACTIVATION_GRADIENT; + spec = &model->acts_grads; dtype_lowp = DTYPE_FLOATX; // todo FP8 shards = 1; if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(output, 1, B*T * max(3*C, max(NH*T, Vp))); + TENSOR_SPECS(output, 1, BT * max(3*C, max(NH*T, Vp))); TENSOR_SPECS(lnf, 1, BTC); TENSOR_SPECS(ln1, L, BTC); TENSOR_SPECS(atty, L, BTC); @@ -396,8 +367,9 @@ void gpt2_allocate(GPT2 *model) { } // allocate a single huge GPU buffer for all the tensors - printf("Current tensor offset in MiB: %zu", current_tensor_offset / (1024*1024)); + printf("Current tensor offset in MiB: %zu\n", current_tensor_offset / (1024*1024)); cudaCheck(cudaMalloc(&model->gpu_mem, current_tensor_offset)); + cudaCheck(cudaMemset(model->gpu_mem, 0, current_tensor_offset)); //initialise helper variables model->num_parameters = tensors_elements[TT::PARAMETER]; @@ -419,27 +391,27 @@ void gpt2_allocate(GPT2 *model) { // parameter gradient bytes size_t param_grad_bytes = tensors_bytes[TT::PARAMETER_GRADIENT]; - printf("number of parameter gradient bytes: %zu\n", param_grad_bytes / (1024*1024)); + printf("number of parameter gradient bytes: %zu MiB\n", param_grad_bytes / (1024*1024)); // number of master weight bytes size_t master_weight_bytes = tensors_bytes[TT::PARAMETER_MASTER]; - printf("number of master weight bytes: %zu\n", master_weight_bytes / (1024*1024)); + printf("number of master weight bytes: %zu MiB\n", master_weight_bytes / (1024*1024)); // opt state m size_t m_bytes = tensors_bytes[TT::PARAMETER_OPT_M]; - printf("number of m bytes: %zu\n", m_bytes / (1024*1024)); + printf("number of m bytes: %zu MiB\n", m_bytes / (1024*1024)); // opt state v size_t v_bytes = tensors_bytes[TT::PARAMETER_OPT_V]; - printf("number of v bytes: %zu\n", v_bytes / (1024*1024)); + printf("number of v bytes: %zu MiB\n", v_bytes / (1024*1024)); // number of act bytes size_t act_bytes = tensors_bytes[TT::ACTIVATION]; - printf("number of act bytes: %zu\n", act_bytes / (1024*1024)); + printf("number of act bytes: %zu MiB\n", act_bytes / (1024*1024)); // number of act gradient bytes size_t act_grad_bytes = tensors_bytes[TT::ACTIVATION_GRADIENT]; - printf("number of act grad bytes: %zu\n", act_grad_bytes / (1024*1024)); + printf("number of act grad bytes: %zu MiB\n", act_grad_bytes / (1024*1024)); // number of multiuse bytes size_t multiuse_bytes = tensors_bytes[TT::MULTIUSE]; - printf("number of multiuse bytes: %zu\n", multiuse_bytes / (1024*1024)); + printf("number of multiuse bytes: %zu MiB\n", multiuse_bytes / (1024*1024)); - printf("number of act+actgrad+multiuse bytes: %zu\n", (multiuse_bytes + act_bytes + act_grad_bytes) / (1024*1024)); + printf("number of act+actgrad+multiuse bytes: %zu MiB\n", (multiuse_bytes + act_bytes + act_grad_bytes) / (1024*1024)); // ======================= // allocate_state stuff @@ -461,9 +433,9 @@ void gpt2_allocate(GPT2 *model) { printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); // give an estimate of the maximum batch size - size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATION] / B; + size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATION] / B; // pessimistic (output buffer) bytes_per_sequence += tensors_bytes[TT::ACTIVATION_GRADIENT] / B; - bytes_per_sequence += tensors_bytes[TT::MULTIUSE] / B; // pessimistic? + bytes_per_sequence += tensors_bytes[TT::MULTIUSE] / B; printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); } @@ -687,6 +659,57 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { #define GPU_X(x) (floatX*)((char*)model->gpu_mem + tensor_specs[x].offset) #define GPU_F32(x) (float*)((char*)model->gpu_mem + tensor_specs[x].offset) +#define GPU_VOID(x) (void*)((char*)model->gpu_mem + tensor_specs[x].offset) + +// debug helper function +void print_tensor_elements(GPT2 *model, int tensor_id) { + const char* tensor_name = tensor_specs[tensor_id].name; + size_t num_elements = tensor_specs[tensor_id].num_elements; + DType dtype = tensor_specs[tensor_id].data_type; + size_t element_size = sizeof_dtype(dtype); + void* gpu_tensor = GPU_VOID(tensor_id); + void* cpu_tensor = malloc(num_elements * element_size); + cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); + + printf("First 4 of %s: ", tensor_name); + for (int i = 0; i < num_elements && i < 4; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Middle 4 of %s: ", tensor_name); + for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Last 4 of %s: ", tensor_name); + for (int i = num_elements - 4; i < num_elements; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + printf("\n"); + + free(cpu_tensor); +} // propagate inputs through the network to produce logits. // right now, this function is fully synchronous with the host @@ -721,7 +744,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { encoder_forward(GPU_X(acts.encoded), model->inputs, GPU_X(params.wte), GPU_X(params.wpe), B, T, C, main_stream); // encoding goes into residual[0] // first layernorm isn't fused - layernorm_forward(GPU_X((model->recompute < 2) ? acts.ln1 : acts.lnf), GPU_F32(acts.ln1_mean), GPU_F32(acts.ln1_rstd), GPU_X(acts.encoded), GPU_X(params.ln1w), GPU_X(params.ln1b), B, T, C, main_stream); + layernorm_forward(GPU_X(acts.ln1), GPU_F32(acts.ln1_mean), GPU_F32(acts.ln1_rstd), GPU_X(acts.encoded), GPU_X(params.ln1w), GPU_X(params.ln1b), B, T, C, main_stream); for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); @@ -731,29 +754,28 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // get the pointers of the weights for this layer floatX* l_qkvw = GPU_X(params.qkvw + l); floatX* l_qkvb = GPU_X(params.qkvb + l); - floatX* l_attprojw = GPU_X(params.attprojw); - floatX* l_attprojb = GPU_X(params.attprojb); - floatX* l_ln2w = GPU_X(params.ln2w); - floatX* l_ln2b = GPU_X(params.ln2b); - floatX* l_fcw = GPU_X(params.fcw); - floatX* l_fcb = GPU_X(params.fcb); - floatX* l_fcprojw = GPU_X(params.fcprojw); - floatX* l_fcprojb = GPU_X(params.fcprojb); + floatX* l_attprojw = GPU_X(params.attprojw + l); + floatX* l_attprojb = GPU_X(params.attprojb + l); + floatX* l_ln2w = GPU_X(params.ln2w + l); + floatX* l_ln2b = GPU_X(params.ln2b + l); + floatX* l_fcw = GPU_X(params.fcw + l); + floatX* l_fcb = GPU_X(params.fcb + l); + floatX* l_fcprojw = GPU_X(params.fcprojw + l); + floatX* l_fcprojb = GPU_X(params.fcprojb + l); // get the pointers of the activations for this layer - floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + l : acts.lnf); + floatX* l_ln1 = GPU_X(acts.ln1 + l); floatX* l_qkvr = GPU_X(acts.qkvr + l); floatX* l_atty = GPU_X(acts.atty + l); floatX* l_residual2 = GPU_X(acts.residual2 + l); - floatX* l_ln2 = GPU_X((model->recompute < 2) ? acts.ln2 + l : acts.lnf); + floatX* l_ln2 = GPU_X(acts.ln2 + l); float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); floatX* l_fch = GPU_X(acts.fch + l); - // reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward - // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size - floatX* l_fch_gelu = GPU_X((model->recompute < 1) ? acts.fch_gelu + l : acts.fch_gelu); + floatX* l_fch_gelu = GPU_X(acts.fch_gelu + l); floatX* l_residual3 = GPU_X(acts.residual3 + l); - floatX* scratch = GPU_X(acts.output); // used for non-cudnn attention, fcproj, attproj, etc. + floatX* l_fcproj = GPU_X(acts.fcproj + l); + floatX* l_attproj = GPU_X(acts.attproj + l); // now do the forward pass #ifdef ENABLE_CUDNN @@ -762,6 +784,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); #else floatX* l_att = GPU_X(acts.att + l); + floatX* scratch = GPU_X(acts.output); if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); } @@ -771,21 +794,21 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); #endif - matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); - fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream); + matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); + fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, l_attproj, l_ln2w, l_ln2b, B*T, C, main_stream); matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); - matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); + matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); // OK, fusion across blocks. if(l+1 != L) { - floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + (l + 1) : acts.lnf); + floatX* l_ln1 = GPU_X(acts.ln1 + (l + 1)); float* l_ln1_mean = GPU_F32(acts.ln1_mean + (l + 1)); float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + (l + 1)); const floatX* l_ln1w = GPU_X(params.ln1w + (l + 1)); const floatX* l_ln1b = GPU_X(params.ln1b + (l + 1)); - fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b, + fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, l_fcproj, l_ln1w, l_ln1b, B * T, C, main_stream); } else { - fused_residual_forward5(l_residual3, GPU_X(acts.lnf), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), l_residual2, scratch, + fused_residual_forward5(l_residual3, GPU_X(acts.lnf), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), l_residual2, l_fcproj, GPU_X(params.lnfw), GPU_X(params.lnfb), B * T, C, main_stream); } @@ -827,8 +850,6 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B } void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { - return; // debugging forward only first... - if(model->grads_memory == nullptr) { fprintf(stderr, "Need to allocate gradients before backward"); exit(EXIT_FAILURE); @@ -867,6 +888,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // backward pass: go in the reverse order of the forward pass, and call backward() functions // reset residual stream gradients (put here to work with gradient accumulation) + // todo - this should be a dedicated tensor, not addressing multiuse.btc directly! floatX* dresidual = GPU_X(model->multiuse.btc); // the main buffer holding the gradient in the backward pass cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); @@ -917,22 +939,22 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* dl_fcprojw = GPU_X(grads.fcprojw + l); floatX* dl_fcprojb = GPU_X(grads.fcprojb + l); // get the pointers of the activations for this layer - floatX* l_ln1 = GPU_X((model->recompute < 2) ? acts.ln1 + l : acts.lnf); + floatX* l_ln1 = GPU_X(acts.ln1 + l); float* l_ln1_mean = GPU_F32(acts.ln1_mean + l); float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + l); - floatX* l_qkvr = GPU_X(acts.qkvr + l * B * T * 3*C); - floatX* l_atty = GPU_X(acts.atty + l * B * T * C); + floatX* l_qkvr = GPU_X(acts.qkvr + l); + floatX* l_atty = GPU_X(acts.atty + l); floatX* l_residual2 = GPU_X(acts.residual2 + l); - floatX* l_ln2 = GPU_X((model->recompute < 2) ? acts.ln2 + l : acts.lnf); + floatX* l_ln2 = GPU_X(acts.ln2 + l); float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); floatX* l_fch_pre_gelu = GPU_X(acts.fch + l); - floatX* l_fch_gelu = GPU_X((model->recompute < 1) ? acts.fch_gelu + l : acts.fch_gelu); + floatX* l_fch_gelu = GPU_X(acts.fch_gelu + l); // get the pointers of the gradients of the activations for this layer // notice that there is no l *, because we just have a single copy, and keep // re-using this memory in every Transformer block as we calculate backward pass - floatX* dl_bt4c = (floatX*)GPU_X(model->multiuse.bt4c); + floatX* dl_bt4c = (floatX*)GPU_X(model->multiuse.bt4c); // TODO - should use dedicated tensors! // start the backward pass for this layer if(model->recompute >= 1) { @@ -995,7 +1017,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // Aggregate all gradients that are not part of the transformer blocks if(last_step) { // reduce all the losses within the current GPU (across all microsteps) - global_sum_deterministic(model->accumulated_mean_loss, GPU_X(acts.losses), B*T, main_stream); + global_sum_deterministic(model->accumulated_mean_loss, GPU_F32(acts.losses), B*T, main_stream); // reduce loss across GPUs to a single, final float across all microsteps and GPUs #if MULTI_GPU ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); @@ -1042,11 +1064,9 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { float* grad_norm_squared = GPU_F32(model->acts.output); float grad_norm_squared_cpu = 0.0f; - /* - int num_slices[2] = {1, model->config.num_layers}; int max_num_block_sums = get_max_num_block_sums(num_slices, 2); - if (multi_gpu_config->zero_stage == 1) { + /*if (multi_gpu_config->zero_stage == 1) { // because of the ncclReduceScatter() in backward, // grads_memory only contains the averaged gradients at the local shards, // so we only calculate the grad norm at the grads_memory belonging to the local shards @@ -1068,14 +1088,13 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { // further sum the (partial) squared norm across all GPUs ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); #endif - } else { + } else*/ { // in regular DDP, backward has averaged the gradients across all GPUs // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); } cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); - */ float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); return grad_norm_cpu; From 3018cc61082a706656b2f448451c9814fecd2e0c Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 1 Sep 2024 17:37:13 +0000 Subject: [PATCH 03/27] More laconic ACT_XL, PARAM_X, etc. indexing... --- train_gpt2.cu | 132 +++++++++++++++++++------------------------------- 1 file changed, 49 insertions(+), 83 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 710515a26..088840f24 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -54,9 +54,11 @@ bool UNIQUE_TENSOR_MEMORY = false; #ifdef ENABLE_CUDNN // defines: create_cudnn, destroy_cudnn, attention_forward_cudnn, attention_backward_cudnn #include "llmc/cudnn_att.h" +#define CUDNN_ENABLED 1 #else // defines: attention_forward, attention_backward #include "llmc/attention.cuh" +#define CUDNN_ENABLED 0 #endif // defines: fused_classifier #include "llmc/fused_classifier.cuh" @@ -87,7 +89,9 @@ constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // GPT-2 model definition enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRADIENT, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, + PARAMETER=0, + PARAMETER_GRADIENT, + PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, MULTIUSE, ACTIVATION, ACTIVATION_GRADIENT, DEFAULT, COUNT=DEFAULT }; @@ -174,13 +178,13 @@ size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; size_t tensors_elements[TT::COUNT] = {0}; -int add_tensor_spec(const char* name, size_t num_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, TT tensor_type=TT::DEFAULT) { +int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, TT tensor_type=TT::DEFAULT) { assert(num_tensor_specs < 16*1024); - assert((num_elements % num_shards) == 0); + assert((total_elements % num_shards) == 0); TensorSpec* spec = &tensor_specs[num_tensor_specs++]; strncpy(spec->name, name, 16); - spec->num_elements = num_elements / num_shards; + spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; spec->data_type = data_type; spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; @@ -204,12 +208,12 @@ int add_tensor_spec(const char* name, size_t num_elements, size_t num_shards, DT return num_tensor_specs - 1; } -int add_layer_specs(int num_layers, const char* name, size_t num_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, bool copy_per_layer=false, TT tensor_type=TT::DEFAULT) { +int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, bool copy_per_layer=false, TT tensor_type=TT::DEFAULT) { int first_tensor_id = num_tensor_specs; for (int l = 0; l < num_layers; l++) { char layer_name[16]; assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); - add_tensor_spec(num_layers > 1 ? layer_name : name, num_elements, num_shards, data_type, copy_offset_from, tensor_type); + add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, tensor_type); if (copy_per_layer) { copy_offset_from++; } @@ -661,6 +665,15 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { #define GPU_F32(x) (float*)((char*)model->gpu_mem + tensor_specs[x].offset) #define GPU_VOID(x) (void*)((char*)model->gpu_mem + tensor_specs[x].offset) +#define ACT_X(x) GPU_X(acts.x) +#define ACT_XL(x) GPU_X(acts.x + l) +#define ACT_32(x) GPU_F32(acts.x) +#define ACT_32L(x) GPU_F32(acts.x + l) +#define PARAM_X(x) GPU_X(params.x) +#define PARAM_XL(x) GPU_X(params.x + l) +#define PARAM_32(x) GPU_F32(params.x) +#define PARAM_32L(x) GPU_F32(params.x + l) + // debug helper function void print_tensor_elements(GPT2 *model, int tensor_id) { const char* tensor_name = tensor_specs[tensor_id].name; @@ -712,13 +725,11 @@ void print_tensor_elements(GPT2 *model, int tensor_id) { } // propagate inputs through the network to produce logits. -// right now, this function is fully synchronous with the host void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); - // we must be careful and use size_t instead of int, otherwise - // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. - - // convenience parameters + // we must be careful and use size_t instead of int, otherwise we could overflow + ParameterTensors params = model->params; + ActivationTensors acts = model->acts; const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; const size_t L = model->config.num_layers; @@ -731,91 +742,46 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T); exit(EXIT_FAILURE); } + // unused parts of attention buffer must be zeroed for non-cuDNN path + if (!CUDNN_ENABLED && T != model->seq_len) { + cudaCheck(cudaMemset(ACT_X(att), 0, L * B * NH * T * T * sizeof(floatX))); + } - // copy inputs/targets to the model + // copy inputs/targets to the model (fully synchronous with the host for now) cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); // validate inputs, all indices must be in the range [0, V) - // we can do this while the copies are already underway tokenCheck(inputs, B*T, V); - // forward pass - ParameterTensors params = model->params; // for brevity - ActivationTensors acts = model->acts; - encoder_forward(GPU_X(acts.encoded), model->inputs, GPU_X(params.wte), GPU_X(params.wpe), B, T, C, main_stream); // encoding goes into residual[0] - - // first layernorm isn't fused - layernorm_forward(GPU_X(acts.ln1), GPU_F32(acts.ln1_mean), GPU_F32(acts.ln1_rstd), GPU_X(acts.encoded), GPU_X(params.ln1w), GPU_X(params.ln1b), B, T, C, main_stream); + // start of forward pass with encoder + encoder_forward(ACT_X(encoded), model->inputs, PARAM_X(wte), PARAM_X(wpe), B, T, C, main_stream); // encoding goes into residual[0] + layernorm_forward(ACT_X(ln1), ACT_32(ln1_mean), ACT_32(ln1_rstd), ACT_X(encoded), PARAM_X(ln1w), PARAM_X(ln1b), B, T, C, main_stream); for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); + floatX* input_residual = GPU_X(l == 0 ? acts.encoded : (acts.residual3 + l - 1)); - floatX* residual = GPU_X(l == 0 ? acts.encoded : (acts.residual3 + l - 1)); - - // get the pointers of the weights for this layer - floatX* l_qkvw = GPU_X(params.qkvw + l); - floatX* l_qkvb = GPU_X(params.qkvb + l); - floatX* l_attprojw = GPU_X(params.attprojw + l); - floatX* l_attprojb = GPU_X(params.attprojb + l); - floatX* l_ln2w = GPU_X(params.ln2w + l); - floatX* l_ln2b = GPU_X(params.ln2b + l); - floatX* l_fcw = GPU_X(params.fcw + l); - floatX* l_fcb = GPU_X(params.fcb + l); - floatX* l_fcprojw = GPU_X(params.fcprojw + l); - floatX* l_fcprojb = GPU_X(params.fcprojb + l); - - // get the pointers of the activations for this layer - floatX* l_ln1 = GPU_X(acts.ln1 + l); - floatX* l_qkvr = GPU_X(acts.qkvr + l); - floatX* l_atty = GPU_X(acts.atty + l); - floatX* l_residual2 = GPU_X(acts.residual2 + l); - floatX* l_ln2 = GPU_X(acts.ln2 + l); - float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); - float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); - floatX* l_fch = GPU_X(acts.fch + l); - floatX* l_fch_gelu = GPU_X(acts.fch_gelu + l); - floatX* l_residual3 = GPU_X(acts.residual3 + l); - floatX* l_fcproj = GPU_X(acts.fcproj + l); - floatX* l_attproj = GPU_X(acts.attproj + l); - - // now do the forward pass + matmul_forward_cublaslt(CUDNN_ENABLED ? ACT_XL(qkvr) : ACT_X(output), ACT_XL(ln1), PARAM_XL(qkvw), PARAM_XL(qkvb), B, T, C, 3*C, main_stream); #ifdef ENABLE_CUDNN - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream); + attention_forward_cudnn(ACT_XL(atty), ACT_32L(att), ACT_XL(qkvr), B, T, NH, C, main_stream); #else - floatX* l_att = GPU_X(acts.att + l); - floatX* scratch = GPU_X(acts.output); - if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent) - cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX))); - } - // these are only needed as scratchpads for the forward pass, but - // need not be stored for backward - matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream); - attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream); + attention_forward(ACT_XL(atty), ACT_XL(qkvr), ACT_XL(att), ACT_X(output), B, T, C, NH, main_stream); #endif - matmul_forward_cublaslt(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream); - fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, l_attproj, l_ln2w, l_ln2b, B*T, C, main_stream); - matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion); - matmul_forward_cublaslt(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream); - // OK, fusion across blocks. - if(l+1 != L) { - floatX* l_ln1 = GPU_X(acts.ln1 + (l + 1)); - float* l_ln1_mean = GPU_F32(acts.ln1_mean + (l + 1)); - float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + (l + 1)); - const floatX* l_ln1w = GPU_X(params.ln1w + (l + 1)); - const floatX* l_ln1b = GPU_X(params.ln1b + (l + 1)); - fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, l_fcproj, l_ln1w, l_ln1b, - B * T, C, main_stream); + matmul_forward_cublaslt(ACT_XL(attproj), ACT_XL(atty), PARAM_XL(attprojw), PARAM_XL(attprojb), B, T, C, C, main_stream); + fused_residual_forward5(ACT_XL(residual2), ACT_XL(ln2), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), input_residual, ACT_XL(attproj), PARAM_XL(ln2w), PARAM_XL(ln2b), B*T, C, main_stream); + matmul_forward_cublaslt(ACT_XL(fch_gelu), ACT_XL(ln2), PARAM_XL(fcw), PARAM_XL(fcb), B, T, C, 4*C, main_stream, ACT_XL(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT_XL(fcproj), ACT_XL(fch_gelu), PARAM_XL(fcprojw), PARAM_XL(fcprojb), B, T, 4*C, C, main_stream); + + if(l+1 != L) { // fusion across layers + fused_residual_forward5(ACT_XL(residual3), ACT_XL(ln1 + 1), ACT_32L(ln1_mean + 1), ACT_32L(ln1_rstd + 1), ACT_XL(residual2), ACT_XL(fcproj), + PARAM_XL(ln1w + 1), PARAM_XL(ln1b + 1), B * T, C, main_stream); } else { - fused_residual_forward5(l_residual3, GPU_X(acts.lnf), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), l_residual2, l_fcproj, - GPU_X(params.lnfw), GPU_X(params.lnfb), - B * T, C, main_stream); + fused_residual_forward5(ACT_XL(residual3), ACT_X(lnf), ACT_32(lnf_mean), ACT_32(lnf_rstd), ACT_XL(residual2), ACT_XL(fcproj), + PARAM_X(lnfw), PARAM_X(lnfb), B * T, C, main_stream); } } - matmul_forward_cublaslt(GPU_X(acts.output), GPU_X(acts.lnf), GPU_X(params.wte), NULL, B, T, C, Vp, main_stream); - cudaCheck(cudaDeviceSynchronize()); + matmul_forward_cublaslt(ACT_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); } @@ -836,11 +802,11 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(GPU_F32(acts.losses), 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(ACT_32(losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(GPU_X(acts.output), GPU_F32(acts.losses), dloss, model->targets, B, T, V, Vp, False, main_stream); - cudaCheck(cudaMemcpy(model->cpu_losses, GPU_F32(acts.losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, False, main_stream); + cudaCheck(cudaMemcpy(model->cpu_losses, ACT_32(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -861,7 +827,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // there are currently two state vars during the gradient accumulation inner loop: // 1) the losses accumulate += into acts.losses, reset here // 2) the gradients accumulate += into grads_memory, reset here - cudaCheck(cudaMemsetAsync(GPU_X(model->acts.losses), 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(GPU_F32(model->acts.losses), 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); cudaCheck(cudaMemsetAsync(model->grads_memory, 0, tensors_bytes[TT::PARAMETER_GRADIENT], main_stream)); } From 474a60b296549c44f105b13b569143e4aeb502d9 Mon Sep 17 00:00:00 2001 From: ademeure Date: Sun, 1 Sep 2024 23:00:02 +0000 Subject: [PATCH 04/27] broken progress (forward is OK, backward is not) for restructure using ACT_X/PARAM_XL/etc... --- llmc/fused_classifier.cuh | 10 +- llmc/layernorm.cuh | 20 ++- train_gpt2.cu | 342 ++++++++++++++++---------------------- 3 files changed, 159 insertions(+), 213 deletions(-) diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index 4837d4cb0..ba5a1e554 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -67,7 +67,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) - fused_classifier_kernel5(floatX* logits, float* losses, floatX* probs, + fused_classifier_kernel5(floatX* dlogits, floatX* logits, float* losses, floatX* probs, const float dloss, const int* targets, int B, int T, int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; @@ -109,7 +109,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) if (WriteDLogits){ // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here - store128cs(logits + idx * P + i * x128::size, packed_logits_vec); + store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec); } if (WriteProbs) { store128(probs + idx * P + i * x128::size, packed_probs); @@ -124,7 +124,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; if (WriteDLogits){ - __stcs(logits + idx * P + i, (floatX)dlogit); + __stcs(dlogits + idx * P + i, (floatX)dlogit); } if (WriteProbs) { probs[idx * P + i] = (floatX)prob; @@ -137,13 +137,13 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // replaces logits with logit gradients template -void fused_classifier(Type* logits, float* losses, +void fused_classifier(Type* dlogits, Type* logits, float* losses, const float dloss, const int* targets, int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 1024; const int N = B * T; const int grid_size = N; - fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); + fused_classifier_kernel5<<>>(dlogits, logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); cudaCheck(cudaGetLastError()); } diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 9777d0658..da0bf9591 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -230,8 +230,9 @@ __global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const f store128(out + idx, packed_out); } +template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel10(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + layernorm_backward_kernel10(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C) { @@ -266,7 +267,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { const floatX* dout_bt = dout + bt * C; const floatX* inp_bt = inp +bt * C; - floatX* dinp_bt = dinp + bt * C; + floatX* dinp_bt = dinp_old + bt * C; + floatX* dinp_new_bt = dinp_new + bt * C; // first: two reduce operations float dnorm_mean = 0.0f; @@ -298,8 +300,10 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with if(global_index < C) { dout128 = load128cs(dout_bt + global_index); inp128 = load128cs(inp_bt + global_index); - dinp128 = load128(dinp_bt + global_index); weight128 = load128(weight + global_index); + if constexpr (!zero_dinp_old) { + dinp128 = load128(dinp_bt + global_index); + } } for(int o = 0; o < x128::size / f128::size; ++o) { @@ -353,7 +357,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } if(global_index < C) { // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing - store128cg(dinp_bt + global_index, dinp128); + store128cg(dinp_new_bt + global_index, dinp128); } } } @@ -489,7 +493,7 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa cudaCheck(cudaGetLastError()); } -void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, +void layernorm_backward(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, int B, int T, int C, cudaStream_t stream) { NVTX_RANGE_FN(); @@ -500,6 +504,10 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 - layernorm_backward_kernel10<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + if (dinp_old == nullptr) { + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + } else { + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + } cudaCheck(cudaGetLastError()); } diff --git a/train_gpt2.cu b/train_gpt2.cu index 088840f24..ed71e6afc 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,7 +1,7 @@ /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ -bool UNIQUE_TENSOR_MEMORY = false; +bool UNIQUE_TENSOR_MEMORY = true; #include #include @@ -89,11 +89,9 @@ constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // GPT-2 model definition enum TT : uint8_t { - PARAMETER=0, - PARAMETER_GRADIENT, - PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, - MULTIUSE, ACTIVATION, ACTIVATION_GRADIENT, - DEFAULT, COUNT=DEFAULT + PARAMETER=0, PARAMETER_GRADIENT, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each + ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_OPT_V+1 }; typedef struct { @@ -122,11 +120,7 @@ typedef struct { typedef struct { GPT2Config config; - ParameterTensors params; - ParameterTensors params_grads; - ParameterTensors params_master; - ParameterTensors params_opt_m; - ParameterTensors params_opt_v; + ParameterTensors params[NUM_TYPES_PARAM]; ActivationTensors acts; ActivationTensors acts_grads; MultiuseTensors multiuse; @@ -134,12 +128,8 @@ typedef struct { size_t num_parameters; size_t num_parameters_bytes; - char* gpu_mem; - void* params_memory; - void* grads_memory; - float* m_memory; - float* v_memory; - float* master_weights; + void* multiuse_memory = NULL; + void* params_memory[NUM_TYPES_PARAM] = {0}; // other run state configuration int batch_size = 0; // the batch size (B) of current forward pass @@ -172,7 +162,6 @@ typedef struct { TensorSpec tensor_specs[16*1024]; size_t num_tensor_specs = 0; -size_t current_tensor_offset = 0; TT current_tensor_type = TT::PARAMETER; size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; @@ -195,14 +184,13 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->offset = tensor_specs[copy_offset_from].offset; size_t original_tensor_bytes = tensor_specs[copy_offset_from].num_elements * sizeof_dtype(tensor_specs[copy_offset_from].data_type); size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); + assert(tensor_specs[copy_offset_from].tensor_type == spec->tensor_type); assert(new_tensor_bytes <= original_tensor_bytes); } else { - spec->offset = current_tensor_offset; - current_tensor_offset += spec->num_elements * sizeof_dtype(data_type); - + spec->offset = tensors_bytes[spec->tensor_type]; tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { - tensors_start[spec->tensor_type] = spec->offset; + tensors_start[spec->tensor_type] = num_tensor_specs - 1; } } return num_tensor_specs - 1; @@ -243,19 +231,13 @@ void gpt2_allocate(GPT2 *model) { int shards_grad = (multi_gpu_config.zero_stage >= 2) ? num_gpu : 1; // 1) parameters & optimizer state - for (int t = TT::PARAMETER; t <= TT::PARAMETER_OPT_V; t++) { - DType dtype = (t <= TT::PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; - DType dtype_lowp = (t <= TT::PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future + for (int t = PARAMETER; t <= PARAMETER_OPT_V; t++) { + DType dtype = (t <= PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; + DType dtype_lowp = (t <= PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future current_tensor_type = (TT)t; - ParameterTensors* spec; - switch (t) { - case TT::PARAMETER: spec = &model->params; shards = 1; break; - case TT::PARAMETER_GRADIENT: spec = &model->params_grads; shards = shards_grad; break; - case TT::PARAMETER_MASTER: spec = &model->params_master; shards = shards_opt; break; - case TT::PARAMETER_OPT_M: spec = &model->params_opt_m; shards = shards_opt; break; - case TT::PARAMETER_OPT_V: spec = &model->params_opt_v; shards = shards_opt; break; - } + ParameterTensors* spec = &model->params[t]; + shards = (t == PARAMETER) ? 1 : (t == PARAMETER_GRADIENT) ? shards_grad : shards_opt; if (t == PARAMETER_MASTER && !model->use_master_weights) { continue; } @@ -279,19 +261,19 @@ void gpt2_allocate(GPT2 *model) { } // 2) multiuse & scratch tensors - if (UNIQUE_TENSOR_MEMORY) { + current_tensor_type = ACTIVATIONS_MULTIUSE; + /*if (UNIQUE_TENSOR_MEMORY) { model->multiuse.bt4c = -1; model->multiuse.btc = -1; - } else { - model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, TT::MULTIUSE); - model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, TT::MULTIUSE); + } else*/ { + model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX); + model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX); } // 3) activations - current_tensor_type = TT::ACTIVATION; ActivationTensors* spec = &model->acts; - DType dtype = DTYPE_FLOATX; DType dtype_lowp = DTYPE_FLOATX; // todo FP8 + DType dtype = DTYPE_FLOATX; shards = 1; TENSOR_SPECS (encoded, 1, BTC); @@ -341,13 +323,13 @@ void gpt2_allocate(GPT2 *model) { } // 4) activation gradients - current_tensor_type = TT::ACTIVATION_GRADIENT; spec = &model->acts_grads; dtype_lowp = DTYPE_FLOATX; // todo FP8 shards = 1; if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(output, 1, BT * max(3*C, max(NH*T, Vp))); + TENSOR_SPECS(encoded, 1, BTC); + TENSOR_SPECS(output, 1, output_size); TENSOR_SPECS(lnf, 1, BTC); TENSOR_SPECS(ln1, L, BTC); TENSOR_SPECS(atty, L, BTC); @@ -371,28 +353,22 @@ void gpt2_allocate(GPT2 *model) { } // allocate a single huge GPU buffer for all the tensors - printf("Current tensor offset in MiB: %zu\n", current_tensor_offset / (1024*1024)); - cudaCheck(cudaMalloc(&model->gpu_mem, current_tensor_offset)); - cudaCheck(cudaMemset(model->gpu_mem, 0, current_tensor_offset)); + cudaCheck(cudaMalloc(&model->multiuse_memory, tensors_bytes[ACTIVATIONS_MULTIUSE])); + cudaCheck(cudaMemset(model->multiuse_memory, 0, tensors_bytes[ACTIVATIONS_MULTIUSE])); + + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER], tensors_bytes[PARAMETER])); + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_GRADIENT], tensors_bytes[PARAMETER_GRADIENT])); + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_M], tensors_bytes[PARAMETER_OPT_M])); + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_V], tensors_bytes[PARAMETER_OPT_V])); + if (model->use_master_weights) { + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_MASTER], tensors_bytes[PARAMETER_MASTER])); + } //initialise helper variables model->num_parameters = tensors_elements[TT::PARAMETER]; model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; - model->params_memory = (void*)(model->gpu_mem + tensors_start[TT::PARAMETER]); - model->grads_memory = (void*)(model->gpu_mem + tensors_start[TT::PARAMETER_GRADIENT]); - model->m_memory = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_OPT_M]); - model->v_memory = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_OPT_V]); - if (model->use_master_weights) { - model->master_weights = (float*)(model->gpu_mem + tensors_start[TT::PARAMETER_MASTER]); - } - // printf gpu_mem and params_memory - printf("gpu_mem: %p\n", model->gpu_mem); - printf("params_memory: %p\n", model->params_memory); - printf("number of parameters: %zu\n", model->num_parameters); - printf("number of parameters bytes: %zu\n", model->num_parameters_bytes); - // parameter gradient bytes size_t param_grad_bytes = tensors_bytes[TT::PARAMETER_GRADIENT]; printf("number of parameter gradient bytes: %zu MiB\n", param_grad_bytes / (1024*1024)); @@ -405,17 +381,9 @@ void gpt2_allocate(GPT2 *model) { // opt state v size_t v_bytes = tensors_bytes[TT::PARAMETER_OPT_V]; printf("number of v bytes: %zu MiB\n", v_bytes / (1024*1024)); - // number of act bytes - size_t act_bytes = tensors_bytes[TT::ACTIVATION]; - printf("number of act bytes: %zu MiB\n", act_bytes / (1024*1024)); - // number of act gradient bytes - size_t act_grad_bytes = tensors_bytes[TT::ACTIVATION_GRADIENT]; - printf("number of act grad bytes: %zu MiB\n", act_grad_bytes / (1024*1024)); // number of multiuse bytes - size_t multiuse_bytes = tensors_bytes[TT::MULTIUSE]; - printf("number of multiuse bytes: %zu MiB\n", multiuse_bytes / (1024*1024)); - - printf("number of act+actgrad+multiuse bytes: %zu MiB\n", (multiuse_bytes + act_bytes + act_grad_bytes) / (1024*1024)); + size_t multiuse_bytes = tensors_bytes[TT::ACTIVATIONS_MULTIUSE]; + printf("number of act+actgrad+multiuse bytes: %zu MiB\n", (multiuse_bytes) / (1024*1024)); // ======================= // allocate_state stuff @@ -437,9 +405,7 @@ void gpt2_allocate(GPT2 *model) { printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); // give an estimate of the maximum batch size - size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATION] / B; // pessimistic (output buffer) - bytes_per_sequence += tensors_bytes[TT::ACTIVATION_GRADIENT] / B; - bytes_per_sequence += tensors_bytes[TT::MULTIUSE] / B; + size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATIONS_MULTIUSE] / B; // pessimistic (output buffer etc.) printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); } @@ -527,7 +493,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // read in the parameters if weight_init is true if (weight_init) { - file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); + file_to_device(model->params_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); } fcloseCheck(model_file); @@ -657,30 +623,46 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { } */ // copy them to GPU - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(model->params_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); } -#define GPU_X(x) (floatX*)((char*)model->gpu_mem + tensor_specs[x].offset) -#define GPU_F32(x) (float*)((char*)model->gpu_mem + tensor_specs[x].offset) -#define GPU_VOID(x) (void*)((char*)model->gpu_mem + tensor_specs[x].offset) +#define ACT_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[acts.x].offset) +#define ACT_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[acts.x].offset) +#define ACT_XL(x) ACT_X(x + l) +#define ACT_32L(x) ACT_32(x + l) + +#define PARAM_X(x) (floatX*)((char*)model->params_memory[PARAMETER] + tensor_specs[params.x].offset) +#define PARAM_32(x) (float*)((char*)model->params_memory[PARAMETER] + tensor_specs[params.x].offset) +#define PARAM_XL(x) PARAM_X(x + l) +#define PARAM_32L(x) PARAM_32(x + l) + +#define PGRAD_X(x) (floatX*)((char*)model->params_memory[PARAMETER_GRADIENT] + tensor_specs[grads.x].offset) +#define PGRAD_32(x) (float*)((char*)model->params_memory[PARAMETER_GRADIENT] + tensor_specs[grads.x].offset) +#define PGRAD_XL(x) PGRAD_X(x + l) +#define PGRAD_32L(x) PGRAD_32(x + l) + +#define AGRAD_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[acts_grads.x].offset) +#define AGRAD_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[acts_grads.x].offset) +#define AGRAD_XL(x) AGRAD_X(x + l) +#define AGRAD_32L(x) AGRAD_32(x + l) + +#define MULTI_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[x].offset) +#define MULTI_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[x].offset) +#define MULTI_XL(x) PARAM_X(x + l) +#define MULTI_32L(x) PARAM_32(x + l) -#define ACT_X(x) GPU_X(acts.x) -#define ACT_XL(x) GPU_X(acts.x + l) -#define ACT_32(x) GPU_F32(acts.x) -#define ACT_32L(x) GPU_F32(acts.x + l) -#define PARAM_X(x) GPU_X(params.x) -#define PARAM_XL(x) GPU_X(params.x + l) -#define PARAM_32(x) GPU_F32(params.x) -#define PARAM_32L(x) GPU_F32(params.x + l) // debug helper function void print_tensor_elements(GPT2 *model, int tensor_id) { const char* tensor_name = tensor_specs[tensor_id].name; size_t num_elements = tensor_specs[tensor_id].num_elements; + TT tensor_type = tensor_specs[tensor_id].tensor_type; DType dtype = tensor_specs[tensor_id].data_type; size_t element_size = sizeof_dtype(dtype); - void* gpu_tensor = GPU_VOID(tensor_id); + + void* gpu_memory = (tensor_id == TT::ACTIVATIONS_MULTIUSE) ? model->multiuse_memory : model->params_memory[tensor_type]; + void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); void* cpu_tensor = malloc(num_elements * element_size); cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); @@ -728,7 +710,7 @@ void print_tensor_elements(GPT2 *model, int tensor_id) { void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); // we must be careful and use size_t instead of int, otherwise we could overflow - ParameterTensors params = model->params; + ParameterTensors params = model->params[PARAMETER]; ActivationTensors acts = model->acts; const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; @@ -758,7 +740,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { for (int l = 0; l < L; l++) { NvtxRange layer_range("Layer", l); - floatX* input_residual = GPU_X(l == 0 ? acts.encoded : (acts.residual3 + l - 1)); + floatX* input_residual = l == 0 ? ACT_X(encoded) : ACT_X(residual3 + l-1); matmul_forward_cublaslt(CUDNN_ENABLED ? ACT_XL(qkvr) : ACT_X(output), ACT_XL(ln1), PARAM_XL(qkvw), PARAM_XL(qkvb), B, T, C, 3*C, main_stream); #ifdef ENABLE_CUDNN @@ -805,7 +787,7 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B cudaCheck(cudaMemset(ACT_32(losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, False, main_stream); + fused_classifier(ACT_X(output), ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, False, main_stream); cudaCheck(cudaMemcpy(model->cpu_losses, ACT_32(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; @@ -816,22 +798,17 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B } void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { - if(model->grads_memory == nullptr) { + if(model->params_memory[PARAMETER_GRADIENT] == nullptr) { fprintf(stderr, "Need to allocate gradients before backward"); exit(EXIT_FAILURE); } NVTX_RANGE_FN(); - bool last_step = micro_step == grad_accum_steps - 1; - // on the first micro-step zero the gradients, as we're about to += accumulate into them - if (micro_step == 0) { - // there are currently two state vars during the gradient accumulation inner loop: - // 1) the losses accumulate += into acts.losses, reset here - // 2) the gradients accumulate += into grads_memory, reset here - cudaCheck(cudaMemsetAsync(GPU_F32(model->acts.losses), 0, model->batch_size * model->seq_len * sizeof(float), main_stream)); - cudaCheck(cudaMemsetAsync(model->grads_memory, 0, tensors_bytes[TT::PARAMETER_GRADIENT], main_stream)); - } - // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) + ParameterTensors params = model->params[PARAMETER]; + ParameterTensors grads = model->params[PARAMETER_GRADIENT]; + ActivationTensors acts = model->acts; + ActivationTensors acts_grads = model->acts_grads; const size_t B = model->batch_size; const size_t T = model->seq_len; const size_t V = model->config.vocab_size; @@ -840,131 +817,86 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - ParameterTensors params = model->params; // for brevity - ParameterTensors grads = model->params_grads; - ActivationTensors acts = model->acts; + bool last_step = micro_step == grad_accum_steps - 1; + // on the first micro-step zero the gradients, as we're about to += accumulate into them + if (micro_step == 0) { + // there are currently two state vars during the gradient accumulation inner loop: + // 1) the losses accumulate += into acts.losses, reset here + // 2) the gradients accumulate += into grads_memory, reset here + cudaCheck(cudaMemsetAsync(ACT_32(losses), 0, B * T * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(model->params_memory[PARAMETER_GRADIENT], 0, tensors_bytes[PARAMETER_GRADIENT], main_stream)); + } // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier NvtxRange classifier_and_loss_range("classifier_and_loss"); const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); - fused_classifier(GPU_X(acts.output), GPU_F32(acts.losses), dloss, model->targets, B, T, V, Vp, True, main_stream); - - // backward pass: go in the reverse order of the forward pass, and call backward() functions - - // reset residual stream gradients (put here to work with gradient accumulation) - // todo - this should be a dedicated tensor, not addressing multiuse.btc directly! - floatX* dresidual = GPU_X(model->multiuse.btc); // the main buffer holding the gradient in the backward pass - cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); + fused_classifier(AGRAD_X(output), ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, True, main_stream); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = GPU_F32(acts.output); - floatX* scratchX = GPU_X(acts.output); + float* scratchF = ACT_32(output); + floatX* scratchX = ACT_X(output); + + // backward pass: go in the reverse order of the forward pass, and call backward() functions // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(GPU_X(model->multiuse.bt4c), GPU_X(grads.wte), NULL, GPU_X(acts.output), GPU_X(acts.lnf), GPU_X(params.wte), NULL, B, T, C, Vp, main_stream); + matmul_backward(AGRAD_X(lnf), PGRAD_X(wte), NULL, ACT_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); // backward the final layernorm - floatX* residual = GPU_X(acts.residual3 + (L-1)); // last residual is in residual3 - layernorm_backward(dresidual, GPU_X(grads.lnfw), GPU_X(grads.lnfb), scratchF, GPU_X(model->multiuse.bt4c), residual, GPU_X(params.lnfw), GPU_F32(acts.lnf_mean), GPU_F32(acts.lnf_rstd), B, T, C, main_stream); - - // from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic - // scratch for backward computations - floatX* dl_btc = residual; + layernorm_backward(AGRAD_X(residual3 + L-1), NULL, PGRAD_X(lnfw), PGRAD_X(lnfb), scratchF, AGRAD_X(lnf), ACT_X(residual3 + L-1), + PARAM_X(lnfw), ACT_32(lnf_mean), ACT_32(lnf_rstd), B, T, C, main_stream); // now backward all the layers for (int l = L-1; l >= 0; l--) { NvtxRange layer_range("Layer", l); - - residual = GPU_X(l == 0 ? acts.encoded : acts.residual3 + (l-1)); - - // get the pointers of the weights for this layer - floatX* l_ln1w = GPU_X(params.ln1w + l); - floatX* l_ln1b = GPU_X(params.ln1b + l); - floatX* l_qkvw = GPU_X(params.qkvw + l); - floatX* l_attprojw = GPU_X(params.attprojw + l); - floatX* l_ln2w = GPU_X(params.ln2w + l); - floatX* l_ln2b = GPU_X(params.ln2b + l); - floatX* l_fcw = GPU_X(params.fcw + l); - floatX* l_fcprojw = GPU_X(params.fcprojw + l); - // get the pointers of the gradients of the weights for this layer - floatX* dl_ln1w = GPU_X(grads.ln1w + l); - floatX* dl_ln1b = GPU_X(grads.ln1b + l); - floatX* dl_qkvw = GPU_X(grads.qkvw + l); - floatX* dl_qkvb = GPU_X(grads.qkvb + l); - floatX* dl_attprojw = GPU_X(grads.attprojw + l); - floatX* dl_attprojb = GPU_X(grads.attprojb + l); - floatX* dl_ln2w = GPU_X(grads.ln2w + l); - floatX* dl_ln2b = GPU_X(grads.ln2b + l); - floatX* dl_fcw = GPU_X(grads.fcw + l); - floatX* dl_fcb = GPU_X(grads.fcb + l); - floatX* dl_fcprojw = GPU_X(grads.fcprojw + l); - floatX* dl_fcprojb = GPU_X(grads.fcprojb + l); - // get the pointers of the activations for this layer - floatX* l_ln1 = GPU_X(acts.ln1 + l); - float* l_ln1_mean = GPU_F32(acts.ln1_mean + l); - float* l_ln1_rstd = GPU_F32(acts.ln1_rstd + l); - floatX* l_qkvr = GPU_X(acts.qkvr + l); - floatX* l_atty = GPU_X(acts.atty + l); - floatX* l_residual2 = GPU_X(acts.residual2 + l); - floatX* l_ln2 = GPU_X(acts.ln2 + l); - float* l_ln2_mean = GPU_F32(acts.ln2_mean + l); - float* l_ln2_rstd = GPU_F32(acts.ln2_rstd + l); - floatX* l_fch_pre_gelu = GPU_X(acts.fch + l); - floatX* l_fch_gelu = GPU_X(acts.fch_gelu + l); - // get the pointers of the gradients of the activations for this layer - // notice that there is no l *, because we just have a single copy, and keep - // re-using this memory in every Transformer block as we calculate backward pass - - floatX* dl_bt4c = (floatX*)GPU_X(model->multiuse.bt4c); // TODO - should use dedicated tensors! + floatX* residual = (l == 0) ? ACT_X(encoded) : ACT_X(residual3 + (l-1)); + floatX* dresidual = (l == 0) ? AGRAD_X(encoded) : AGRAD_X(residual3 + (l-1)); // start the backward pass for this layer if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu. in this case, // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here - gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream); + gelu_forward(ACT_XL(fch_gelu), ACT_XL(fch), B*T*4*C, main_stream); } - matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion); if(model->recompute >= 2) { // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand - layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream); + layernorm_forward(ACT_XL(ln2), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), ACT_XL(residual2), PARAM_XL(ln2w), PARAM_XL(ln2b), B, T, C, main_stream); + layernorm_forward(ACT_XL(ln1), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), residual, PARAM_XL(ln1w), PARAM_XL(ln1b), B, T, C, main_stream); } - matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream); + + matmul_backward(AGRAD_XL(fch), PGRAD_XL(fcprojw), PGRAD_XL(fcprojb), AGRAD_XL(residual3), ACT_XL(fch_gelu), PARAM_XL(fcprojw), scratchF, B, T, 4*C, C, main_stream, ACT_XL(fch), model->gelu_fusion); + matmul_backward(AGRAD_XL(ln2), PGRAD_XL(fcw), PGRAD_XL(fcb), AGRAD_XL(fch), ACT_XL(ln2), PARAM_XL(fcw), scratchF, B, T, C, 4 * C, main_stream); // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above - layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream); - matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); + layernorm_backward(AGRAD_XL(residual2), AGRAD_XL(residual3), PGRAD_XL(ln2w), PGRAD_XL(ln2b), scratchF, AGRAD_XL(ln2), ACT_XL(residual2), PARAM_XL(ln2w), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), B, T, C, main_stream); + matmul_backward(AGRAD_XL(atty), PGRAD_XL(attprojw), PGRAD_XL(attprojb), AGRAD_XL(residual2), ACT_XL(atty), PARAM_XL(attprojw), scratchF, B, T, C, C, main_stream); #ifdef ENABLE_CUDNN float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + attention_backward_cudnn(AGRAD_XL(qkvr), AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(atty), (float*)l_att, B, T, NH, C, main_stream); #else - floatX* l_att = GPU_X(acts.att + l); // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = l_atty; - floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + floatX* buffer_a = ACT_XL(atty); + floatX* buffer_b = ACT_XL(fch); // this is B x T x 4C, so even larger than what we need + attention_backward(AGRAD_XL(qkvr), buffer_b, scratchX, buffer_a, AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(att), B, T, C, NH, main_stream); #endif - if(model->recompute >= 2) { - layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream); - } // QKV parameter gradients - matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream); + matmul_backward(AGRAD_XL(ln1), PGRAD_XL(qkvw), PGRAD_XL(qkvb), AGRAD_XL(qkvr), ACT_XL(ln1), PARAM_XL(qkvw), scratchF, B, T, C, 3 * C, main_stream); // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above - layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream); + layernorm_backward(dresidual, AGRAD_XL(residual2), PGRAD_XL(ln1w), PGRAD_XL(ln1b), scratchF, AGRAD_XL(ln1), residual, PARAM_XL(ln1w), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), B, T, C, main_stream); // Accumulate gradients from this layer in a background stream. if(last_step) { floatX* const pointers[] = { - dl_ln1w, dl_ln1b, - dl_qkvw, dl_qkvb, - dl_attprojw, dl_attprojb, - dl_ln2w, dl_ln2b, - dl_fcw, dl_fcb, - dl_fcprojw, dl_fcprojb + PGRAD_XL(ln1w), PGRAD_XL(ln1b), + PGRAD_XL(qkvw), PGRAD_XL(qkvb), + PGRAD_XL(attprojw), PGRAD_XL(attprojb), + PGRAD_XL(ln2w), PGRAD_XL(ln2b), + PGRAD_XL(fcw), PGRAD_XL(fcb), + PGRAD_XL(fcprojw), PGRAD_XL(fcprojb) }; const size_t nelem[] = { C, C, @@ -977,20 +909,21 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } } - encoder_backward(GPU_X(grads.wte), GPU_X(grads.wpe), scratchX, model->workload_indices, model->bucket_info, - dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); + + encoder_backward(PGRAD_X(wte), PGRAD_X(wpe), scratchX, model->workload_indices, model->bucket_info, + AGRAD_X(encoded), model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks if(last_step) { // reduce all the losses within the current GPU (across all microsteps) - global_sum_deterministic(model->accumulated_mean_loss, GPU_F32(acts.losses), B*T, main_stream); + global_sum_deterministic(model->accumulated_mean_loss, ACT_32(losses), B*T, main_stream); // reduce loss across GPUs to a single, final float across all microsteps and GPUs #if MULTI_GPU ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {GPU_X(grads.wte), GPU_X(grads.wpe), GPU_X(grads.lnfw), GPU_X(grads.lnfb)}; + floatX* const pointers[] = {PGRAD_X(wte), PGRAD_X(wpe), PGRAD_X(lnfw), PGRAD_X(lnfb)}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -1024,10 +957,11 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - floatX* grads_memory = (floatX*)model->grads_memory; + floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRADIENT]; + ActivationTensors acts = model->acts; // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = GPU_F32(model->acts.output); + float* grad_norm_squared = ACT_32(output); float grad_norm_squared_cpu = 0.0f; int num_slices[2] = {1, model->config.num_layers}; @@ -1077,7 +1011,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // selectively weight decay some, but not all tensors :( // TODO: revisit and probably refactor this entire function NVTX_RANGE_FN(); - if(model->grads_memory == nullptr || model->m_memory == nullptr || model->v_memory == nullptr) { + if(model->params_memory[PARAMETER] == nullptr || model->params_memory[PARAMETER_OPT_M] == nullptr || model->params_memory[PARAMETER_OPT_V] == nullptr) { fprintf(stderr, "Need to allocate optimizer state before update"); exit(EXIT_FAILURE); } @@ -1086,8 +1020,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo if(init_state) { model->init_state = false; NvtxRange rng("InitOpt"); - cudaCheck(cudaMemset(model->m_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->v_memory, 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->params_memory[PARAMETER_OPT_M], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->params_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); } // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint @@ -1191,7 +1125,11 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { } void gpt2_free(GPT2 *model) { - cudaFreeCheck(&model->gpu_mem); + cudaFreeCheck(&model->multiuse_memory); + for (int i = 0; i < TT::NUM_TYPES_PARAM; i++) { + cudaFreeCheck(&model->params_memory[i]); + } + cudaFreeCheck(&model->inputs); cudaFreeCheck(&model->targets); cudaFreeCheck(&model->accumulated_mean_loss); @@ -1263,10 +1201,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->params_memory[PARAMETER_OPT_M], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->params_memory[PARAMETER_OPT_V], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->params_memory[PARAMETER_MASTER], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // write dataloader state if we are using the Permuted version of it @@ -1307,13 +1245,13 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename } model->init_state = false; // we just got the state from file, no need to do first-touch init - assert(model->m_memory != nullptr); - assert(model->v_memory != nullptr); - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + assert(model->params_memory[PARAMETER_OPT_M] != nullptr); + assert(model->params_memory[PARAMETER_OPT_V] != nullptr); + file_to_device(model->params_memory[PARAMETER_OPT_M], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->params_memory[PARAMETER_OPT_V], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - assert(model->master_weights != nullptr); - file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + assert(model->params_memory[PARAMETER_MASTER] != nullptr); + file_to_device(model->params_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); // restore weights from the master weights using the RNG state before last weight update model->rng_state = model->rng_state_last_update; gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); @@ -1806,7 +1744,7 @@ int main(int argc, char *argv[]) { // note this is still somewhat wasteful because we don't have a KV cache! gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = (floatX*)(model.gpu_mem + tensors_start[model.acts.output] + (t - 1) * model.config.padded_vocab_size); + floatX* logits = ((floatX*)model.multiuse_memory) + tensor_specs[model.acts.output].offset + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) From e27385d802d50b373929d53f73e61b2c8edbf044 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 2 Sep 2024 01:20:31 +0000 Subject: [PATCH 05/27] It's alive!!! gpt2_update() is working as is forward/backward/recompute/etc... still lots of work left though! --- llmc/fused_classifier.cuh | 4 +- train_gpt2.cu | 121 +++++++++++++++++++++++++++----------- 2 files changed, 88 insertions(+), 37 deletions(-) diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index ba5a1e554..c01bbf578 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -67,7 +67,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) - fused_classifier_kernel5(floatX* dlogits, floatX* logits, float* losses, floatX* probs, + fused_classifier_kernel5(floatX* dlogits, const floatX* logits, float* losses, floatX* probs, const float dloss, const int* targets, int B, int T, int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; @@ -137,7 +137,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // replaces logits with logit gradients template -void fused_classifier(Type* dlogits, Type* logits, float* losses, +void fused_classifier(Type* dlogits, const Type* logits, float* losses, const float dloss, const int* targets, int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { NVTX_RANGE_FN(); diff --git a/train_gpt2.cu b/train_gpt2.cu index ed71e6afc..f84f15762 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,7 +1,7 @@ /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ -bool UNIQUE_TENSOR_MEMORY = true; +bool UNIQUE_TENSOR_MEMORY = false; #include #include @@ -107,6 +107,7 @@ typedef struct { typedef struct { int bt4c; // (B, T, 4*C) int btc; // (B, T, C) + int local_scratch; // (B, T, C) } MultiuseTensors; typedef struct { @@ -128,8 +129,8 @@ typedef struct { size_t num_parameters; size_t num_parameters_bytes; - void* multiuse_memory = NULL; - void* params_memory[NUM_TYPES_PARAM] = {0}; + char* multiuse_memory = NULL; + char* params_memory[NUM_TYPES_PARAM] = {0}; // other run state configuration int batch_size = 0; // the batch size (B) of current forward pass @@ -156,6 +157,7 @@ typedef struct { size_t offset; // into base pointer size_t num_elements; // per shard size_t num_shards; + int remaining_layers; DType data_type; TT tensor_type; } TensorSpec; @@ -175,6 +177,7 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, strncpy(spec->name, name, 16); spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; + spec->remaining_layers = 0; spec->data_type = data_type; spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; tensors_elements[spec->tensor_type] += spec->num_elements; @@ -201,10 +204,11 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz for (int l = 0; l < num_layers; l++) { char layer_name[16]; assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); - add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, tensor_type); + int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, tensor_type); if (copy_per_layer) { copy_offset_from++; } + tensor_specs[spec].remaining_layers = num_layers - (l + 1); } return first_tensor_id; } @@ -268,6 +272,7 @@ void gpt2_allocate(GPT2 *model) { } else*/ { model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX); model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX); + model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32); // todo - is this oversized? } // 3) activations @@ -317,8 +322,8 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS(ln2, L, BTC); spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); } else { - spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->multiuse.btc); // todo - not OK for backwards - spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->multiuse.btc); // todo - not OK for backwards + spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->acts.lnf); + spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->acts.lnf); spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); } @@ -341,14 +346,18 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS(qkvr, L, 3 * BTC); } else { spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->acts.output); + + int reused_btc = model->acts.residual3 + (L-1); + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.btc); - spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, model->multiuse.btc); - spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, model->multiuse.btc); + spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, model->multiuse.btc); spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, model->multiuse.btc); - spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, model->multiuse.btc); + spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, model->multiuse.btc); spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c); spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c); - spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, model->multiuse.btc); spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c); } @@ -834,9 +843,9 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tokenCheck(targets, B*T, V); fused_classifier(AGRAD_X(output), ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, True, main_stream); // todo - split output & doutput - // re-use the output buffer of the forward pass as a scratchpad during backward pass - float* scratchF = ACT_32(output); - floatX* scratchX = ACT_X(output); + // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer + float* scratchF = MULTI_32(model->multiuse.local_scratch); + floatX* scratchX_HUGE = ACT_X(output); // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -845,7 +854,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(AGRAD_X(lnf), PGRAD_X(wte), NULL, ACT_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); + matmul_backward(AGRAD_X(lnf), PGRAD_X(wte), NULL, AGRAD_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); // backward the final layernorm layernorm_backward(AGRAD_X(residual3 + L-1), NULL, PGRAD_X(lnfw), PGRAD_X(lnfb), scratchF, AGRAD_X(lnf), ACT_X(residual3 + L-1), PARAM_X(lnfw), ACT_32(lnf_mean), ACT_32(lnf_rstd), B, T, C, main_stream); @@ -856,36 +865,31 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int floatX* residual = (l == 0) ? ACT_X(encoded) : ACT_X(residual3 + (l-1)); floatX* dresidual = (l == 0) ? AGRAD_X(encoded) : AGRAD_X(residual3 + (l-1)); - // start the backward pass for this layer - if(model->recompute >= 1) { - // recompute >= 1 means we recompute gelu. in this case, - // l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here + if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu gelu_forward(ACT_XL(fch_gelu), ACT_XL(fch), B*T*4*C, main_stream); } - if(model->recompute >= 2) { - // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand + matmul_backward(AGRAD_XL(fch), PGRAD_XL(fcprojw), PGRAD_XL(fcprojb), AGRAD_XL(residual3), ACT_XL(fch_gelu), PARAM_XL(fcprojw), scratchF, B, T, 4*C, C, main_stream, ACT_XL(fch), model->gelu_fusion); + + if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm layernorm_forward(ACT_XL(ln2), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), ACT_XL(residual2), PARAM_XL(ln2w), PARAM_XL(ln2b), B, T, C, main_stream); - layernorm_forward(ACT_XL(ln1), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), residual, PARAM_XL(ln1w), PARAM_XL(ln1b), B, T, C, main_stream); } - - matmul_backward(AGRAD_XL(fch), PGRAD_XL(fcprojw), PGRAD_XL(fcprojb), AGRAD_XL(residual3), ACT_XL(fch_gelu), PARAM_XL(fcprojw), scratchF, B, T, 4*C, C, main_stream, ACT_XL(fch), model->gelu_fusion); matmul_backward(AGRAD_XL(ln2), PGRAD_XL(fcw), PGRAD_XL(fcb), AGRAD_XL(fch), ACT_XL(ln2), PARAM_XL(fcw), scratchF, B, T, C, 4 * C, main_stream); - // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above layernorm_backward(AGRAD_XL(residual2), AGRAD_XL(residual3), PGRAD_XL(ln2w), PGRAD_XL(ln2b), scratchF, AGRAD_XL(ln2), ACT_XL(residual2), PARAM_XL(ln2w), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), B, T, C, main_stream); matmul_backward(AGRAD_XL(atty), PGRAD_XL(attprojw), PGRAD_XL(attprojb), AGRAD_XL(residual2), ACT_XL(atty), PARAM_XL(attprojw), scratchF, B, T, C, C, main_stream); #ifdef ENABLE_CUDNN - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(AGRAD_XL(qkvr), AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(atty), (float*)l_att, B, T, NH, C, main_stream); + attention_backward_cudnn(AGRAD_XL(qkvr), AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(atty), ACT_32L(att), B, T, NH, C, main_stream); #else // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory floatX* buffer_a = ACT_XL(atty); - floatX* buffer_b = ACT_XL(fch); // this is B x T x 4C, so even larger than what we need - attention_backward(AGRAD_XL(qkvr), buffer_b, scratchX, buffer_a, AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(att), B, T, C, NH, main_stream); + floatX* buffer_b = ACT_XL(fch); + attention_backward(AGRAD_XL(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(att), B, T, C, NH, main_stream); #endif - // QKV parameter gradients + + if(model->recompute >= 2) { + layernorm_forward(ACT_XL(ln1), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), residual, PARAM_XL(ln1w), PARAM_XL(ln1b), B, T, C, main_stream); + } matmul_backward(AGRAD_XL(ln1), PGRAD_XL(qkvw), PGRAD_XL(qkvb), AGRAD_XL(qkvr), ACT_XL(ln1), PARAM_XL(qkvw), scratchF, B, T, C, 3 * C, main_stream); - // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, AGRAD_XL(residual2), PGRAD_XL(ln1w), PGRAD_XL(ln1b), scratchF, AGRAD_XL(ln1), residual, PARAM_XL(ln1w), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), B, T, C, main_stream); // Accumulate gradients from this layer in a background stream. @@ -910,7 +914,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } - encoder_backward(PGRAD_X(wte), PGRAD_X(wpe), scratchX, model->workload_indices, model->bucket_info, + encoder_backward(PGRAD_X(wte), PGRAD_X(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, AGRAD_X(encoded), model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); // Aggregate all gradients that are not part of the transformer blocks @@ -1002,8 +1006,6 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { - return; // debugging forward only for now - // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -1027,6 +1029,55 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; + // todo: merge all tensors into 1 kerne + for (int i = 0; i < tensors_start[PARAMETER_GRADIENT];) { + unsigned int seed = random_u32(&model->rng_state); + + TensorSpec param_spec = tensor_specs[i]; + TensorSpec grad_spec = tensor_specs[i + tensors_start[PARAMETER_GRADIENT]]; + TensorSpec master_spec = tensor_specs[i + tensors_start[PARAMETER_MASTER]]; + TensorSpec opt_m_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_M]]; + TensorSpec opt_v_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_V]]; + + floatX* param_ptr = (floatX*)(&model->params_memory[PARAMETER][param_spec.offset]); + floatX* grad_ptr = (floatX*)(&model->params_memory[PARAMETER_GRADIENT][grad_spec.offset]); + float* m_ptr = (float*)(&model->params_memory[PARAMETER_OPT_M][opt_m_spec.offset]); + float* v_ptr = (float*)(&model->params_memory[PARAMETER_OPT_V][opt_v_spec.offset]); + + float* master_ptr = NULL; + if (model->params_memory[PARAMETER_MASTER] != NULL) { + master_ptr = (float*)(&model->params_memory[PARAMETER_MASTER][master_spec.offset]); + } + + size_t tensor_elements = param_spec.num_elements; + size_t shard_elements = master_spec.num_elements; + int num_layers = param_spec.remaining_layers + 1; + + if(init_state && model->use_master_weights) { + size_t grid_size = CEIL_DIV(shard_elements, 512); + copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard_elements, shard_elements, tensor_elements); + cudaCheck(cudaGetLastError()); + } + + // hack - todo - 2D tensors only check... + float wd = (param_spec.num_elements > (4 * model->config.channels)) ? weight_decay : 0.0f; + + if (init_from_master_only) { + // when resuming training from a checkpoint with master weights (allows changing precision) + //init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); + assert(false); + } else { + // ok finally call the kernel to update the weights with AdamW + adamw_update(param_ptr, master_ptr, grad_ptr, + m_ptr, v_ptr, + shard_elements, tensor_elements, tensor_elements, shard_elements, num_layers, + learning_rate, + beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); + } + + i += num_layers; + } + // AdamW update // handle adamw for all the transformer blocks /* @@ -1742,9 +1793,9 @@ int main(int argc, char *argv[]) { // on cuDNN 9.2.1 with cuDNN FrontEnd 1.5.2, T >= 256 seems bit-for-bit identical // (but even if it wasn't fully identical that's probably not the end of the world) // note this is still somewhat wasteful because we don't have a KV cache! - gpt2_forward(&model, gen_tokens, 1, CEIL_DIV(t, min(T,256)) * min(T,256)); + gpt2_forward(&model, gen_tokens, 1, T); // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = ((floatX*)model.multiuse_memory) + tensor_specs[model.acts.output].offset + (t - 1) * model.config.padded_vocab_size; + floatX* logits = ((floatX*)&model.multiuse_memory[tensor_specs[model.acts.output].offset]) + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) From 0c4b1e1ca60097f1d9f2c9721094e45eb0fa2cca Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 2 Sep 2024 15:22:07 +0000 Subject: [PATCH 06/27] Refactoring into nicer ACT/PARAM_L/etc. macros with get_tensor and autocasting + make main_stream default. New gpt2_forward() feels clean! --- llmc/adamw.cuh | 4 +- llmc/attention.cuh | 4 +- llmc/cuda_common.h | 7 +- llmc/cudnn_att.cpp | 4 +- llmc/encoder.cuh | 4 +- llmc/fused_classifier.cuh | 13 +- llmc/gelu.cuh | 4 +- llmc/global_norm.cuh | 2 +- llmc/layernorm.cuh | 19 ++- llmc/matmul.cuh | 28 ++-- llmc/zero.cuh | 1 - train_gpt2.cu | 301 ++++++++++++++++++++------------------ 12 files changed, 207 insertions(+), 184 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 4453576ee..84d64f391 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -74,7 +74,7 @@ __global__ void init_from_master_kernel(Tp* params_memory, float* master_params_ template void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, - float grad_scale, unsigned int seed, cudaStream_t stream) { + float grad_scale, unsigned int seed, cudaStream_t stream=main_stream) { // AdamW update int block_size = 512; int num_blocks = CEIL_DIV(num_parameters, block_size); @@ -89,7 +89,7 @@ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memo template void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream) { + ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream=main_stream) { int block_size = 512; // must match block size of adamw_update so that RNG also matches int num_blocks = CEIL_DIV(num_parameters, block_size); init_from_master_kernel<<>> diff --git a/llmc/attention.cuh b/llmc/attention.cuh index f6294a213..3dc5cd52f 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -194,7 +194,7 @@ __global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, con void attention_forward(floatX* out, floatX* qkvr, floatX* att, floatX* inp, - int B, int T, int C, int NH, cudaStream_t stream) { + int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. // Its contents will be overwritten by this function. @@ -239,7 +239,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch, const floatX* dout, const floatX* qkvr, const floatX* att, - int B, int T, int C, int NH, cudaStream_t stream) { + int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int HS = C / NH; // head size diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 0baade163..49c2b910d 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -26,6 +26,9 @@ Common utilities for CUDA code. // but it is actually created and instantiated in the main program file extern cudaDeviceProp deviceProp; +// Main stream used by default for all CUDA operations +extern cudaStream_t main_stream; + // WarpSize is not a compile time constant // Defining here like this possibly allows the compiler to optimize better #define WARP_SIZE 32U @@ -130,7 +133,7 @@ class NvtxRange { // Utilities to Read & Write between CUDA memory <-> files // copy num_bytes from device pointer src into file dest, using double buffering running on the given stream. -inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { +inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream=main_stream) { // allocate pinned buffer for faster, async transfer char* buffer_space; cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size)); @@ -169,7 +172,7 @@ inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffe } // copy num_bytes from file src into device pointer dest, using double buffering running on the given stream. -inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { +inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream=main_stream) { // allocate pinned buffer for faster, async transfer // from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html) // WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers. diff --git a/llmc/cudnn_att.cpp b/llmc/cudnn_att.cpp index 0330abe20..3d2f8af4d 100644 --- a/llmc/cudnn_att.cpp +++ b/llmc/cudnn_att.cpp @@ -222,7 +222,7 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) { void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) float* stats, // output for backward pass: (B, NH, T) floatX* inp, // input: (B, T, 3, NH, HS) QKV - int B, int T, int NH, int C, cudaStream_t stream) { + int B, int T, int NH, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); int HS = C / NH; // number of features per head bool is_inference_only = (stats == nullptr); @@ -255,7 +255,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS) void attention_backward_cudnn(floatX* dqkvr, // output floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs - int B, int T, int NH, int C, cudaStream_t stream) { + int B, int T, int NH, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); int HS = C / NH; // number of features per head diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 3aa63e175..5af09476c 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -156,7 +156,7 @@ __global__ void wpe_backward_kernel(floatX* dwpe, void encoder_forward(floatX* out, const int* inp, const floatX* wte, const floatX* wpe, - int B, int T, int C, cudaStream_t stream) { + int B, int T, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int N = B * T * C; @@ -169,7 +169,7 @@ void encoder_forward(floatX* out, void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs - int B, int T, int C, unsigned int seed, cudaStream_t stream) { + int B, int T, int C, unsigned int seed, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte) diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index c01bbf578..8b29ca233 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -69,7 +69,7 @@ template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* dlogits, const floatX* logits, float* losses, floatX* probs, const float dloss, const int* targets, - int B, int T, int V, int P, std::bool_constant) { + int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit @@ -136,14 +136,13 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // kernel launchers // replaces logits with logit gradients -template -void fused_classifier(Type* dlogits, const Type* logits, float* losses, +template +void fused_classifier(floatX* dlogits, const floatX* logits, float* losses, const float dloss, const int* targets, - int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { + int BT, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 1024; - const int N = B * T; - const int grid_size = N; - fused_classifier_kernel5<<>>(dlogits, logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); + const int grid_size = BT; + fused_classifier_kernel5<<>>(dlogits, logits, losses, (floatX*)NULL, dloss, targets, V, P, write_dlogits); cudaCheck(cudaGetLastError()); } diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index cd5c297b6..138daa40a 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -47,7 +47,7 @@ __global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp // ---------------------------------------------------------------------------- // kernel launchers -void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) { +void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; assert(N % (block_size * x128::size) == 0); @@ -56,7 +56,7 @@ void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream) { cudaCheck(cudaGetLastError()); } -void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream) { +void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 128; assert(N % (block_size * x128::size) == 0); diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index e0e23b08a..968171a81 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -66,7 +66,7 @@ int get_max_num_block_sums(int* num_slices_all, int numel) { } template -void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream) { +void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream=main_stream) { const int block_size = 512; // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. // having one block less than possible is a tiny performance hit, having diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index da0bf9591..cd66dbf60 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -235,7 +235,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with layernorm_backward_kernel10(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, - int B, int T, int C) { + int BT, int C) { int BLOCK_SIZE = blockDim.x; int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; @@ -264,7 +264,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } __syncthreads(); - for (int bt = baseIdx; bt < B * T; bt += warpsInGrid) { + for (int bt = baseIdx; bt < BT; bt += warpsInGrid) { const floatX* dout_bt = dout + bt * C; const floatX* inp_bt = inp +bt * C; floatX* dinp_bt = dinp_old + bt * C; @@ -436,11 +436,10 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // similar to `fused_residual_forward5` void layernorm_forward(floatX* out, float* mean, float* rstd, floatX* inp, const floatX* weight, const floatX* bias, - int B, int T, int C, cudaStream_t stream) { + int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; int block_y = block_size / WARP_SIZE; - const int N = B * T; const int grid_size = CEIL_DIV(N, block_y); size_t smem = (2 + block_y) * C * sizeof(floatX); @@ -459,7 +458,7 @@ void layernorm_forward(floatX* out, float* mean, float* rstd, cudaCheck(cudaGetLastError()); } -void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream) { +void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; assert(N % (block_size * x128::size) == 0); @@ -471,7 +470,7 @@ void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, float* rstd, const floatX* inp1, const floatX* inp2, const floatX* weight, const floatX* bias, - int N, int C, cudaStream_t stream) { + int N, int C, cudaStream_t stream=main_stream) { const int block_size = 256; int block_y = block_size / WARP_SIZE; const int grid_size = CEIL_DIV(N, block_y); @@ -488,14 +487,14 @@ void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, floa weight, bias, N, C); } else { residual_forward(residual, inp1, inp2, N*C, stream); - layernorm_forward(normed, mean, rstd, residual, weight, bias, N, 1, C, stream); + layernorm_forward(normed, mean, rstd, residual, weight, bias, N, C, stream); } cudaCheck(cudaGetLastError()); } void layernorm_backward(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, - int B, int T, int C, cudaStream_t stream) { + int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; const int blocks_per_sm = 2; // supported on every architecture and less cache thrashing than 3 @@ -505,9 +504,9 @@ void layernorm_backward(floatX* dinp_new, floatX* dinp_old, floatX* dweight, flo cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 if (dinp_old == nullptr) { - layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); } else { - layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); } cudaCheck(cudaGetLastError()); } diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index becc372c6..af3398e78 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -14,7 +14,7 @@ Matrix Multiplication, with help from cuBLASLt // CUDA kernels template -__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC, +__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int BT, int OC, std::bool_constant) { constexpr const int bdx = 4; constexpr const int bdy = WARP_SIZE / bdx; @@ -40,7 +40,7 @@ __global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout if(global_oc < OC) { // sum up over all bt within registers - for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { + for (int idx = blockIdx.y * bt_per_block + local_bt; idx < BT; idx += gridDim.y * bt_per_block) { x128 packed_dout = load128(dout + global_oc + idx*OC); for (int k = 0; k < x128::size; k++) { accumulators[k] += (float)packed_dout[k]; @@ -230,22 +230,22 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* // small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments) void matmul_forward_cublaslt(floatX* out, floatX* inp, floatX* weight, floatX* bias, - int B, int T, int C, int OC, cudaStream_t stream, - floatX* pre_gelu=NULL, int gelu_fusion=1) { + int BT, int C, int OC, + floatX* pre_gelu=NULL, int gelu_fusion=1, cudaStream_t stream=main_stream) { // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) if (gelu_fusion < 1 && pre_gelu) { - matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); - gelu_forward(out, pre_gelu, B*T*OC, stream); + matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); + gelu_forward(out, pre_gelu, BT*OC, stream); } else { - matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); + matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); } } void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, floatX* dout, floatX* inp, floatX* weight, float* dbias_buffer, - int B, int T, int C, int OC, cudaStream_t stream, - floatX* pre_gelu=NULL, int gelu_fusion=1) { + int BT, int C, int OC, + floatX* pre_gelu=NULL, int gelu_fusion=1, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // backward to bias, if given, does a += @@ -263,11 +263,11 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { - matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, False); + matmul_backward_bias_kernel9<<>>(dbias, dout, BT, OC, False); cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset - matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, True); + matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, BT, OC, True); cudaCheck(cudaGetLastError()); reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); cudaCheck(cudaGetLastError()); @@ -276,15 +276,15 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, } // backward to input, uses = in the backward pass (set the gradient) - matmul_cublaslt(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false, + matmul_cublaslt(dinp, weight, dout, NULL, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, gelu_fusion >= 2 ? pre_gelu : NULL, true); // backward GELU (if it wasn't fused into the matmul above) if (gelu_fusion < 2 && pre_gelu) { - gelu_backward_inplace(dinp, pre_gelu, B*T*C, stream); + gelu_backward_inplace(dinp, pre_gelu, BT*C, stream); } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one - matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, B*T, stream, false, true, 0, 0, 0, 0, + matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, BT, stream, false, true, 0, 0, 0, 0, true /* accumulate */, NULL, true); } diff --git a/llmc/zero.cuh b/llmc/zero.cuh index e6c5b6e7c..37f8c1b1f 100644 --- a/llmc/zero.cuh +++ b/llmc/zero.cuh @@ -594,4 +594,3 @@ float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* config) { } #endif - diff --git a/train_gpt2.cu b/train_gpt2.cu index f84f15762..256bfba46 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,7 +1,7 @@ /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ -bool UNIQUE_TENSOR_MEMORY = false; +#define UNIQUE_TENSOR_MEMORY false #include #include @@ -108,6 +108,8 @@ typedef struct { int bt4c; // (B, T, 4*C) int btc; // (B, T, C) int local_scratch; // (B, T, C) + int output_scratch; // huge + int output_scratch_fp32; // typically same buffer as above } MultiuseTensors; typedef struct { @@ -152,29 +154,57 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; +// todo: need flags, subtypes (e.g. act gradient), etc... typedef struct { - char name[16]; + char* ptr; size_t offset; // into base pointer size_t num_elements; // per shard - size_t num_shards; - int remaining_layers; + int id; + short num_shards; + short remaining_layers; DType data_type; TT tensor_type; + char name[16]; + + template + operator T*() const { + if (std::is_same::value && data_type != DType::FP32 || + std::is_same::value && data_type != DType::FP16 || + std::is_same::value && data_type != DType::BF16) { + printf("ERROR: Unexpected data type (%d) for tensor %s\n", (int)data_type, name); + exit(EXIT_FAILURE); + } + return reinterpret_cast(ptr); + } } TensorSpec; -TensorSpec tensor_specs[16*1024]; +constexpr size_t MAX_TENSORS = 16*1024; +TensorSpec tensor_specs[MAX_TENSORS] = {0}; size_t num_tensor_specs = 0; TT current_tensor_type = TT::PARAMETER; size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; size_t tensors_elements[TT::COUNT] = {0}; +TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { + TensorSpec spec = tensor_specs[spec_index]; + if (layer > 0 && spec.remaining_layers >= layer) { + spec = tensor_specs[spec_index + layer]; + } else if (layer > 0 && spec.remaining_layers > 0) { + printf("ERROR: get_tensor() for %s layer %d but only %d layers remaining\n", spec.name, layer, spec.remaining_layers); + assert(false); + } + assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); + return spec; +} + int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, TT tensor_type=TT::DEFAULT) { assert(num_tensor_specs < 16*1024); assert((total_elements % num_shards) == 0); - TensorSpec* spec = &tensor_specs[num_tensor_specs++]; - + TensorSpec* spec = &tensor_specs[num_tensor_specs]; strncpy(spec->name, name, 16); + + spec->id = num_tensor_specs; spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; spec->remaining_layers = 0; @@ -182,7 +212,6 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; tensors_elements[spec->tensor_type] += spec->num_elements; - if (copy_offset_from >= 0) { spec->offset = tensor_specs[copy_offset_from].offset; size_t original_tensor_bytes = tensor_specs[copy_offset_from].num_elements * sizeof_dtype(tensor_specs[copy_offset_from].data_type); @@ -193,10 +222,10 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->offset = tensors_bytes[spec->tensor_type]; tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { - tensors_start[spec->tensor_type] = num_tensor_specs - 1; + tensors_start[spec->tensor_type] = num_tensor_specs; } } - return num_tensor_specs - 1; + return num_tensor_specs++; } int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, bool copy_per_layer=false, TT tensor_type=TT::DEFAULT) { @@ -273,6 +302,9 @@ void gpt2_allocate(GPT2 *model) { model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX); model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX); model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32); // todo - is this oversized? + model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32); // todo - is this oversized? + model->multiuse.output_scratch = add_tensor_spec("output_fpx", output_size, 1, DTYPE_FLOATX); + model->multiuse.output_scratch_fp32 = add_tensor_spec("output_fp32", output_size / 2, 1, DType::FP32, model->multiuse.output_scratch); } // 3) activations @@ -286,7 +318,6 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS_FP32(lnf_mean, 1, BT); TENSOR_SPECS_FP32(lnf_rstd, 1, BT); TENSOR_SPECS_FP32(losses, 1, BT); - TENSOR_SPECS (output, 1, output_size); TENSOR_SPECS_FP32(ln1_mean, L, BT); TENSOR_SPECS_FP32(ln1_rstd, L, BT); @@ -304,9 +335,11 @@ void gpt2_allocate(GPT2 *model) { #endif if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS (output, 1, output_size); TENSOR_SPECS_LOWP(fcproj, L, BTC); TENSOR_SPECS_LOWP(attproj, L, BTC); } else { + spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch); spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc); spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc); } @@ -345,7 +378,7 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS(residual3, L, BTC); TENSOR_SPECS(qkvr, L, 3 * BTC); } else { - spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->acts.output); + spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch); int reused_btc = model->acts.residual3 + (L-1); spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc); @@ -373,26 +406,38 @@ void gpt2_allocate(GPT2 *model) { cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_MASTER], tensors_bytes[PARAMETER_MASTER])); } + // Set the ptr for each tensor spec based on type and offset + for (size_t i = 0; i < num_tensor_specs; i++) { + TensorSpec* spec = &tensor_specs[i]; + switch (spec->tensor_type) { + case PARAMETER: + case PARAMETER_GRADIENT: + case PARAMETER_OPT_M: + case PARAMETER_OPT_V: + spec->ptr = model->params_memory[spec->tensor_type] + spec->offset; + break; + case PARAMETER_MASTER: + if (model->use_master_weights) { + spec->ptr = model->params_memory[PARAMETER_MASTER] + spec->offset; + } + break; + case ACTIVATIONS_MULTIUSE: + spec->ptr = model->multiuse_memory + spec->offset; + break; + default: assert(false); + } + } + //initialise helper variables model->num_parameters = tensors_elements[TT::PARAMETER]; model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; - // printf gpu_mem and params_memory - // parameter gradient bytes - size_t param_grad_bytes = tensors_bytes[TT::PARAMETER_GRADIENT]; - printf("number of parameter gradient bytes: %zu MiB\n", param_grad_bytes / (1024*1024)); - // number of master weight bytes - size_t master_weight_bytes = tensors_bytes[TT::PARAMETER_MASTER]; - printf("number of master weight bytes: %zu MiB\n", master_weight_bytes / (1024*1024)); - // opt state m - size_t m_bytes = tensors_bytes[TT::PARAMETER_OPT_M]; - printf("number of m bytes: %zu MiB\n", m_bytes / (1024*1024)); - // opt state v - size_t v_bytes = tensors_bytes[TT::PARAMETER_OPT_V]; - printf("number of v bytes: %zu MiB\n", v_bytes / (1024*1024)); - // number of multiuse bytes - size_t multiuse_bytes = tensors_bytes[TT::ACTIVATIONS_MULTIUSE]; - printf("number of act+actgrad+multiuse bytes: %zu MiB\n", (multiuse_bytes) / (1024*1024)); + printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); + printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRADIENT] / (1024*1024)); + printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); + printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); + printf("number of v bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_V] / (1024*1024)); + printf("number of act+actgrad+multiuse bytes: %zu MiB\n", tensors_bytes[TT::ACTIVATIONS_MULTIUSE] / (1024*1024)); // ======================= // allocate_state stuff @@ -442,7 +487,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - device_to_file(model_file, model->params_memory, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); + device_to_file(model_file, model->params_memory, model->num_parameters_bytes, IO_BUF_SIZE); // close file, we're done fcloseCheck(model_file); } @@ -502,7 +547,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // read in the parameters if weight_init is true if (weight_init) { - file_to_device(model->params_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream); + file_to_device(model->params_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); } fcloseCheck(model_file); @@ -606,7 +651,7 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { || i == 4 || i == 6 || i == 10 || i == 12) { size_t n = model->param_elements[i]; size_t layer_offset = 0; - if (i == 0) { + if (i == 0) {rer // for wte tensor (padded vocab) override to init V instead of Vp rows n = model->config.vocab_size * model->config.channels; } @@ -636,38 +681,13 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { free(params_memory_cpu); } -#define ACT_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[acts.x].offset) -#define ACT_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[acts.x].offset) -#define ACT_XL(x) ACT_X(x + l) -#define ACT_32L(x) ACT_32(x + l) - -#define PARAM_X(x) (floatX*)((char*)model->params_memory[PARAMETER] + tensor_specs[params.x].offset) -#define PARAM_32(x) (float*)((char*)model->params_memory[PARAMETER] + tensor_specs[params.x].offset) -#define PARAM_XL(x) PARAM_X(x + l) -#define PARAM_32L(x) PARAM_32(x + l) - -#define PGRAD_X(x) (floatX*)((char*)model->params_memory[PARAMETER_GRADIENT] + tensor_specs[grads.x].offset) -#define PGRAD_32(x) (float*)((char*)model->params_memory[PARAMETER_GRADIENT] + tensor_specs[grads.x].offset) -#define PGRAD_XL(x) PGRAD_X(x + l) -#define PGRAD_32L(x) PGRAD_32(x + l) - -#define AGRAD_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[acts_grads.x].offset) -#define AGRAD_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[acts_grads.x].offset) -#define AGRAD_XL(x) AGRAD_X(x + l) -#define AGRAD_32L(x) AGRAD_32(x + l) - -#define MULTI_X(x) (floatX*)((char*)model->multiuse_memory + tensor_specs[x].offset) -#define MULTI_32(x) (float*)((char*)model->multiuse_memory + tensor_specs[x].offset) -#define MULTI_XL(x) PARAM_X(x + l) -#define MULTI_32L(x) PARAM_32(x + l) - - // debug helper function void print_tensor_elements(GPT2 *model, int tensor_id) { - const char* tensor_name = tensor_specs[tensor_id].name; - size_t num_elements = tensor_specs[tensor_id].num_elements; - TT tensor_type = tensor_specs[tensor_id].tensor_type; - DType dtype = tensor_specs[tensor_id].data_type; + TensorSpec spec = tensor_specs[tensor_id]; + size_t num_elements = spec.num_elements; + const char* tensor_name = spec.name; + TT tensor_type = spec.tensor_type; + DType dtype = spec.data_type; size_t element_size = sizeof_dtype(dtype); void* gpu_memory = (tensor_id == TT::ACTIVATIONS_MULTIUSE) ? model->multiuse_memory : model->params_memory[tensor_type]; @@ -715,64 +735,73 @@ void print_tensor_elements(GPT2 *model, int tensor_id) { free(cpu_tensor); } +// Helper macros for accessing tensors +#define ACT_L(x,l) get_tensor(model->acts.x, ACTIVATIONS_MULTIUSE, l) +#define AGRAD_L(x,l) get_tensor(model->acts_grads.x, ACTIVATIONS_MULTIUSE, l) +#define PARAM_L(x,l) get_tensor(model->params[PARAMETER].x, PARAMETER, l) +#define PGRAD_L(x,l) get_tensor(model->params[PARAMETER_GRADIENT].x, PARAMETER_GRADIENT, l) +#define MULTI_L(x,l) get_tensor(model->multiuse.x, ACTIVATIONS_MULTIUSE, l) +#define TENSOR(x,l) get_tensor(x, DEFAULT, l) +#define ACT(x) ACT_L(x,l) +#define AGRAD(x) AGRAD_L(x,l) +#define PARAM(x) PARAM_L(x,l) +#define PGRAD(x) PGRAD_L(x,l) +#define MULTI(x) MULTI_L(x,l) + // propagate inputs through the network to produce logits. void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); - // we must be careful and use size_t instead of int, otherwise we could overflow - ParameterTensors params = model->params[PARAMETER]; - ActivationTensors acts = model->acts; const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - // validate B,T are not larger than the values used at initialisation - // (smaller B,T are okay for inference only) + // validate B,T are not larger than the values at initialisation (smaller is OK for inference) if (B > model->batch_size || T > model->seq_len) { printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T); exit(EXIT_FAILURE); } // unused parts of attention buffer must be zeroed for non-cuDNN path if (!CUDNN_ENABLED && T != model->seq_len) { - cudaCheck(cudaMemset(ACT_X(att), 0, L * B * NH * T * T * sizeof(floatX))); + cudaCheck(cudaMemset(ACT_L(att, 0), 0, L * B * NH * T * T * sizeof(floatX))); } - - // copy inputs/targets to the model (fully synchronous with the host for now) - cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); // validate inputs, all indices must be in the range [0, V) tokenCheck(inputs, B*T, V); - // start of forward pass with encoder - encoder_forward(ACT_X(encoded), model->inputs, PARAM_X(wte), PARAM_X(wpe), B, T, C, main_stream); // encoding goes into residual[0] - layernorm_forward(ACT_X(ln1), ACT_32(ln1_mean), ACT_32(ln1_rstd), ACT_X(encoded), PARAM_X(ln1w), PARAM_X(ln1b), B, T, C, main_stream); + // copy inputs/targets to the model (fully synchronous with the host for now) + cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); + // start of forward pass with encoder (layer 0) + int l = 0; + encoder_forward(ACT(encoded), model->inputs, PARAM(wte), PARAM(wpe), B, T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); - for (int l = 0; l < L; l++) { + for (; l < L; l++) { NvtxRange layer_range("Layer", l); - floatX* input_residual = l == 0 ? ACT_X(encoded) : ACT_X(residual3 + l-1); + floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); - matmul_forward_cublaslt(CUDNN_ENABLED ? ACT_XL(qkvr) : ACT_X(output), ACT_XL(ln1), PARAM_XL(qkvw), PARAM_XL(qkvb), B, T, C, 3*C, main_stream); + matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN - attention_forward_cudnn(ACT_XL(atty), ACT_32L(att), ACT_XL(qkvr), B, T, NH, C, main_stream); + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); #else - attention_forward(ACT_XL(atty), ACT_XL(qkvr), ACT_XL(att), ACT_X(output), B, T, C, NH, main_stream); + attention_forward(ACT(atty), ACT(qkvr), ACT(att), MULTI(output_scratch), B, T, C, NH); #endif - matmul_forward_cublaslt(ACT_XL(attproj), ACT_XL(atty), PARAM_XL(attprojw), PARAM_XL(attprojb), B, T, C, C, main_stream); - fused_residual_forward5(ACT_XL(residual2), ACT_XL(ln2), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), input_residual, ACT_XL(attproj), PARAM_XL(ln2w), PARAM_XL(ln2b), B*T, C, main_stream); - matmul_forward_cublaslt(ACT_XL(fch_gelu), ACT_XL(ln2), PARAM_XL(fcw), PARAM_XL(fcb), B, T, C, 4*C, main_stream, ACT_XL(fch), model->gelu_fusion); - matmul_forward_cublaslt(ACT_XL(fcproj), ACT_XL(fch_gelu), PARAM_XL(fcprojw), PARAM_XL(fcprojb), B, T, 4*C, C, main_stream); + matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); - if(l+1 != L) { // fusion across layers - fused_residual_forward5(ACT_XL(residual3), ACT_XL(ln1 + 1), ACT_32L(ln1_mean + 1), ACT_32L(ln1_rstd + 1), ACT_XL(residual2), ACT_XL(fcproj), - PARAM_XL(ln1w + 1), PARAM_XL(ln1b + 1), B * T, C, main_stream); + if(l+1 != L) { + fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), + PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); } else { - fused_residual_forward5(ACT_XL(residual3), ACT_X(lnf), ACT_32(lnf_mean), ACT_32(lnf_rstd), ACT_XL(residual2), ACT_XL(fcproj), - PARAM_X(lnfw), PARAM_X(lnfb), B * T, C, main_stream); + fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), + PARAM(lnfw), PARAM(lnfb), B*T, C); } } - matmul_forward_cublaslt(ACT_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); + matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), NULL, B*T, C, Vp); } @@ -786,6 +815,7 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; + int l = 0; NvtxRange classifier_and_loss_range("classifier_and_loss"); ActivationTensors acts = model->acts; @@ -793,11 +823,11 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(ACT_32(losses), 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(ACT(losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(ACT_X(output), ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, False, main_stream); - cudaCheck(cudaMemcpy(model->cpu_losses, ACT_32(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(ACT(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, False); + cudaCheck(cudaMemcpy(model->cpu_losses, ACT(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -814,10 +844,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int NVTX_RANGE_FN(); // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) - ParameterTensors params = model->params[PARAMETER]; - ParameterTensors grads = model->params[PARAMETER_GRADIENT]; - ActivationTensors acts = model->acts; - ActivationTensors acts_grads = model->acts_grads; const size_t B = model->batch_size; const size_t T = model->seq_len; const size_t V = model->config.vocab_size; @@ -825,6 +851,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; + int l = L-1; bool last_step = micro_step == grad_accum_steps - 1; // on the first micro-step zero the gradients, as we're about to += accumulate into them @@ -832,7 +859,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // there are currently two state vars during the gradient accumulation inner loop: // 1) the losses accumulate += into acts.losses, reset here // 2) the gradients accumulate += into grads_memory, reset here - cudaCheck(cudaMemsetAsync(ACT_32(losses), 0, B * T * sizeof(float), main_stream)); + cudaCheck(cudaMemsetAsync(ACT(losses), 0, B * T * sizeof(float), main_stream)); cudaCheck(cudaMemsetAsync(model->params_memory[PARAMETER_GRADIENT], 0, tensors_bytes[PARAMETER_GRADIENT], main_stream)); } @@ -841,11 +868,11 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); - fused_classifier(AGRAD_X(output), ACT_X(output), ACT_32(losses), dloss, model->targets, B, T, V, Vp, True, main_stream); // todo - split output & doutput + fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - float* scratchF = MULTI_32(model->multiuse.local_scratch); - floatX* scratchX_HUGE = ACT_X(output); + float* scratchF = TENSOR(model->multiuse.local_scratch, 0); + floatX* scratchX_HUGE = TENSOR(model->multiuse.output_scratch, 0); // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -854,53 +881,53 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(AGRAD_X(lnf), PGRAD_X(wte), NULL, AGRAD_X(output), ACT_X(lnf), PARAM_X(wte), NULL, B, T, C, Vp, main_stream); + matmul_backward(AGRAD(lnf), PGRAD(wte), NULL, AGRAD(output), ACT(lnf), PARAM(wte), NULL, B*T, C, Vp); // backward the final layernorm - layernorm_backward(AGRAD_X(residual3 + L-1), NULL, PGRAD_X(lnfw), PGRAD_X(lnfb), scratchF, AGRAD_X(lnf), ACT_X(residual3 + L-1), - PARAM_X(lnfw), ACT_32(lnf_mean), ACT_32(lnf_rstd), B, T, C, main_stream); + layernorm_backward(AGRAD(residual3 + L-1), NULL, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT(residual3 + L-1), + PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers - for (int l = L-1; l >= 0; l--) { + for (; l >= 0; l--) { NvtxRange layer_range("Layer", l); - floatX* residual = (l == 0) ? ACT_X(encoded) : ACT_X(residual3 + (l-1)); - floatX* dresidual = (l == 0) ? AGRAD_X(encoded) : AGRAD_X(residual3 + (l-1)); + floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + floatX* dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT_XL(fch_gelu), ACT_XL(fch), B*T*4*C, main_stream); + gelu_forward(ACT(fch_gelu), ACT(fch), B*T*4*C); } - matmul_backward(AGRAD_XL(fch), PGRAD_XL(fcprojw), PGRAD_XL(fcprojb), AGRAD_XL(residual3), ACT_XL(fch_gelu), PARAM_XL(fcprojw), scratchF, B, T, 4*C, C, main_stream, ACT_XL(fch), model->gelu_fusion); + matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm - layernorm_forward(ACT_XL(ln2), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), ACT_XL(residual2), PARAM_XL(ln2w), PARAM_XL(ln2b), B, T, C, main_stream); + layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } - matmul_backward(AGRAD_XL(ln2), PGRAD_XL(fcw), PGRAD_XL(fcb), AGRAD_XL(fch), ACT_XL(ln2), PARAM_XL(fcw), scratchF, B, T, C, 4 * C, main_stream); - layernorm_backward(AGRAD_XL(residual2), AGRAD_XL(residual3), PGRAD_XL(ln2w), PGRAD_XL(ln2b), scratchF, AGRAD_XL(ln2), ACT_XL(residual2), PARAM_XL(ln2w), ACT_32L(ln2_mean), ACT_32L(ln2_rstd), B, T, C, main_stream); - matmul_backward(AGRAD_XL(atty), PGRAD_XL(attprojw), PGRAD_XL(attprojb), AGRAD_XL(residual2), ACT_XL(atty), PARAM_XL(attprojw), scratchF, B, T, C, C, main_stream); + matmul_backward(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, B*T, C, 4 * C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); #ifdef ENABLE_CUDNN - attention_backward_cudnn(AGRAD_XL(qkvr), AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(atty), ACT_32L(att), B, T, NH, C, main_stream); + attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C); #else // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = ACT_XL(atty); - floatX* buffer_b = ACT_XL(fch); - attention_backward(AGRAD_XL(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD_XL(atty), ACT_XL(qkvr), ACT_XL(att), B, T, C, NH, main_stream); + floatX* buffer_a = ACT(atty); + floatX* buffer_b = ACT(fch); + attention_backward(AGRAD(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); #endif if(model->recompute >= 2) { - layernorm_forward(ACT_XL(ln1), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), residual, PARAM_XL(ln1w), PARAM_XL(ln1b), B, T, C, main_stream); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } - matmul_backward(AGRAD_XL(ln1), PGRAD_XL(qkvw), PGRAD_XL(qkvb), AGRAD_XL(qkvr), ACT_XL(ln1), PARAM_XL(qkvw), scratchF, B, T, C, 3 * C, main_stream); - layernorm_backward(dresidual, AGRAD_XL(residual2), PGRAD_XL(ln1w), PGRAD_XL(ln1b), scratchF, AGRAD_XL(ln1), residual, PARAM_XL(ln1w), ACT_32L(ln1_mean), ACT_32L(ln1_rstd), B, T, C, main_stream); + matmul_backward(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, B*T, C, 3 * C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. if(last_step) { floatX* const pointers[] = { - PGRAD_XL(ln1w), PGRAD_XL(ln1b), - PGRAD_XL(qkvw), PGRAD_XL(qkvb), - PGRAD_XL(attprojw), PGRAD_XL(attprojb), - PGRAD_XL(ln2w), PGRAD_XL(ln2b), - PGRAD_XL(fcw), PGRAD_XL(fcb), - PGRAD_XL(fcprojw), PGRAD_XL(fcprojb) + PGRAD(ln1w), PGRAD(ln1b), + PGRAD(qkvw), PGRAD(qkvb), + PGRAD(attprojw), PGRAD(attprojb), + PGRAD(ln2w), PGRAD(ln2b), + PGRAD(fcw), PGRAD(fcb), + PGRAD(fcprojw), PGRAD(fcprojb) }; const size_t nelem[] = { C, C, @@ -914,20 +941,20 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } - encoder_backward(PGRAD_X(wte), PGRAD_X(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, - AGRAD_X(encoded), model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream); + encoder_backward(PGRAD(wte), PGRAD(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, + AGRAD(encoded), model->inputs, inputs, B, T, C, random_u32(&model->rng_state)); // Aggregate all gradients that are not part of the transformer blocks if(last_step) { // reduce all the losses within the current GPU (across all microsteps) - global_sum_deterministic(model->accumulated_mean_loss, ACT_32(losses), B*T, main_stream); + global_sum_deterministic(model->accumulated_mean_loss, ACT(losses), B*T, main_stream); // reduce loss across GPUs to a single, final float across all microsteps and GPUs #if MULTI_GPU ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); #endif cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); // reduce the gradients for non-transformer block parameters - floatX* const pointers[] = {PGRAD_X(wte), PGRAD_X(wpe), PGRAD_X(lnfw), PGRAD_X(lnfb)}; + floatX* const pointers[] = {PGRAD(wte), PGRAD(wpe), PGRAD(lnfw), PGRAD(lnfb)}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } @@ -962,10 +989,9 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRADIENT]; - ActivationTensors acts = model->acts; // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = ACT_32(output); + float* grad_norm_squared = MULTI_L(output_scratch_fp32, 0); float grad_norm_squared_cpu = 0.0f; int num_slices[2] = {1, model->config.num_layers}; @@ -1029,7 +1055,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; - // todo: merge all tensors into 1 kerne + // todo: merge everything into 1 kernel call for (int i = 0; i < tensors_start[PARAMETER_GRADIENT];) { unsigned int seed = random_u32(&model->rng_state); @@ -1039,14 +1065,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo TensorSpec opt_m_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_M]]; TensorSpec opt_v_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_V]]; - floatX* param_ptr = (floatX*)(&model->params_memory[PARAMETER][param_spec.offset]); - floatX* grad_ptr = (floatX*)(&model->params_memory[PARAMETER_GRADIENT][grad_spec.offset]); - float* m_ptr = (float*)(&model->params_memory[PARAMETER_OPT_M][opt_m_spec.offset]); - float* v_ptr = (float*)(&model->params_memory[PARAMETER_OPT_V][opt_v_spec.offset]); - + // todo - adjust offset into params/grads when optimiser state is sharded + floatX* param_ptr = (floatX*)param_spec.ptr; float* master_ptr = NULL; if (model->params_memory[PARAMETER_MASTER] != NULL) { - master_ptr = (float*)(&model->params_memory[PARAMETER_MASTER][master_spec.offset]); + master_ptr = (float*)master_spec.ptr; } size_t tensor_elements = param_spec.num_elements; @@ -1068,8 +1091,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo assert(false); } else { // ok finally call the kernel to update the weights with AdamW - adamw_update(param_ptr, master_ptr, grad_ptr, - m_ptr, v_ptr, + adamw_update(param_ptr, master_ptr, (floatX*)grad_spec.ptr, + (float*)opt_m_spec.ptr, (float*)opt_v_spec.ptr, shard_elements, tensor_elements, tensor_elements, shard_elements, num_layers, learning_rate, beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); From 65264ed2f20e21077b76dda6227873c07f3df3b7 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 2 Sep 2024 22:22:13 +0000 Subject: [PATCH 07/27] Activation checkpointing for entire layers is working! --- train_gpt2.cu | 330 +++++++++++++++++++++++++++++--------------------- 1 file changed, 191 insertions(+), 139 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 256bfba46..66dba6ce7 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -2,6 +2,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. */ #define UNIQUE_TENSOR_MEMORY false +#define LAYERS_PER_ACTIVATION_CHECKPOINT 1 // 0 = disabled #include #include @@ -89,11 +90,23 @@ constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // GPT-2 model definition enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRADIENT, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each + PARAMETER=0, PARAMETER_GRAD, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_OPT_V+1 }; +enum TFlags : uint8_t { + NONE=0, + REUSED_MEMORY=1, + GRADIENT=2, + TENSOR_2D=4, + BIAS=8, + LAYERNORM=16, + RESIDUAL=32, + EMBEDDING=64, + STATS=128 +}; + typedef struct { int wte, wpe, lnfw, lnfb; // not per layer int ln1w, ln1b, qkvw, qkvb, attprojw, attprojb, ln2w, ln2b, fcw, fcb, fcprojw, fcprojb; // per layer @@ -164,6 +177,7 @@ typedef struct { short remaining_layers; DType data_type; TT tensor_type; + int flags; char name[16]; template @@ -198,7 +212,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { return spec; } -int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, TT tensor_type=TT::DEFAULT) { +int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { assert(num_tensor_specs < 16*1024); assert((total_elements % num_shards) == 0); TensorSpec* spec = &tensor_specs[num_tensor_specs]; @@ -210,13 +224,20 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->remaining_layers = 0; spec->data_type = data_type; spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; + spec->flags = flags; tensors_elements[spec->tensor_type] += spec->num_elements; if (copy_offset_from >= 0) { - spec->offset = tensor_specs[copy_offset_from].offset; - size_t original_tensor_bytes = tensor_specs[copy_offset_from].num_elements * sizeof_dtype(tensor_specs[copy_offset_from].data_type); + TensorSpec base_spec = tensor_specs[copy_offset_from]; + spec->offset = base_spec.offset; + size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); - assert(tensor_specs[copy_offset_from].tensor_type == spec->tensor_type); + if (base_spec.tensor_type != spec->tensor_type) { + printf("ERROR: tensor_type mismatch for %s: %d vs %d\n", + spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); + assert(false); + } + assert(base_spec.tensor_type == spec->tensor_type); assert(new_tensor_bytes <= original_tensor_bytes); } else { spec->offset = tensors_bytes[spec->tensor_type]; @@ -228,12 +249,20 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, return num_tensor_specs++; } -int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, bool copy_per_layer=false, TT tensor_type=TT::DEFAULT) { +int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, + int copy_offset_from=-1, int flags=TFlags::NONE, bool copy_per_layer=false, + int reuse_every_n_layers=0, TT tensor_type=TT::DEFAULT) { int first_tensor_id = num_tensor_specs; + if (reuse_every_n_layers > 0 && num_layers > 1) { + flags |= REUSED_MEMORY; + } for (int l = 0; l < num_layers; l++) { char layer_name[16]; assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); - int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, tensor_type); + if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { + copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); + } + int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); if (copy_per_layer) { copy_offset_from++; } @@ -242,9 +271,9 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz return first_tensor_id; } -#define TENSOR_SPECS(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, dtype) -#define TENSOR_SPECS_LOWP(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, dtype_lowp) -#define TENSOR_SPECS_FP32(name, dim1, dim2) spec->name = add_layer_specs(dim1, #name, dim2, shards, DType::FP32) // todo - won't work loading model +#define TENSOR_SPECS(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype, -1, flags, false, reuse_every_n) +#define TENSOR_SPECS_LOWP(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype_lowp, -1, flags, false, reuse_every_n) +#define TENSOR_SPECS_FP32(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, DType::FP32, -1, flags, false, reuse_every_n) void gpt2_allocate(GPT2 *model) { size_t Vp = model->config.padded_vocab_size; @@ -258,54 +287,54 @@ void gpt2_allocate(GPT2 *model) { size_t BTC = B*T*C; size_t BT = B*T; - size_t shards = 1; + int reuse_every_n = 0; + int shards = 1; int num_gpu = multi_gpu_config.num_processes; int shards_opt = (multi_gpu_config.zero_stage >= 1) ? num_gpu : 1; int shards_grad = (multi_gpu_config.zero_stage >= 2) ? num_gpu : 1; // 1) parameters & optimizer state for (int t = PARAMETER; t <= PARAMETER_OPT_V; t++) { - DType dtype = (t <= PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; - DType dtype_lowp = (t <= PARAMETER_GRADIENT) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future + DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; + DType dtype_lowp = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future current_tensor_type = (TT)t; ParameterTensors* spec = &model->params[t]; - shards = (t == PARAMETER) ? 1 : (t == PARAMETER_GRADIENT) ? shards_grad : shards_opt; + shards = (t == PARAMETER) ? 1 : (t == PARAMETER_GRAD) ? shards_grad : shards_opt; if (t == PARAMETER_MASTER && !model->use_master_weights) { continue; } - TENSOR_SPECS (wte, 1, Vp * C); - TENSOR_SPECS (wpe, 1, maxT * C); - TENSOR_SPECS (ln1w, L, C); - TENSOR_SPECS (ln1b, L, C); - TENSOR_SPECS_LOWP(qkvw, L, 3 * C * C); - TENSOR_SPECS (qkvb, L, 3 * C); - TENSOR_SPECS_LOWP(attprojw, L, C * C); - TENSOR_SPECS (attprojb, L, C); - TENSOR_SPECS (ln2w, L, C); - TENSOR_SPECS (ln2b, L, C); - TENSOR_SPECS_LOWP(fcw, L, 4 * C * C); - TENSOR_SPECS_LOWP(fcb, L, 4 * C); - TENSOR_SPECS_LOWP(fcprojw, L, 4 * C * C); - TENSOR_SPECS (fcprojb, L, C); - TENSOR_SPECS (lnfw, 1, C); - TENSOR_SPECS (lnfb, 1, C); + TENSOR_SPECS (wte, 1, Vp * C, TENSOR_2D | EMBEDDING); + TENSOR_SPECS (wpe, 1, maxT * C, TENSOR_2D | EMBEDDING); + TENSOR_SPECS (ln1w, L, C, LAYERNORM); + TENSOR_SPECS (ln1b, L, C, LAYERNORM | BIAS); + TENSOR_SPECS_LOWP(qkvw, L, 3 * C * C, TENSOR_2D); + TENSOR_SPECS (qkvb, L, 3 * C, BIAS); + TENSOR_SPECS_LOWP(attprojw, L, C * C, TENSOR_2D); + TENSOR_SPECS (attprojb, L, C, BIAS); + TENSOR_SPECS (ln2w, L, C, LAYERNORM); + TENSOR_SPECS (ln2b, L, C, LAYERNORM | BIAS); + TENSOR_SPECS_LOWP(fcw, L, 4 * C * C, TENSOR_2D); + TENSOR_SPECS_LOWP(fcb, L, 4 * C, BIAS); + TENSOR_SPECS_LOWP(fcprojw, L, 4 * C * C, TENSOR_2D); + TENSOR_SPECS (fcprojb, L, C, BIAS); + TENSOR_SPECS (lnfw, 1, C, LAYERNORM); + TENSOR_SPECS (lnfb, 1, C, LAYERNORM | BIAS); } // 2) multiuse & scratch tensors current_tensor_type = ACTIVATIONS_MULTIUSE; - /*if (UNIQUE_TENSOR_MEMORY) { + if (UNIQUE_TENSOR_MEMORY) { model->multiuse.bt4c = -1; model->multiuse.btc = -1; - } else*/ { - model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX); - model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX); - model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32); // todo - is this oversized? - model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32); // todo - is this oversized? - model->multiuse.output_scratch = add_tensor_spec("output_fpx", output_size, 1, DTYPE_FLOATX); - model->multiuse.output_scratch_fp32 = add_tensor_spec("output_fp32", output_size / 2, 1, DType::FP32, model->multiuse.output_scratch); + } else { + model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); } + model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32, -1, REUSED_MEMORY); // todo - is this avoidable (or oversized)? + model->multiuse.output_scratch = add_tensor_spec("output_fpx", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.output_scratch_fp32 = add_tensor_spec("output_fp32", output_size / 2, 1, DType::FP32, model->multiuse.output_scratch, REUSED_MEMORY); // 3) activations ActivationTensors* spec = &model->acts; @@ -313,85 +342,96 @@ void gpt2_allocate(GPT2 *model) { DType dtype = DTYPE_FLOATX; shards = 1; - TENSOR_SPECS (encoded, 1, BTC); - TENSOR_SPECS (lnf, 1, BTC); - TENSOR_SPECS_FP32(lnf_mean, 1, BT); - TENSOR_SPECS_FP32(lnf_rstd, 1, BT); - TENSOR_SPECS_FP32(losses, 1, BT); - - TENSOR_SPECS_FP32(ln1_mean, L, BT); - TENSOR_SPECS_FP32(ln1_rstd, L, BT); - TENSOR_SPECS (atty, L, BTC); - TENSOR_SPECS (residual2, L, BTC); - TENSOR_SPECS_FP32(ln2_mean, L, BT); - TENSOR_SPECS_FP32(ln2_rstd, L, BT); - TENSOR_SPECS (residual3, L, BTC); - TENSOR_SPECS_LOWP(fch, L, 4 * BTC); - TENSOR_SPECS (qkvr, L, 3 * BTC); + // with activation checkpointing, we keep every layer's residual3 for simplicity + // in theory, if we have e.g. 4 layers per checkpoint, we could have 1/4 as many residual3 + // but that would complicate everything a lot for relatively little benefit... + TENSOR_SPECS (residual3, L, BTC, RESIDUAL); + reuse_every_n = LAYERS_PER_ACTIVATION_CHECKPOINT; + assert(!reuse_every_n || (L % reuse_every_n) == 0); + + TENSOR_SPECS (encoded, 1, BTC, 0); + TENSOR_SPECS (lnf, 1, BTC, 0); + TENSOR_SPECS_FP32(lnf_mean, 1, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(losses, 1, BT, 0); + + TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS (atty, L, BTC, 0); + TENSOR_SPECS (residual2, L, BTC, RESIDUAL); + TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, 0); + TENSOR_SPECS (qkvr, L, 3 * BTC, 0); #ifdef ENABLE_CUDNN - TENSOR_SPECS_FP32(att, L, NH * B * T); + TENSOR_SPECS_FP32(att, L, NH * B * T, 0); #else - TENSOR_SPECS (att, L, NH * B * T * T); + TENSOR_SPECS (att, L, NH * B * T * T, 0); #endif if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS (output, 1, output_size); - TENSOR_SPECS_LOWP(fcproj, L, BTC); - TENSOR_SPECS_LOWP(attproj, L, BTC); + TENSOR_SPECS (output, 1, output_size, 0); + TENSOR_SPECS_LOWP(fcproj, L, BTC, 0); + TENSOR_SPECS_LOWP(attproj, L, BTC, 0); } else { - spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch); - spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc); - spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc); + spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch, REUSED_MEMORY); + spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY); + spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY); } // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(ln1, L, BTC); - TENSOR_SPECS(ln2, L, BTC); - TENSOR_SPECS(fch_gelu, L, 4 * BTC); + TENSOR_SPECS(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS(ln2, L, BTC, LAYERNORM); + TENSOR_SPECS(fch_gelu, L, 4 * BTC, 0); } else if (model->recompute < 2) { - TENSOR_SPECS(ln1, L, BTC); - TENSOR_SPECS(ln2, L, BTC); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); + TENSOR_SPECS(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS(ln2, L, BTC, LAYERNORM); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output, REUSED_MEMORY); } else { - spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->acts.lnf); - spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->acts.lnf); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output); + spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); + spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output, REUSED_MEMORY); } // 4) activation gradients + // todo - specify subtype! + reuse_every_n = 0; spec = &model->acts_grads; dtype_lowp = DTYPE_FLOATX; // todo FP8 shards = 1; if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(encoded, 1, BTC); - TENSOR_SPECS(output, 1, output_size); - TENSOR_SPECS(lnf, 1, BTC); - TENSOR_SPECS(ln1, L, BTC); - TENSOR_SPECS(atty, L, BTC); - TENSOR_SPECS(residual2, L, BTC); - TENSOR_SPECS(ln2, L, BTC); - TENSOR_SPECS(fch, L, 4 * BTC); - TENSOR_SPECS(fch_gelu, L, 4 * BTC); - TENSOR_SPECS(residual3, L, BTC); - TENSOR_SPECS(qkvr, L, 3 * BTC); + TENSOR_SPECS(encoded, 1, BTC, GRADIENT); + TENSOR_SPECS(output, 1, output_size, GRADIENT); + TENSOR_SPECS(lnf, 1, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS(ln1, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS(atty, L, BTC, GRADIENT); + TENSOR_SPECS(residual2, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS(ln2, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS(fch, L, 4 * BTC, GRADIENT); + TENSOR_SPECS(fch_gelu, L, 4 * BTC, GRADIENT); + TENSOR_SPECS(residual3, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS(qkvr, L, 3 * BTC, GRADIENT); } else { - spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch); - - int reused_btc = model->acts.residual3 + (L-1); - spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc); - spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc); - spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc); - - spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.btc); - spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, model->multiuse.btc); - spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, model->multiuse.btc); - spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, model->multiuse.btc); - spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c); - spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c); + spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT); + + int reused_btc = model->acts.residual3 + (L-1); // todo - check if this works with activation checkpointing + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT); + + int reused_btc2 = model->acts.lnf; + spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); + spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); + spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT); + + // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); + spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); } // allocate a single huge GPU buffer for all the tensors @@ -399,7 +439,7 @@ void gpt2_allocate(GPT2 *model) { cudaCheck(cudaMemset(model->multiuse_memory, 0, tensors_bytes[ACTIVATIONS_MULTIUSE])); cudaCheck(cudaMalloc(&model->params_memory[PARAMETER], tensors_bytes[PARAMETER])); - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_GRADIENT], tensors_bytes[PARAMETER_GRADIENT])); + cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_GRAD], tensors_bytes[PARAMETER_GRAD])); cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_M], tensors_bytes[PARAMETER_OPT_M])); cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_V], tensors_bytes[PARAMETER_OPT_V])); if (model->use_master_weights) { @@ -410,21 +450,12 @@ void gpt2_allocate(GPT2 *model) { for (size_t i = 0; i < num_tensor_specs; i++) { TensorSpec* spec = &tensor_specs[i]; switch (spec->tensor_type) { - case PARAMETER: - case PARAMETER_GRADIENT: - case PARAMETER_OPT_M: - case PARAMETER_OPT_V: - spec->ptr = model->params_memory[spec->tensor_type] + spec->offset; - break; - case PARAMETER_MASTER: - if (model->use_master_weights) { - spec->ptr = model->params_memory[PARAMETER_MASTER] + spec->offset; - } - break; case ACTIVATIONS_MULTIUSE: spec->ptr = model->multiuse_memory + spec->offset; break; - default: assert(false); + default: + assert(spec->tensor_type <= PARAMETER_OPT_V); + spec->ptr = model->params_memory[spec->tensor_type] + spec->offset; } } @@ -433,7 +464,7 @@ void gpt2_allocate(GPT2 *model) { model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); - printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRADIENT] / (1024*1024)); + printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRAD] / (1024*1024)); printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); printf("number of v bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_V] / (1024*1024)); @@ -736,17 +767,17 @@ void print_tensor_elements(GPT2 *model, int tensor_id) { } // Helper macros for accessing tensors -#define ACT_L(x,l) get_tensor(model->acts.x, ACTIVATIONS_MULTIUSE, l) -#define AGRAD_L(x,l) get_tensor(model->acts_grads.x, ACTIVATIONS_MULTIUSE, l) -#define PARAM_L(x,l) get_tensor(model->params[PARAMETER].x, PARAMETER, l) -#define PGRAD_L(x,l) get_tensor(model->params[PARAMETER_GRADIENT].x, PARAMETER_GRADIENT, l) -#define MULTI_L(x,l) get_tensor(model->multiuse.x, ACTIVATIONS_MULTIUSE, l) -#define TENSOR(x,l) get_tensor(x, DEFAULT, l) -#define ACT(x) ACT_L(x,l) -#define AGRAD(x) AGRAD_L(x,l) -#define PARAM(x) PARAM_L(x,l) -#define PGRAD(x) PGRAD_L(x,l) -#define MULTI(x) MULTI_L(x,l) +#define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) +#define ACT_L(x,layer) get_tensor(model->acts.x, ACTIVATIONS_MULTIUSE, layer) +#define MULTI_L(x,layer) get_tensor(model->multiuse.x, ACTIVATIONS_MULTIUSE, layer) +#define AGRAD_L(x,layer) get_tensor(model->acts_grads.x, ACTIVATIONS_MULTIUSE, layer) +#define PARAM_L(x,layer) get_tensor(model->params[PARAMETER].x, PARAMETER, layer) +#define PGRAD_L(x,layer) get_tensor(model->params[PARAMETER_GRAD].x, PARAMETER_GRAD, layer) +#define ACT(x) ACT_L(x,l) +#define MULTI(x) MULTI_L(x,l) +#define AGRAD(x) AGRAD_L(x,l) +#define PARAM(x) PARAM_L(x,l) +#define PGRAD(x) PGRAD_L(x,l) // propagate inputs through the network to produce logits. void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { @@ -815,19 +846,17 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; - int l = 0; NvtxRange classifier_and_loss_range("classifier_and_loss"); - ActivationTensors acts = model->acts; float mean_loss = 0.0f; // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(ACT(losses), 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(ACT_L(losses, 0), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(ACT(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, False); - cudaCheck(cudaMemcpy(model->cpu_losses, ACT(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(ACT_L(output, 0), ACT_L(output, 0), ACT_L(losses, 0), dloss, model->targets, B*T, V, Vp, False); + cudaCheck(cudaMemcpy(model->cpu_losses, ACT_L(losses, 0), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -837,10 +866,6 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B } void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { - if(model->params_memory[PARAMETER_GRADIENT] == nullptr) { - fprintf(stderr, "Need to allocate gradients before backward"); - exit(EXIT_FAILURE); - } NVTX_RANGE_FN(); // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) @@ -851,7 +876,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - int l = L-1; + int l = L-1; // start from the last layer bool last_step = micro_step == grad_accum_steps - 1; // on the first micro-step zero the gradients, as we're about to += accumulate into them @@ -860,7 +885,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // 1) the losses accumulate += into acts.losses, reset here // 2) the gradients accumulate += into grads_memory, reset here cudaCheck(cudaMemsetAsync(ACT(losses), 0, B * T * sizeof(float), main_stream)); - cudaCheck(cudaMemsetAsync(model->params_memory[PARAMETER_GRADIENT], 0, tensors_bytes[PARAMETER_GRADIENT], main_stream)); + cudaCheck(cudaMemsetAsync(model->params_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); } // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier @@ -871,8 +896,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - float* scratchF = TENSOR(model->multiuse.local_scratch, 0); - floatX* scratchX_HUGE = TENSOR(model->multiuse.output_scratch, 0); + float* scratchF = MULTI_L(local_scratch, 0); + floatX* scratchX_HUGE = MULTI_L(output_scratch, 0); // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -883,7 +908,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // next: backward the classifier matmul matmul_backward(AGRAD(lnf), PGRAD(wte), NULL, AGRAD(output), ACT(lnf), PARAM(wte), NULL, B*T, C, Vp); // backward the final layernorm - layernorm_backward(AGRAD(residual3 + L-1), NULL, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT(residual3 + L-1), + layernorm_backward(AGRAD_L(residual3, L-1), NULL, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers @@ -939,6 +964,33 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int }; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } + + // Is it time to redo the forward pass from our activation checkpoints? + if (LAYERS_PER_ACTIVATION_CHECKPOINT && (l % max(1, LAYERS_PER_ACTIVATION_CHECKPOINT)) == 0 && l > 0) { + int old_l = l; + // forward pass time! + l -= LAYERS_PER_ACTIVATION_CHECKPOINT; + for (int i = 0; i < LAYERS_PER_ACTIVATION_CHECKPOINT; i++, l++) { + // non-fused layernorm as we already (only!) have the residual + // (for the original forward pass, residual of l-1 is fused with layernorm of l) + floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + + matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + #ifdef ENABLE_CUDNN + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); + #else + attention_forward(ACT(atty), ACT(qkvr), ACT(att), MULTI(output_scratch), B, T, C, NH); + #endif + + matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + } + l = old_l; + } + } encoder_backward(PGRAD(wte), PGRAD(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, @@ -988,7 +1040,7 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRADIENT]; + floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRAD]; // repurposing this buffer (which isn't needed now) to write grad norm into it float* grad_norm_squared = MULTI_L(output_scratch_fp32, 0); @@ -1056,11 +1108,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo model->rng_state_last_update = model->rng_state; // todo: merge everything into 1 kernel call - for (int i = 0; i < tensors_start[PARAMETER_GRADIENT];) { + for (int i = 0; i < tensors_start[PARAMETER_GRAD];) { unsigned int seed = random_u32(&model->rng_state); TensorSpec param_spec = tensor_specs[i]; - TensorSpec grad_spec = tensor_specs[i + tensors_start[PARAMETER_GRADIENT]]; + TensorSpec grad_spec = tensor_specs[i + tensors_start[PARAMETER_GRAD]]; TensorSpec master_spec = tensor_specs[i + tensors_start[PARAMETER_MASTER]]; TensorSpec opt_m_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_M]]; TensorSpec opt_v_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_V]]; @@ -1082,8 +1134,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo cudaCheck(cudaGetLastError()); } - // hack - todo - 2D tensors only check... - float wd = (param_spec.num_elements > (4 * model->config.channels)) ? weight_decay : 0.0f; + // todo - make it configurable whether weight decay applies to e.g. bias or not + float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; if (init_from_master_only) { // when resuming training from a checkpoint with master weights (allows changing precision) From 07ed7eaf1b45ff0b37dd545778cb7f3484f7262d Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 4 Sep 2024 15:37:26 +0000 Subject: [PATCH 08/27] First draft of TensorGPU approach --- llmc/cuda_utils.cuh | 305 +++++++++++++++++++++++++++++++++++++++++++- llmc/layernorm.cuh | 7 +- train_gpt2.cu | 82 ++---------- 3 files changed, 317 insertions(+), 77 deletions(-) diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 0ce728ee1..b0266d51d 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -51,30 +51,58 @@ struct alignas(16) Packed128 { // load a Packed128 from an aligned memory address template -__device__ Packed128 load128(const ElementType* address) { +__device__ Packed128 load128(const ElementType* __restrict__ address) { return Packed128{*reinterpret_cast(address)}; } // load a Packed128 from an aligned memory address with streaming cache hint template -__device__ Packed128 load128cs(const ElementType* address) { +__device__ Packed128 load128cs(const ElementType* __restrict__ address) { return Packed128{__ldcs(reinterpret_cast(address))}; } // store a Packed128 to an aligned memory address template -__device__ void store128(ElementType* target, Packed128 value) { +__device__ void store128(ElementType* __restrict__ target, Packed128 value) { *reinterpret_cast(target) = value.get_bits(); } // store a Packed128 to an aligned memory address with streaming cache hint template -__device__ void store128cs(ElementType* target, Packed128 value) { +__device__ void store128cs(ElementType* __restrict__ target, Packed128 value) { __stcs(reinterpret_cast(target), value.get_bits()); } // store a Packed128 to an aligned memory address while caching in L2 but bypassing L1 template -__device__ void store128cg(ElementType* target, Packed128 value) { +__device__ void store128cg(ElementType* __restrict__ target, Packed128 value) { __stcg(reinterpret_cast(target), value.get_bits()); } +// This helper is for when we want to copy from e.g. FP32 to BF16 +// so if want to load a f128 of 4 elements, and write those 4 elements to memory as 64-bit +// not needed in the case of loads, the compiler will automatically optimise away unused reads +template +__device__ void store128_same_length(ElementType* target, Packed128 value) { + int4 bits = value.get_bits(); + switch (sizeof(OriginalType) / sizeof(ElementType)) { + case 0: *reinterpret_cast(target) = bits; break; // smaller + case 1: *reinterpret_cast(target) = bits; break; // same size + case 2: *reinterpret_cast(target) = make_int2(bits.x, bits.y); break; + case 4: *reinterpret_cast(target) = bits.x; break; + default: break; //assert(false); + } +} + +// todo - can we unify this with non-cs function somehow? +template +__device__ void store128_same_length_cs(ElementType* target, Packed128 value) { + int4 bits = value.get_bits(); + switch (sizeof(OriginalType) / sizeof(ElementType)) { + case 0: __stcs(reinterpret_cast(target), bits); break; // smaller + case 1: __stcs(reinterpret_cast(target), bits); break; // same size + case 2: __stcs(reinterpret_cast(target), make_int2(bits.x, bits.y)); break; + case 4: __stcs(reinterpret_cast(target), bits.x); break; + default: break; //assert(false); + } +} + // short-form typedefs typedef Packed128 f128; typedef Packed128 x128; @@ -107,11 +135,278 @@ DType dtype_of(float* f) { return DType::FP32; } DType dtype_of(nv_bfloat16 * f) { return DType::BF16; } DType dtype_of(half * f) { return DType::FP16; } +// ---------------------------------------------------------------------------- +// ... +template +struct TensorGPU { + ElementType* data_ptr; + float* scale_descale_ptr; + unsigned int* absmax_ptr; + size_t num_elements; + + template + T* as() { + return reinterpret_cast(data_ptr); + } + operator ElementType*() const { + return data_ptr; + } +}; + +// short-form typedefs +typedef TensorGPU tensorX; +typedef TensorGPU tensorFP32; +typedef TensorGPU tensorFP16; +typedef TensorGPU tensorBF16; + +template +struct tensor128 { +private: + Packed128 data128; + ElementType* data_ptr; + unsigned int *absmax_ptr; + float scale; + float descale; + float new_absmax = 0.0f; + bool wrote_data = false; + bool wrote_absmax = false; + +public: + bool scaling = (sizeof(ElementType) <= 4); // todo - fp8 only + static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); + + __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { + float2* __restrict__ ptr_restricted = (float2*)tensor.scale_descale_ptr; + float2 scale_descale = *ptr_restricted; + scale = scale_descale.x; + descale = scale_descale.y; + data_ptr = tensor.data_ptr; + if (disable_scaling) { + scaling = false; + } + } + + __device__ void load(size_t offset, bool cache_streaming=false) { + ElementType* addr = data_ptr + offset; + data128 = cache_streaming ? load128cs(addr) : load128(addr); + } + + __device__ void store(size_t offset, bool cache_streaming=false) { + if (cache_streaming) { + store128cs(data_ptr + offset, data128); + } else { + store128(data_ptr + offset, data128); + } + wrote_data = true; + } + + template + __device__ void store_same_length(size_t offset, bool cache_streaming=false) { + if (cache_streaming) { + store128_same_length_cs(data_ptr + offset, data128); + } else { + store128_same_length(data_ptr + offset, data128); + } + wrote_data = true; + } + + __device__ float get(int index) { + return (float)data128[index] * (scaling ? descale : 1.0f); + } + + __device__ void set(int index, float value) { + new_absmax = max(new_absmax, fabsf(value)); + data128[index] = (ElementType)(value * (scaling ? scale : 1.0f)); + } + + __device__ void update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { + if (!forced && !scaling) { + return; + } + wrote_absmax = true; + + // use native integer reductions as much as possible (supported on all GPUs with FP8) + // this might treat NaN/INF slightly differently but that is the least of our problems + unsigned int absmax_uint = *(unsigned int*)&new_absmax; + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + __shared__ unsigned int shared[32]; + + // lane_id must be obtained directly from the special register + // otherwise, the compiler does silly things related to the redux/atomicMax + unsigned int lane_id ; + asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); + unsigned int num_warps = num_threads >> 5; + unsigned int warp_id = thread_id & 31; + + // with this condition instead of lane_id == 0, we have shared[lane_id] both here and below + // this reduces the number of instructions for addressing + if (lane_id == warp_id) { + shared[lane_id] = absmax_uint; + } + + // sync can be after exit (dead threads don't count) but must be before return + // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR + // but this way the EXIT is right before the barrier which frees the warps slightly quicker + bool done = (warp_id != 0 || lane_id >= num_warps); + if (done && exit) asm volatile("exit;"); + __syncthreads(); + if (done && !exit) return; + + // one more warp reduction then global memory atomic + // we want as few global atomics as possible (i.e. 1 per threadblock) + absmax_uint = shared[lane_id]; + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + if (lane_id == 0) { + atomicMax(absmax_ptr, absmax_uint); + } + } + __device__ void update_absmax_1D(bool exit=false) { + update_absmax(threadIdx.x & 31, blockDim.x >> 5, exit); + } + __device__ void skip_absmax() { + wrote_absmax = true; + } + + template + __device__ void force_precision(bool stochastic=false, int microtensor_scale=false, + int zeroed_mantissa_bits=0, bool two_four_sparsity=false) { + for (int k = 0; k < elements; k++) { + // todo: fancy stuff + if (scaling || scale == 0.0f) { // already scaled + data128[k] = (ElementType)((ForcedType)(data128[k])); + } else { // need to scale & descale + float scaled_value = (float)data128[k] * scaling; + ForcedType converted_value = (ForcedType)scaled_value; + float descaled_value = (float)converted_value * descale; + data128[k] = (ElementType)descaled_value; + } + } + } + + __device__ ~tensor128() { + // this should ~always be optimised away by the compiler + assert(wrote_absmax || !scaling || !wrote_data); + } +}; + +template +__device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling=false) { + return tensor128(tensor, disable_scaling); +} + +template +__device__ tensor128 load_tensor128(TensorGPU tensor, size_t offset, + bool disable_scaling=false, bool cache_streaming = false) { + tensor128 t128(tensor, disable_scaling); + t128.load(offset, cache_streaming); + return t128; +} + +// ---------------------------------------------------------------------------- +// ... + +constexpr size_t MAX_TENSORS = 16*1024; +constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - should make this a command line option +extern int num_tensor_specs; +extern int current_absmax_index; +extern void* gpu_tensor_scale_memory; +extern void* gpu_tensor_absmax_memory; + +enum TT : uint8_t { + PARAMETER=0, PARAMETER_GRAD, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each + ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_OPT_V+1 +}; + +enum TFlags : uint8_t { + NONE=0, + REUSED_MEMORY=1, + GRADIENT=2, + TENSOR_2D=4, + BIAS=8, + LAYERNORM=16, + RESIDUAL=32, + EMBEDDING=64, + STATS=128 +}; + +typedef struct { + char* ptr; + size_t offset; // into base pointer + size_t num_elements; // per shard + int id; + short num_shards; + short remaining_layers; + DType data_type; + TT tensor_type; + int flags; + char name[16]; + + template + operator T*() const { + if (std::is_same::value && data_type != DType::FP32 || + std::is_same::value && data_type != DType::FP16 || + std::is_same::value && data_type != DType::BF16) { + printf("ERROR: Unexpected data type (%d) for tensor %s\n", (int)data_type, name); + exit(EXIT_FAILURE); + } + return reinterpret_cast(ptr); + } + + template + operator TensorGPU() const { + TensorGPU tensor; + int absmax_idx = id + (current_absmax_index * num_tensor_specs); + + tensor.num_elements = num_elements; + tensor.data_ptr = this->operator T*(); + tensor.scale_descale_ptr = reinterpret_cast(gpu_tensor_scale_memory) + id; + tensor.absmax_ptr = reinterpret_cast(gpu_tensor_absmax_memory) + absmax_idx; + + return tensor; + } +} TensorSpec; // ---------------------------------------------------------------------------- // Copy, cast functions +using elementwise_func_t = float (*) (float); +__device__ float nothing_elementwise(float x) { + return x; +} +template +__global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { + constexpr size_t vec_size = 16 / ((sizeof(T1) < sizeof(T2)) ? sizeof(T2) : sizeof(T1)); + size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; + size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; + if (idx >= in.num_elements) { return; } + + auto inp128 = load_tensor128(in, idx, disable_scaling, true); + auto out128 = new_tensor128(out, disable_scaling); + for (int k = 0; k < vec_size; k++) { + float out_fp32 = elementwise_func(inp128.get(k)); + out128.set(k, out_fp32); + } + out128.store_same_length(idx); + out128.update_absmax(threadIdx.x, block_size, true); +} + +// todo - move to GELU etc. +__device__ float gelu_forward_elementwise(float x) { + float cube = 0.044715f * x * x * x; + + float tanh_out; + float tanh_arg = sqrtf(2.0f / M_PI) * (x + cube); + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_out) : "f"(tanh_arg)); + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_x = 0.5f * x; + return half_x * tanh_out + half_x; +} + // device functions and the kernel to cast data between types template __device__ Td cast_value(Ts val); diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index cd66dbf60..b0a8821bd 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -434,7 +434,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // kernel launchers // similar to `fused_residual_forward5` -void layernorm_forward(floatX* out, float* mean, float* rstd, +void layernorm_forward(TensorGPU out, float* mean, float* rstd, floatX* inp, const floatX* weight, const floatX* bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); @@ -467,9 +467,8 @@ void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N cudaCheck(cudaGetLastError()); } -void fused_residual_forward5(floatX* residual, floatX* normed, float* mean, float* rstd, - const floatX* inp1, const floatX* inp2, - const floatX* weight, const floatX* bias, +void fused_residual_forward5(tensorX residual, tensorX normed, tensorFP32 mean, tensorFP32 rstd, + tensorX inp1, tensorX inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { const int block_size = 256; int block_y = block_size / WARP_SIZE; diff --git a/train_gpt2.cu b/train_gpt2.cu index 66dba6ce7..ac3c72b10 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -89,24 +89,6 @@ constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // ---------------------------------------------------------------------------- // GPT-2 model definition -enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRAD, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each - ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch - DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_OPT_V+1 -}; - -enum TFlags : uint8_t { - NONE=0, - REUSED_MEMORY=1, - GRADIENT=2, - TENSOR_2D=4, - BIAS=8, - LAYERNORM=16, - RESIDUAL=32, - EMBEDDING=64, - STATS=128 -}; - typedef struct { int wte, wpe, lnfw, lnfb; // not per layer int ln1w, ln1b, qkvw, qkvb, attprojw, attprojb, ln2w, ln2b, fcw, fcb, fcprojw, fcprojb; // per layer @@ -167,34 +149,12 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; -// todo: need flags, subtypes (e.g. act gradient), etc... -typedef struct { - char* ptr; - size_t offset; // into base pointer - size_t num_elements; // per shard - int id; - short num_shards; - short remaining_layers; - DType data_type; - TT tensor_type; - int flags; - char name[16]; - - template - operator T*() const { - if (std::is_same::value && data_type != DType::FP32 || - std::is_same::value && data_type != DType::FP16 || - std::is_same::value && data_type != DType::BF16) { - printf("ERROR: Unexpected data type (%d) for tensor %s\n", (int)data_type, name); - exit(EXIT_FAILURE); - } - return reinterpret_cast(ptr); - } -} TensorSpec; +int num_tensor_specs = 0; +int current_absmax_index = 0; +void* gpu_tensor_scale_memory = NULL; +void* gpu_tensor_absmax_memory = NULL; -constexpr size_t MAX_TENSORS = 16*1024; TensorSpec tensor_specs[MAX_TENSORS] = {0}; -size_t num_tensor_specs = 0; TT current_tensor_type = TT::PARAMETER; size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; @@ -473,7 +433,12 @@ void gpt2_allocate(GPT2 *model) { // ======================= // allocate_state stuff // ======================= - // allocate the space + // absmax/scaling/descaling buffers for FP8 & Friends + cudaMalloc(&gpu_tensor_scale_memory, sizeof(float) * num_tensor_specs); + cudaMemset(gpu_tensor_scale_memory, 0, sizeof(float) * num_tensor_specs); + cudaMalloc(&gpu_tensor_absmax_memory, sizeof(float) * num_tensor_specs * MAX_ABSMAX_HISTORY); + cudaMemset(gpu_tensor_absmax_memory, 0, sizeof(float) * num_tensor_specs * MAX_ABSMAX_HISTORY); + cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); @@ -809,7 +774,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { for (; l < L; l++) { NvtxRange layer_range("Layer", l); - floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN @@ -914,8 +879,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // now backward all the layers for (; l >= 0; l--) { NvtxRange layer_range("Layer", l); - floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); - floatX* dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu gelu_forward(ACT(fch_gelu), ACT(fch), B*T*4*C); @@ -973,7 +938,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int for (int i = 0; i < LAYERS_PER_ACTIVATION_CHECKPOINT; i++, l++) { // non-fused layernorm as we already (only!) have the residual // (for the original forward pass, residual of l-1 is fused with layernorm of l) - floatX* residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); + tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); @@ -1019,25 +984,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } } -// Gets the offset of a specific tensor for a specific layer in the GPT2 model -// layer_id is ignored for weights that are not part of a transformer block -/* -ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { - // first offset our way to the parameter tensor start - ptrdiff_t offset = 0; - for (int i = 0; i < param_tensor_id; i++) { - offset += (ptrdiff_t)model->param_elements[i]; - } - size_t size = model->param_elements[param_tensor_id] ; - // if we are in the transformer block, we need to additionally offset by the layer id - if(2 <= param_tensor_id && param_tensor_id <= 13) { - size /= model->config.num_layers; - offset += (ptrdiff_t)(layer_id * size); - } - return {offset, size}; -} -*/ - float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRAD]; From a70f3222a5e3c789a3e07a7922b0808bdd79f4c2 Mon Sep 17 00:00:00 2001 From: ademeure Date: Fri, 6 Sep 2024 02:12:30 +0000 Subject: [PATCH 09/27] WIP most things converted to TensorGPU, bit more encoder and a lot more matmul work to do... --- llmc/cuda_utils.cuh | 43 ++++-- llmc/encoder.cuh | 36 +++-- llmc/gelu.cuh | 93 +++++++----- llmc/layernorm.cuh | 335 ++++++++++++++++---------------------------- llmc/matmul.cuh | 33 ++--- train_gpt2.cu | 27 ++-- 6 files changed, 261 insertions(+), 306 deletions(-) diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index b0266d51d..73fa15b67 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -145,13 +145,25 @@ struct TensorGPU { size_t num_elements; template - T* as() { + __device__ __host__ T* as() { return reinterpret_cast(data_ptr); } - operator ElementType*() const { + __device__ __host__ operator ElementType*() const { return data_ptr; } + + __device__ __host__ ElementType& operator[](size_t index) { + return data_ptr[index]; + } + + __device__ __host__ const ElementType& operator[](size_t index) const { + return data_ptr[index]; + } + + __device__ __host__ int num_per_128() const { + return sizeof(int4) / sizeof(ElementType); + } }; // short-form typedefs @@ -160,6 +172,12 @@ typedef TensorGPU tensorFP32; typedef TensorGPU tensorFP16; typedef TensorGPU tensorBF16; +typedef TensorGPU tensorFP8e4; +typedef TensorGPU tensorFP8e5; + +extern TensorGPU null_tensorX; +extern TensorGPU null_tensorFP32; + template struct tensor128 { private: @@ -173,7 +191,7 @@ private: bool wrote_absmax = false; public: - bool scaling = (sizeof(ElementType) <= 4); // todo - fp8 only + bool scaling = (sizeof(ElementType) <= 1); // todo - fp8 only static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { @@ -182,9 +200,11 @@ public: scale = scale_descale.x; descale = scale_descale.y; data_ptr = tensor.data_ptr; + absmax_ptr = tensor.absmax_ptr; if (disable_scaling) { scaling = false; } + scaling = false; } __device__ void load(size_t offset, bool cache_streaming=false) { @@ -211,6 +231,10 @@ public: wrote_data = true; } + __device__ Packed128 get128() { + return data128; + } + __device__ float get(int index) { return (float)data128[index] * (scaling ? descale : 1.0f); } @@ -220,9 +244,9 @@ public: data128[index] = (ElementType)(value * (scaling ? scale : 1.0f)); } - __device__ void update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { + __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { if (!forced && !scaling) { - return; + return false; // if we return true, we can skip __syncthreads() in some kernels } wrote_absmax = true; @@ -251,7 +275,7 @@ public: bool done = (warp_id != 0 || lane_id >= num_warps); if (done && exit) asm volatile("exit;"); __syncthreads(); - if (done && !exit) return; + if (done && !exit) return true; // one more warp reduction then global memory atomic // we want as few global atomics as possible (i.e. 1 per threadblock) @@ -260,6 +284,7 @@ public: if (lane_id == 0) { atomicMax(absmax_ptr, absmax_uint); } + return true; } __device__ void update_absmax_1D(bool exit=false) { update_absmax(threadIdx.x & 31, blockDim.x >> 5, exit); @@ -297,7 +322,7 @@ __device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling= template __device__ tensor128 load_tensor128(TensorGPU tensor, size_t offset, - bool disable_scaling=false, bool cache_streaming = false) { + bool cache_streaming = false, bool disable_scaling=false) { tensor128 t128(tensor, disable_scaling); t128.load(offset, cache_streaming); return t128; @@ -384,8 +409,8 @@ __global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; if (idx >= in.num_elements) { return; } - auto inp128 = load_tensor128(in, idx, disable_scaling, true); - auto out128 = new_tensor128(out, disable_scaling); + auto inp128 = load_tensor128(in, idx, true, disable_scaling); + auto out128 = new_tensor128(out); for (int k = 0; k < vec_size; k++) { float out_fp32 = elementwise_func(inp128.get(k)); out128.set(k, out_fp32); diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 5af09476c..a95688b53 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -16,8 +16,8 @@ In the backward pass, the gradients flow to both, handled by different kernels // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void encoder_forward_kernel3(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, +__global__ void encoder_forward_kernel3(tensorX out, + const int* inp, const tensorX wte, const tensorX wpe, int B, int T, int C) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; int N = B * T * C; @@ -27,25 +27,23 @@ __global__ void encoder_forward_kernel3(floatX* out, int b = bt / T; int t = bt % T; int c = idx % C; - int ix = inp[b * T + t]; - floatX* out_btc = out + b * T * C + t * C + c; - const floatX* wte_ix = wte + ix * C + c; - const floatX* wpe_tc = wpe + t * C + c; + auto out128 = new_tensor128(out); + auto wte128 = load_tensor128(wte, ix * C + c); + auto wpe128 = load_tensor128(wpe, t * C + c); x128 packed_out; - x128 wte128 = load128cs(wte_ix); - x128 wpe128 = load128cs(wpe_tc); for (int k = 0; k < x128::size; k++) { - packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]); + out128.set(k, wte128.get(k) + wpe128.get(k)); } - store128(out_btc, packed_out); + out128.store(b * T * C + t * C + c); + out128.update_absmax(threadIdx.x, blockDim.x, true); } template -__global__ void wte_backward_kernel(floatX* dwte, - const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp, +__global__ void wte_backward_kernel(tensorX dwte, + const int4* bucket_info, const int* workload_indices, const tensorX dout, const int* inp, unsigned int seed, int B, int T, int C) { // In order to be deterministic, we preprocess the inputs on the cpu into "buckets" // Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token @@ -116,8 +114,8 @@ __global__ void wte_backward_kernel(floatX* dwte, store128(dwte_ix, packed_in_out); } -__global__ void wpe_backward_kernel(floatX* dwpe, - const floatX* dout, const int* inp, +__global__ void wpe_backward_kernel(tensorX dwpe, + const tensorX dout, const int* inp, int B, int T, int C, unsigned int seed) { // Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16 // For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total @@ -154,8 +152,8 @@ __global__ void wpe_backward_kernel(floatX* dwpe, // ---------------------------------------------------------------------------- // kernel launchers -void encoder_forward(floatX* out, - const int* inp, const floatX* wte, const floatX* wpe, +void encoder_forward(tensorX out, + const int* inp, const tensorX wte, const tensorX wpe, int B, int T, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; @@ -166,9 +164,9 @@ void encoder_forward(floatX* out, } // Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details) -void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch +void encoder_backward(tensorX dwte, tensorX dwpe, tensorX scratch, // gpu outputs & scratch int* workload_indices, int4* bucket_info, // cpu scratch buffers - const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs + const tensorX dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs int B, int T, int C, unsigned int seed, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); @@ -222,7 +220,7 @@ void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu output // Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice) // todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely - int4* d_bucket_info = (int4*)scratch; + int4* d_bucket_info = (int4*)scratch.data_ptr; int* d_workload_indices = (int*)(scratch + B*T*num_c_groups * sizeof(int4)); cudaCheck(cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice, stream)); cudaCheck(cudaMemcpyAsync(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice, stream)); diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index 138daa40a..a940a3ea7 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -10,57 +10,84 @@ // CUDA kernels #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) -__global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; +__global__ void gelu_forward_kernel2(tensorFP8e4 out, tensorFP8e4 inp) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); - x128 packed_out; - x128 packed_inp = load128cs(inp + idx); // load and do not keep in cache - for(int k = 0; k < packed_inp.size; ++k) { - float xi = (float)packed_inp[k]; + auto out128 = new_tensor128(out); + auto inp128 = load_tensor128(inp, idx, true); + for(int k = 0; k < inp.num_per_128(); ++k) { + float xi = inp128.get(k); float cube = 0.044715f * xi * xi * xi; - packed_out[k] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)))); + + float tanh_in_out = GELU_SCALING_FACTOR * (xi + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_xi = 0.5f * xi; + out128.set(k, half_xi * tanh_in_out + half_xi); } - // store instead of storecs (without cache streaming) in case it is useful for the - // data to be in the cache for the next operation after this GeLU - store128(out + idx, packed_out); + out128.store_same_length(idx, false); + + // Update absmax + out128.update_absmax(threadIdx.x, blockDim.x, true); } -__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; +//template +template +__global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorGPU inp) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); - x128 packed_dinp; - x128 packed_inp = load128cs(inp + idx); - x128 packed_dout = load128(d_in_out + idx); - for (int k = 0; k < packed_inp.size; ++k) { - float x = (float)packed_inp[k]; + auto packed_dinp = new_tensor128(dinp); + auto packed_inp = load_tensor128(inp, idx, true); + auto packed_dout = load_tensor128(dout, idx); + for (int k = 0; k < dout.num_per_128(); ++k) { + float x = packed_inp.get(k); float cube = 0.044715f * x * x * x; - float tanh_arg = GELU_SCALING_FACTOR * (x + cube); - float tanh_out = tanhf(tanh_arg); - float coshf_out = coshf(tanh_arg); - float sech_out = 1.0f / (coshf_out * coshf_out); - float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); - packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]); + + float tanh_in_out = GELU_SCALING_FACTOR * (x + cube); + #if !defined(PRECISE_GELU_TANH) && !defined(ENABLE_FP32) && __CUDA_ARCH__ >= 750 + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_in_out) : "f"(tanh_in_out)); + #else + tanh_in_out = tanhf(tanh_in_out); + #endif + + float sech_out = 1.0f - (tanh_in_out * tanh_in_out); + float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); + float result = local_grad * (float)packed_dout.get(k); + packed_dinp.set(k, result); } - store128(d_in_out + idx, packed_dinp); + packed_dinp.store_same_length(idx, false); + + // Update absmax + packed_dinp.update_absmax(threadIdx.x, blockDim.x, true); } // ---------------------------------------------------------------------------- // kernel launchers -void gelu_forward(floatX* out, const floatX* inp, int N, cudaStream_t stream=main_stream) { +void gelu_forward(tensorX out, tensorX inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 512; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); + const int block_size = 256; + assert(out.num_per_128() == inp.num_per_128()); + assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); + + const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); gelu_forward_kernel2<<>>(out, inp); cudaCheck(cudaGetLastError()); } -void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N, cudaStream_t stream=main_stream) { +void gelu_backward(tensorX dinp, tensorX dout, tensorX inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 128; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); - gelu_backward_inplace_kernel<<>>(d_in_out, inp); + const int block_size = 512; + assert(dout.num_per_128() == inp.num_per_128()); + assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); + assert(dout.num_elements == inp.num_elements && dout.num_elements == dinp.num_elements); + + const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); + gelu_backward_kernel<<>>(dinp, dout, inp); cudaCheck(cudaGetLastError()); } diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index b0a8821bd..771d79452 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -17,89 +17,24 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const floatX* __restrict__ inp, const floatX* __restrict__ weight, - const floatX* __restrict__ bias, int N, int C) { - int lane_id = threadIdx.x % WARP_SIZE; - int warp_id = threadIdx.x / WARP_SIZE; - int num_warps = blockDim.x / WARP_SIZE; - - int idx = blockIdx.x * num_warps + warp_id; - if(idx >= N) { return; } // guard - - // the row of input that this group of threads is responsible for - const floatX* x = inp + idx * C; - - // mean - float sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - sum += (float)x[i]; - } - sum = warpReduceSum(sum); - float m = sum / C; - if(lane_id == 0 && mean != nullptr) { - __stcs(mean + idx, m); - } - - // rstd - sum = 0.0f; - for (int i = lane_id; i < C; i += WARP_SIZE) { - float diff = (float)x[i] - m; - sum += diff * diff; - } - sum = warpReduceSum(sum); - float s = rsqrtf(sum / C + 1e-5f); - if(lane_id == 0 && rstd != nullptr) { - __stcs(rstd + idx, s); - } - - // final normalization and scaling by weight/bias - floatX* o = out + idx * C; - for (int c = lane_id; c < C; c += WARP_SIZE) { - // load and store using the .cs "streaming" hint to the compiler, - // indicating that this data will not be reused soon, and can be streamed through the caches - // this allows the threads to get more cache-hits for the (shared) weight and bias parameters - float n = s * ((float)__ldcs(x+c) - m); - __stcs(o+c, (floatX)(n * (float)weight[c] + (float)bias[c])); - } -} - -__global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd, - const floatX* __restrict__ inp, const floatX* __restrict__ weight, - const floatX* __restrict__ bias, int N, int C) { - assert(blockDim.x == WARP_SIZE); - - // load weights and biases into shared memory - // do this before we allow any threads to exit! +__global__ void layernorm_forward_kernel6(tensorFP8e4 out, tensorFP32 mean, tensorFP32 rstd, + tensorFP8e4 inp, tensorFP8e4 weight, + const tensorX bias, int N, int C) { + // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here + int idx = blockIdx.x * blockDim.y + threadIdx.y; // non-standard: threadIdx.x is used for c + if(idx >= N) { return; } + + // load/store128 sometimes generated multiple instructions with floatX, so keep it as x128 extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_bias = reinterpret_cast(params) + (C / x128::size); - x128* s_in = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - s_bias[i/x128::size] = load128(bias + i); - } - __syncthreads(); - - int idx = blockIdx.x * blockDim.y + threadIdx.y; - if(idx >= N) { return; } // guard - - // adjust pointers to current token - inp += idx * C; - out += idx * C; + x128* s_in = reinterpret_cast(params) + (threadIdx.y * C / x128::size); - const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in_data = load128cs(inp + c); + auto inp128 = load_tensor128(inp, idx * C + c, true); for(int k = 0; k < x128::size; ++k) { - sum += (float)in_data[k]; + sum += inp128.get(k); } - s_in[c / x128::size] = in_data; + s_in[c / x128::size] = inp128.get128(); } sum = warpReduceSum(sum); @@ -114,74 +49,57 @@ __global__ void layernorm_forward_kernel6(floatX* __restrict__ out, float* __res } v = warpReduceSum(v) / C; + const float eps = 1e-5f; // todo - is this optimal / theoretically justified? float s = rsqrtf(v + eps); + auto out128 = new_tensor128(out); for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 in_data = s_in[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - const x128 b = s_bias[c / x128::size]; - x128 out_data; + auto w128 = load_tensor128(weight, c); + auto b128 = load_tensor128(bias, c); for(int k = 0; k < x128::size; ++k) { float n = s * ((float)in_data[k] - m); // normalized output - float o = n * (float)w[k] + (float)b[k]; // scale and shift it - out_data[k] = (floatX)o; + float o = n * w128.get(k) + b128.get(k); // scale and shift it + out128.set(k, o); } - - store128cs(out + c, out_data); + out128.store_same_length(idx * C + c); } // cache the mean and rstd for the backward pass later - if(threadIdx.x == 0 && mean != nullptr) { + if(threadIdx.x == 0) { // todo - add a way to pass equivalent of null for mean/rstd to avoid store __stcs(mean + idx, m); - } - // store the rstd, no need to cache it - if(threadIdx.x == 0 && rstd != nullptr) { __stcs(rstd + idx, s); } + // update absmax + out128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -__global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, float* mean, float* rstd, - const floatX* inp1, const floatX* inp2, - const floatX* weight, const floatX* bias, +__global__ void fused_residual_forward_kernel5(tensorX residual_, tensorFP8e4 normed_, tensorFP32 mean, tensorFP32 rstd, + const tensorX inp1_, const tensorFP8e4 inp2_, + const tensorX weight, const tensorX bias, int N, int C) { - assert(blockDim.x == WARP_SIZE); - - // load weights and biases into shared memory - // do this before we allow any threads to exit! - extern __shared__ char* params[]; - // load128/store128 sometimes generated multiple instructions when the types here were floatX*, so - // let's keep everything as x128 - x128* s_weight = reinterpret_cast(params); - x128* s_bias = reinterpret_cast(params) + (C / x128::size); - x128* s_res = reinterpret_cast(params) + ((2 + threadIdx.y) * C / x128::size); - - int sidx = (threadIdx.x + WARP_SIZE * threadIdx.y) * x128::size; - for(int i = sidx; i < C; i += blockDim.y * WARP_SIZE * x128::size) { - s_weight[i/x128::size] = load128(weight + i); - s_bias[i/x128::size] = load128(bias + i); - } - __syncthreads(); - + // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here int idx = blockIdx.x * blockDim.y + threadIdx.y; if(idx > N) return; - // adjust pointers to current token - residual += C * idx; - normed += C * idx; - inp1 += C * idx; - inp2 += C * idx; + // load/store128 sometimes generated multiple instructions with floatX, so keep it as x128 + extern __shared__ char* params[]; + x128* s_res = reinterpret_cast(params) + (threadIdx.y * C / x128::size); + + auto residual128 = new_tensor128(residual_); + auto normed128 = new_tensor128(normed_); const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - const x128 in1 = load128cs(inp1 + c); - const x128 in2 = load128cs(inp2 + c); - x128 out; + auto inp1_128 = load_tensor128(inp1_, idx * C + c, true); + auto inp2_128 = load_tensor128(inp2_, idx * C + c, true); for(int k = 0; k < x128::size; ++k) { - out[k] = (float)in1[k] + (float)in2[k]; - sum += (float)out[k]; + float out = inp1_128.get(k) + inp2_128.get(k); + residual128.set(k, out); + sum += residual128.get(k); } - store128cs(residual + c, out); - s_res[c / x128::size] = out; + residual128.store_same_length(idx * C + c, false); + s_res[c / x128::size] = residual128.get128(); } sum = warpReduceSum(sum); @@ -200,43 +118,32 @@ __global__ void fused_residual_forward_kernel5(floatX* residual, floatX* normed, for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { const x128 res = s_res[c / x128::size]; - const x128 w = s_weight[c / x128::size]; - const x128 b = s_bias[c / x128::size]; - x128 out; + auto w128 = load_tensor128(weight, c); + auto b128 = load_tensor128(bias, c); for(int k = 0; k < x128::size; ++k) { float n = s * ((float)res[k] - m); // normalized output - float o = n * (float)w[k] + (float)b[k]; // scale and shift it - out[k] = o; + float o = n * w128.get(k) + b128.get(k); // scale and shift it + normed128.set(k, o); } - - store128cs(normed + c, out); + normed128.store_same_length(idx * C + c, false); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { - mean[idx] = m; - rstd[idx] = s; + __stcs(mean + idx, m); + __stcs(rstd + idx, s); } -} - -__global__ void residual_forward_kernel(floatX* out, const floatX* inp1, const floatX* inp2) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; - x128 packed_out; - x128 packed_inp1 = load128cs(inp1 + idx); - x128 packed_inp2 = load128cs(inp2 + idx); - for (int k = 0; k < packed_inp1.size; k++) { - packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]); - } - store128(out + idx, packed_out); + // Update absmax for both residual and normed tensors + residual128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, false); + normed128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel10(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, - const float* mean, const float* rstd, + layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch_, + tensorFP8e5 dout, tensorX inp, tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C) { - int BLOCK_SIZE = blockDim.x; + int BLOCK_SIZE = blockDim.x; // todo - does it make any difference if this is hardcoded here? int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block extern __shared__ float shared[]; @@ -264,23 +171,19 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } __syncthreads(); - for (int bt = baseIdx; bt < BT; bt += warpsInGrid) { - const floatX* dout_bt = dout + bt * C; - const floatX* inp_bt = inp +bt * C; - floatX* dinp_bt = dinp_old + bt * C; - floatX* dinp_new_bt = dinp_new + bt * C; + auto dinp_new128 = new_tensor128(dinp_new); - // first: two reduce operations + for (int bt = baseIdx; bt < BT; bt += warpsInGrid) { float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; for (int i = warpThreadIdx * x128::size; i < C; i += WARP_SIZE * x128::size) { - x128 dout128_i = load128(dout_bt + i); - x128 inp128_i = load128(inp_bt + i); - x128 weight128_i = load128(weight + i); + auto dout128_i = load_tensor128(dout, bt * C + i); + auto inp128_i = load_tensor128(inp, bt * C + i); + auto weight128_i = load_tensor128(weight, i); for (int k = 0; k < x128::size; k++) { - float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + float dnorm_i = weight128_i.get(k) * dout128_i.get(k); dnorm_mean += dnorm_i; - dnorm_norm_mean += dnorm_i * (float)inp128_i[k]; + dnorm_norm_mean += dnorm_i * inp128_i.get(k); } } @@ -292,17 +195,17 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with for (int c = 0; c < iterations_C; c++) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); - x128 dout128 = x128::zeros(); - x128 inp128 = x128::zeros(); - x128 dinp128 = x128::zeros(); - x128 weight128 = x128::zeros(); + auto dout128 = new_tensor128(dout); + auto inp128 = new_tensor128(inp); + auto dinp128 = new_tensor128(dinp_old); + auto weight128 = new_tensor128(weight); if(global_index < C) { - dout128 = load128cs(dout_bt + global_index); - inp128 = load128cs(inp_bt + global_index); - weight128 = load128(weight + global_index); + dout128 = load_tensor128(dout, bt * C + global_index, true); + inp128 = load_tensor128(inp, bt * C + global_index, true); + weight128 = load_tensor128(weight, global_index); if constexpr (!zero_dinp_old) { - dinp128 = load128(dinp_bt + global_index); + dinp128 = load_tensor128(dinp_old, bt * C + global_index); } } @@ -311,17 +214,17 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with f128 dweight_f; for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; - float dout_i = (float)dout128[x]; - float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dout_i = dout128.get(x); + float norm_bti = (inp128.get(x) - mean_bt) * rstd_bt; dbias_f[i] = dout_i; dweight_f[i] = norm_bti * dout_i; float dval = 0.0f; - dval += (float) weight128[x] * (float)dout128[x]; // term 1 + dval += weight128.get(x) * dout128.get(x); // term 1 dval -= dnorm_mean; // term 2 dval -= norm_bti * dnorm_norm_mean; // term 3 dval *= rstd_bt; // final scale - dinp128[x] = (floatX) ((float) dinp128[x] + dval); + dinp_new128.set(x, dinp128.get(x) + dval); } if (warpId != 0) { @@ -356,15 +259,21 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } } if(global_index < C) { - // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing - store128cg(dinp_new_bt + global_index, dinp128); + dinp_new128.store_same_length(bt * C + global_index, false); } } } + + // if we did actually update the absmax (returns true), we already did __syncthreads() here + if (!dinp_new128.update_absmax(threadIdx.x, BLOCK_SIZE, false)) { + //__syncthreads(); + } __syncthreads(); + // Each block writes its partial sum to global memory // The last block to finish becomes responsible for summing up all the partial sums // This is done by atomically incrementing a flag (cleared to 0 before launching the kernel) + float* scratch = (float*)scratch_; unsigned int* scratchFlag = (unsigned int*)(scratch); // Increment scratch pointer by a full cacheline so that everything remains cacheline aligned scratch += 32; @@ -413,19 +322,19 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with break; } - x128 dbias128 = load128(dbias + global_index); - x128 dweight128 = load128(dweight + global_index); + auto dbias128 = load_tensor128(dbias, global_index); + auto dweight128 = load_tensor128(dweight, global_index); for(int o = 0; o < x128::size / f128::size; ++o) { f128 s_db = load128(dbias_shared + global_index + o * f128::size); f128 s_dw = load128(dweight_shared + global_index + o * f128::size); for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; - dbias128[x] = (floatX)(s_db[i] + (float)dbias128[x]); - dweight128[x] = (floatX)(s_dw[i] + (float)dweight128[x]); + dbias128.set(x, s_db[i] + dbias128.get(x)); + dweight128.set(x, s_dw[i] + dweight128.get(x)); } } - store128(dbias + global_index, dbias128); - store128(dweight + global_index, dweight128); + dbias128.store_same_length(global_index); + dweight128.store_same_length(global_index); } } } @@ -434,65 +343,55 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // kernel launchers // similar to `fused_residual_forward5` -void layernorm_forward(TensorGPU out, float* mean, float* rstd, - floatX* inp, const floatX* weight, const floatX* bias, +void layernorm_forward(tensorX out, tensorFP32 mean, tensorFP32 rstd, + tensorX inp, const tensorX weight, const tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 256; + int block_size = 256; // hardcoded in kernel as well int block_y = block_size / WARP_SIZE; - const int grid_size = CEIL_DIV(N, block_y); - size_t smem = (2 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); + size_t smem = block_y * C * sizeof(floatX); auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if (status == cudaSuccess) { - layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); - } else { - // fall back to the version without shared memory - const int grid_size_fb = CEIL_DIV(N * WARP_SIZE, block_size); - layernorm_forward_kernel3<<>>(out, mean, rstd, inp, weight, bias, N, C); + // todo - comment + retry to unify into one function? (failed when I tried due to kernel argument not sure why) + while (status != cudaSuccess) { + if (block_y == 1) { + printf("ERROR: not enough shared memory for layernorm_forward\n"); + exit(EXIT_FAILURE); + } + block_y /= 2, block_size /= 2; + smem = (2 + block_y) * C * sizeof(floatX); + status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } - cudaCheck(cudaGetLastError()); -} - -void residual_forward(floatX* out, const floatX* inp1, const floatX* inp2, int N, cudaStream_t stream=main_stream) { - NVTX_RANGE_FN(); - const int block_size = 256; - assert(N % (block_size * x128::size) == 0); - const int grid_size = CEIL_DIV(N, block_size * x128::size); - residual_forward_kernel<<>>(out, inp1, inp2); + int grid_size = CEIL_DIV(N, block_y); + layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); cudaCheck(cudaGetLastError()); } void fused_residual_forward5(tensorX residual, tensorX normed, tensorFP32 mean, tensorFP32 rstd, tensorX inp1, tensorX inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { - const int block_size = 256; + NVTX_RANGE_FN(); + int block_size = 256; int block_y = block_size / WARP_SIZE; - const int grid_size = CEIL_DIV(N, block_y); size_t smem = (2 + block_y) * C * sizeof(floatX); - - // in order to use more than 48 KiB of smem, need to call cudaFuncSetAttribute - // this may fail, in which case we fall back to the smem free implementation. - cudaCheck(cudaGetLastError()); auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - cudaCheck(cudaGetLastError()); - if(status == cudaSuccess) { - fused_residual_forward_kernel5<<>>(residual, normed, - mean, rstd, inp1, inp2, - weight, bias, N, C); - } else { - residual_forward(residual, inp1, inp2, N*C, stream); - layernorm_forward(normed, mean, rstd, residual, weight, bias, N, C, stream); + while (status != cudaSuccess) { + if (block_y == 1) { + printf("ERROR: not enough shared memory for fused_residual_forward\n"); + exit(EXIT_FAILURE); + } + block_y /= 2, block_size /= 2; + smem = (2 + block_y) * C * sizeof(floatX); + status = cudaFuncSetAttribute(fused_residual_forward5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } + int grid_size = CEIL_DIV(N, block_y); + fused_residual_forward_kernel5<<>>(residual, normed, + mean, rstd, inp1, inp2, + weight, bias, N, C); cudaCheck(cudaGetLastError()); } -void layernorm_backward(floatX* dinp_new, floatX* dinp_old, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, const float* mean, const float* rstd, +void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, + const tensorX dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; @@ -502,7 +401,7 @@ void layernorm_backward(floatX* dinp_new, floatX* dinp_old, floatX* dweight, flo size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 - if (dinp_old == nullptr) { + if (dinp_old == null_tensorX) { layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); } else { layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index af3398e78..2f5e07061 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -13,8 +13,8 @@ Matrix Multiplication, with help from cuBLASLt // ---------------------------------------------------------------------------- // CUDA kernels -template -__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int BT, int OC, +template +__global__ void matmul_backward_bias_kernel9(TensorGPU dbias, tensorX dout, int BT, int OC, std::bool_constant) { constexpr const int bdx = 4; constexpr const int bdy = WARP_SIZE / bdx; @@ -227,29 +227,31 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* cudaCheck(cudaGetLastError()); } +template // small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments) -void matmul_forward_cublaslt(floatX* out, - floatX* inp, floatX* weight, floatX* bias, +void matmul_forward_cublaslt(tensorX out, + tensorX inp, tensorX weight, tensorX bias, int BT, int C, int OC, - floatX* pre_gelu=NULL, int gelu_fusion=1, cudaStream_t stream=main_stream) { + TensorGPU pre_gelu=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) - if (gelu_fusion < 1 && pre_gelu) { + if (gelu_fusion < 1 && pre_gelu != null_tensorX) { matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); - gelu_forward(out, pre_gelu, BT*OC, stream); + gelu_forward(out, pre_gelu, stream); } else { matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); } } -void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, - floatX* dout, floatX* inp, floatX* weight, - float* dbias_buffer, +template +void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, + tensorX dout, tensorX inp, tensorX weight, + tensorFP32 dbias_buffer, int BT, int C, int OC, - floatX* pre_gelu=NULL, int gelu_fusion=1, cudaStream_t stream=main_stream) { + TensorGPU pre_gelu=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // backward to bias, if given, does a += - if (dbias != NULL) { + if (dbias != null_tensorX) { // Each warp is responsible for 8 * "x128::size" = 64 OCs at BF16 (OC must be a multiple of 64!) // Block size is 1024 | 768 threads (32|24 warps) and we reduce those values into 1 at the end @@ -272,16 +274,15 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); cudaCheck(cudaGetLastError()); } - dbias = NULL; // prevent dbias calculation from also being fused in matmul_cublaslt below (if we enabled fusion) } // backward to input, uses = in the backward pass (set the gradient) matmul_cublaslt(dinp, weight, dout, NULL, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, - gelu_fusion >= 2 ? pre_gelu : NULL, true); + gelu_fusion >= 2 ? pre_gelu.data_ptr : NULL, true); // backward GELU (if it wasn't fused into the matmul above) - if (gelu_fusion < 2 && pre_gelu) { - gelu_backward_inplace(dinp, pre_gelu, BT*C, stream); + if (gelu_fusion < 2 && pre_gelu != null_tensorX) { + gelu_backward(dinp, dinp, pre_gelu, stream); } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one diff --git a/train_gpt2.cu b/train_gpt2.cu index ac3c72b10..9e17cc645 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -83,6 +83,7 @@ char filename_buffer[512]; // global vars containing information about the GPU this process is running on cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; +TensorGPU null_tensorX; // buffer size to use for device <-> disk io constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; @@ -593,7 +594,6 @@ void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) { config->num_heads = channels / head_size; config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2 } - void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // The model descriptor can be: // - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers. @@ -762,7 +762,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { if (!CUDNN_ENABLED && T != model->seq_len) { cudaCheck(cudaMemset(ACT_L(att, 0), 0, L * B * NH * T * T * sizeof(floatX))); } - // validate inputs, all indices must be in the range [0, V) + // validate inputs, all indices mucst be in the range [0, V) tokenCheck(inputs, B*T, V); // copy inputs/targets to the model (fully synchronous with the host for now) @@ -785,7 +785,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); if(l+1 != L) { @@ -797,7 +797,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { } } - matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), NULL, B*T, C, Vp); + matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); } @@ -861,8 +861,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - float* scratchF = MULTI_L(local_scratch, 0); - floatX* scratchX_HUGE = MULTI_L(output_scratch, 0); + tensorFP32 scratchF = MULTI_L(local_scratch, 0); + tensorX scratchX_HUGE = MULTI_L(output_scratch, 0); // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -871,9 +871,9 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(AGRAD(lnf), PGRAD(wte), NULL, AGRAD(output), ACT(lnf), PARAM(wte), NULL, B*T, C, Vp); + matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); // backward the final layernorm - layernorm_backward(AGRAD_L(residual3, L-1), NULL, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), + layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers @@ -883,9 +883,9 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT(fch_gelu), ACT(fch), B*T*4*C); + gelu_forward(ACT(fch_gelu), ACT(fch)); } - matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); + matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); @@ -950,7 +950,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); } l = old_l; @@ -1222,6 +1222,11 @@ void common_start(bool override_enable_tf32 = true, bool print_device_info = tru printf("Device %d: %s\n", multi_gpu_config.local_device_idx, deviceProp.name); } + null_tensorX.data_ptr = nullptr; + null_tensorX.absmax_ptr = nullptr; + null_tensorX.scale_descale_ptr = nullptr; + null_tensorX.num_elements = 0; + // set up the cuda streams. atm everything is on the single main stream cudaCheck(cudaStreamCreate(&main_stream)); nvtxNameCudaStreamA(main_stream, "main stream"); From a864fe5b7dc7a042982693c8d16552f70a3a61ff Mon Sep 17 00:00:00 2001 From: ademeure Date: Sat, 7 Sep 2024 01:27:08 +0000 Subject: [PATCH 10/27] More TensorGPU integration + better stochastic rounding + better layernorm low shared memory fallback + ... --- llmc/adamw.cuh | 4 +- llmc/attention.cuh | 95 ++++++++++------ llmc/cuda_common.h | 1 + llmc/cuda_utils.cuh | 228 ++++++++++++++++++++++++++++---------- llmc/encoder.cuh | 39 +++---- llmc/fused_classifier.cuh | 62 ++++++----- llmc/gelu.cuh | 27 ++--- llmc/layernorm.cuh | 49 ++++---- train_gpt2.cu | 30 +++-- 9 files changed, 334 insertions(+), 201 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 84d64f391..bea72b0c4 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -47,7 +47,7 @@ __device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* } template -__global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +__global__ void adamw_kernel3(TensorGPU params_memory, float* master_params_memory, TensorGPU grads_memory, float* m_memory, float* v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, float grad_scale, unsigned int seed) { @@ -72,7 +72,7 @@ __global__ void init_from_master_kernel(Tp* params_memory, float* master_params_ } template -void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +void adamw_update(TensorGPU params_memory, float* master_params_memory, TensorGPU grads_memory, float* m_memory, float* v_memory, size_t num_parameters, ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, float grad_scale, unsigned int seed, cudaStream_t stream=main_stream) { // AdamW update diff --git a/llmc/attention.cuh b/llmc/attention.cuh index 3dc5cd52f..e65dcef21 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -12,11 +12,11 @@ Attention, as a fallback when we do not use the Flash Attention from cuDNN // inputs floatX, outputs FP32 (for current FP32-only activation path for this WIP) __global__ void permute_kernel(floatX* q, floatX* k, floatX* v, - const floatX* inp, + tensorX inp, int B, int N, int NH, int d) { // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d) // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d) - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); if (idx >= B * NH * N * d) { return; } // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_] @@ -27,15 +27,21 @@ __global__ void permute_kernel(floatX* q, floatX* k, floatX* v, int n = rest / d; int d_ = rest % d; int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_; - q[idx] = __ldcs(&inp[inp_idx]); - k[idx] = __ldcs(&inp[inp_idx + NH * d]); - v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]); + + auto inp128_q = load_tensor128(inp, inp_idx, true); + auto inp128_k = load_tensor128(inp, inp_idx + NH * d, true); + auto inp128_v = load_tensor128(inp, inp_idx + 2 * (NH * d), true); + for (int i = 0; i < inp.num_per_128(); i++) { + q[idx+i] = inp128_q.get(i); + k[idx+i] = inp128_k.get(i); + v[idx+i] = inp128_v.get(i); + } } -__global__ void permute_kernel_backward(floatX* dinp, +__global__ void permute_kernel_backward(tensorX dinp, const floatX* dq, const floatX* dk, const floatX* dv, int B, int N, int NH, int d) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dinp.num_per_128(); if (idx >= B * NH * N * d) { return; } int b = idx / (NH * N * d); @@ -49,12 +55,28 @@ __global__ void permute_kernel_backward(floatX* dinp, dinp[inp_idx] = dq[idx]; dinp[inp_idx + NH * d] = dk[idx]; dinp[inp_idx + 2 * (NH * d)] = dv[idx]; + + auto dinp128_q = new_tensor128(dinp); + auto dinp128_k = new_tensor128(dinp); + auto dinp128_v = new_tensor128(dinp); + for (int i = 0; i < dinp.num_per_128(); i++) { + dinp128_q.set(i, dq[idx+i]); + dinp128_k.set(i, dk[idx+i]); + dinp128_v.set(i, dv[idx+i]); + // to allow us to update the absmax only once + dinp128_k.add_value_stats(dk[idx+i], dinp128_k.get128()[i]); + dinp128_v.add_value_stats(dv[idx+i], dinp128_v.get128()[i]); + } + dinp128_q.store(inp_idx); + dinp128_k.store(inp_idx + NH * d); + dinp128_v.store(inp_idx + 2 * (NH * d)); + dinp128_q.update_absmax(threadIdx.x, blockDim.x, true); } -__global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int d) { +__global__ void unpermute_kernel(tensorX out, floatX* inp, int B, int N, int NH, int d) { // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d) - int idx = (blockIdx.x * blockDim.x + threadIdx.x); + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * out.num_per_128(); // out[b][n][nh_][d_] <- inp[b][nh_][n][d_] if (idx >= B * NH * N * d) { return; } @@ -65,11 +87,16 @@ __global__ void unpermute_kernel(floatX* inp, floatX *out, int B, int N, int NH, int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; - out[other_idx] = __ldcs(&inp[idx]); + auto out128 = new_tensor128(out); + for (int i = 0; i < out.num_per_128(); i++) { + out128.set(i, __ldcs(&inp[idx + i])); + } + out128.store(other_idx); + out128.update_absmax(threadIdx.x, blockDim.x, true); } -__global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int B, int N, int NH, int d) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; +__global__ void unpermute_kernel_backward(floatX* dout_permuted, tensorX dout, int B, int N, int NH, int d) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); if (idx >= B * NH * N * d) { return; } int b = idx / (NH * N * d); @@ -79,10 +106,13 @@ __global__ void unpermute_kernel_backward(floatX* dinp, const floatX *dout, int int n = rest / d; int d_ = rest % d; int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_; - dinp[idx] = (floatX)dout[other_idx]; + auto dout128 = load_tensor128(dout, other_idx); + for (int k = 0; k < dout128.elements; k++) { + dout_permuted[idx+k] = (floatX)dout128.get(k); + } } -__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, const floatX* inp, int N, int T) { +__global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, floatX* inp, int N, int T) { // inp, out shape: (N, T, T), where N = B * NH // fuses the multiplication by scale inside attention // directly autoregressive, so we only compute the lower triangular part @@ -149,7 +179,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons } } -__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, const floatX* att, +__global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, floatX* att, int B, int T, int C, float scale) { constexpr const int BlockSize = 256; constexpr int T_per_block = 4; @@ -192,8 +222,8 @@ __global__ void softmax_autoregressive_backward_inplace_kernel(floatX* datt, con // ---------------------------------------------------------------------------- // kernel launchers -void attention_forward(floatX* out, floatX* qkvr, floatX* att, - floatX* inp, +void attention_forward(tensorX out, floatX* qkvr, floatX* att, + tensorX inp, int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer. @@ -211,11 +241,11 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; int total_threads = B * NH * T * HS; - int num_blocks = CEIL_DIV(total_threads, block_size); + int num_blocks = CEIL_DIV(total_threads, block_size * inp.num_per_128()); permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); floatX* preatt = inp; // reuse inp as scratch buffer - matmul_cublaslt(preatt, k, q, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(preatt), tensorX::from(k), tensorX::from(q), nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // multiply all elements of preatt elementwise by scale float scale = 1.f / sqrtf(HS); @@ -225,27 +255,26 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att, // new approach: first cuBLAS another batched matmul floatX* vaccum = inp; // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) - matmul_cublaslt(vaccum, v, att, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(vaccum), tensorX::from(v), tensorX::from(att), nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - num_blocks = CEIL_DIV(B * T * C, block_size); - unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); + num_blocks = CEIL_DIV(B * T * C, block_size * out.num_per_128()); + unpermute_kernel<<>>(out, vaccum, B, T, NH, HS); cudaCheck(cudaGetLastError()); } // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) -void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scratch, - const floatX* dout, - const floatX* qkvr, const floatX* att, +void attention_backward(tensorX dinp, floatX* dqkvr, floatX* datt, floatX* scratch, + tensorX dout, tensorX qkvr, floatX* att, int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int HS = C / NH; // head size // unpack convenience pointers into q, k, v - const floatX *q, *k, *v; + floatX *q, *k, *v; q = qkvr + 0 * B * T * C; k = qkvr + 1 * B * T * C; v = qkvr + 2 * B * T * C; @@ -255,22 +284,22 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* datt, floatX* scrat dv = dqkvr + 2 * B * T * C; // backward through the unpermute operation - int num_blocks = CEIL_DIV(B * T * C, block_size); + int num_blocks = CEIL_DIV(B * T * C, block_size * dout.num_per_128()); unpermute_kernel_backward<<>>(scratch, dout, B, T, NH, HS); // backward into datt - matmul_cublaslt(datt, v, scratch, nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(datt), tensorX::from(v), tensorX::from(scratch), nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // backward into dv - matmul_cublaslt(dv, scratch, att, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dv), tensorX::from(scratch), tensorX::from(att), nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); const float scale = 1.0f / sqrtf((float)HS); // backward into preatt. this is an in-place operation; datt turns into dpreatt here softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); - const floatX* dpreatt = datt; + floatX* dpreatt = datt; // backward into q - matmul_cublaslt(dq, k, dpreatt, nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dq), tensorX::from(k), tensorX::from(dpreatt), nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // backward into k - matmul_cublaslt(dk, q, dpreatt, nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dk), tensorX::from(q), tensorX::from(dpreatt), nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); // backward into inp - num_blocks = CEIL_DIV(B * NH * T * HS, block_size); + num_blocks = CEIL_DIV(B * NH * T * HS, block_size * dinp.num_per_128()); permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); cudaCheck(cudaGetLastError()); } diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 49c2b910d..c4fb2fd24 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -15,6 +15,7 @@ Common utilities for CUDA code. #include #include #include +#include #include "utils.h" diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 73fa15b67..85a78ebf6 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -135,6 +135,111 @@ DType dtype_of(float* f) { return DType::FP32; } DType dtype_of(nv_bfloat16 * f) { return DType::BF16; } DType dtype_of(half * f) { return DType::FP16; } +// ---------------------------------------------------------------------------- +// Random Number Generation used in Stochastic Rounding (defined here as used by TensorGPU) + +// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5) +// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU +// todo - possibly overkill and we don't need such high quality random numbers? (tbd) +// http://eiserloh.net/noise/SquirrelNoise5.hpp +__device__ __host__ unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed) { + constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 + constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 + constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 + constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 + constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 + unsigned int mangledBits = positionX; + mangledBits *= SQ5_BIT_NOISE1; + mangledBits += seed; + mangledBits ^= (mangledBits >> 9); + mangledBits += SQ5_BIT_NOISE2; + mangledBits ^= (mangledBits >> 11); + mangledBits *= SQ5_BIT_NOISE3; + mangledBits ^= (mangledBits >> 13); + mangledBits += SQ5_BIT_NOISE4; + mangledBits ^= (mangledBits >> 15); + mangledBits *= SQ5_BIT_NOISE5; + mangledBits ^= (mangledBits >> 17); + return mangledBits; +} + +// rely on default values of 0 being optimised away for 1D/2D/3D (shorter than original code) +__device__ __host__ unsigned int get_random_noise(unsigned int seed, unsigned int x, + unsigned int y=0, unsigned int z=0, unsigned int t=0) { + constexpr unsigned int PRIME1 = 198491317u; // Large prime number with non-boring bits + constexpr unsigned int PRIME2 = 6542989u; // Large prime number with distinct and non-boring bits + constexpr unsigned int PRIME3 = 357239u; // Large prime number with distinct and non-boring bits + return SquirrelNoise5(x + (PRIME1 * y) + (PRIME2 * z) + (PRIME3 * t), seed); +} + +// stochastic rounding (typicalling using Squirel Noise above to go from a seed to a random number) +// new algorithm that calculates distance from rounded up/down values to correctly handle denorms +// (didn't matter with BF16 because denorms are so tiny they're irrelevant, unlike in FP8/FP16) +template +__device__ __forceinline__ void stochastic_rounding(float in, Ti *out, unsigned int seed, float prob_offset=0.0f) { + unsigned int random = noise ? get_random_noise(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed) : seed; + + // prob_offset allows rounding towards gradient more of the time (one paper recommends that) + // e.g. +0.3f ==> 65% chance up, 35% chance down + highp threshold_percentage = ((highp)random / (highp)0xFFFFFFFF) - prob_offset; + + Ti rounded_down, rounded_up; + if constexpr (std::is_same::value) { + rounded_down = __float2half_rd(in); + rounded_up = __float2half_ru(in); + } else if constexpr (std::is_same::value) { + rounded_down = __float2bfloat16_rd(in); + rounded_up = __float2bfloat16_ru(in); + } else if constexpr (std::is_same::value) { + // CUDA doesn't have round down/up instructions for FP8 (in SW or HW) so we do it ourselves + // ARM-Intel-NVIDIA style FP8 E4M3 (different for AMD-Graphcore-Qualcomm format!) + Ti rounded = __nv_fp8_e4m3(in); + unsigned char rounded_bits = rounded.__x; + unsigned char absolute_bits = rounded_bits & 127; + unsigned char rounded_up_bits = absolute_bits + 1; + unsigned char rounded_down_bits = absolute_bits - 1; + + // compiler likes the following code atm, but small changes may increase instructions by a lot + // as it may suddenly decide to use branches rather than predication... + if (absolute_bits >= 126) { // maximum normal value (+NaN) + rounded_up_bits = absolute_bits; + if (absolute_bits == 127) { // NaN (not always preserving sign) + rounded_down_bits = 127; + } + } else if (absolute_bits == 0) { // zero + rounded_down_bits = 0; + } else { + unsigned char mantissa_bits = absolute_bits & 7; + if (mantissa_bits == 7) { // maximum mantissa (already known non-NaN/non-max) + rounded_up_bits = (absolute_bits - mantissa_bits) + 8; // clear mantissa, add 1 to exponent + } else if (mantissa_bits == 0) { // minimum mantissa (already known non-zero) + rounded_down_bits = (absolute_bits + 7) - 8; // max mantissa, subtract 1 from exponent + } + } + if (in < 0) { // negative input: swap rounded up/down and add negative sign + unsigned char swap_tmp = rounded_down_bits | 128; + rounded_down_bits = rounded_up_bits | 128; + rounded_up_bits = swap_tmp; + } + + // rounding to nearest even already gave us 1 of the 2 rounded values surrounding the input + // we only need the other one (but no point skipping anything above given SIMT divergence) + rounded_down.__x = ((float)rounded <= in) ? rounded.__x : rounded_down_bits; + rounded_up.__x = ((float)rounded >= in) ? rounded.__x : rounded_up_bits; + } else if constexpr (std::is_same::value) { + assert(false); // todo + } else { + assert(false); + } + + highp diff = (highp)rounded_up - (highp)rounded_down; + highp lerp = ((highp)in - (highp)rounded_down) / diff; // division by 0 is OK as it means (up == down) anyway + *out = (lerp > threshold_percentage) ? rounded_up : rounded_down; +} +__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { + *out = in; // dummy function for when floatX is float (FP32 mode) +} + // ---------------------------------------------------------------------------- // ... template @@ -144,6 +249,12 @@ struct TensorGPU { unsigned int* absmax_ptr; size_t num_elements; + static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { + TensorGPU tmp = {0}; + tmp.data_ptr = ptr; + return tmp; + } + template __device__ __host__ T* as() { return reinterpret_cast(data_ptr); @@ -164,6 +275,25 @@ struct TensorGPU { __device__ __host__ int num_per_128() const { return sizeof(int4) / sizeof(ElementType); } + + __device__ __host__ float get_scalar(size_t index, bool disable_scaling=true) const { + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float value = (float)data_ptr_restricted[index]; + float descale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[1] : 1.0f; + return value * descale; // [1] = descale + } + + __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=true) { + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float scale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[0] : 1.0f; + ElementType output = (ElementType)(value * scale); + data_ptr_restricted[index] = output; + return output; + } }; // short-form typedefs @@ -191,7 +321,7 @@ private: bool wrote_absmax = false; public: - bool scaling = (sizeof(ElementType) <= 1); // todo - fp8 only + bool scaling = (sizeof(ElementType) == 33); // todo - fp8 only static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { @@ -231,17 +361,54 @@ public: wrote_data = true; } - __device__ Packed128 get128() { + __device__ const Packed128& get128() const { + return data128; + } + + __device__ Packed128& get128() { return data128; } + // call this manually if e.g. you use set_scalar() to update the tensor + // todo - in the future, this could consider more than just absmax + __device__ void add_value_stats(float value, ElementType output) { + new_absmax = max(new_absmax, fabsf(value)); + } + __device__ float get(int index) { return (float)data128[index] * (scaling ? descale : 1.0f); } __device__ void set(int index, float value) { - new_absmax = max(new_absmax, fabsf(value)); data128[index] = (ElementType)(value * (scaling ? scale : 1.0f)); + add_value_stats(value, data128[index]); + } + + __device__ void set_stochastic(int index, float value, unsigned int random_number, + bool rotate_by_index=true, bool non_deterministic_rng=false) { + float scaled_value = value * (scaling ? scale : 1.0f); + + // rotate the random number by the index so we can cheaply reuse the same RNG + // obviously less good than having true per-index RNG, but should be good enough + // when rounding FP32 to FP8, most of the bits make extremely little difference anyway... + // x10 is used so that it never repeats for indices [0;15] with a minimum difference of 2 etc. + if (rotate_by_index) { + assert(index < 16); // >=16 would repeat and be extremely bad RNG + random_number = __funnelshift_l(random_number, random_number, index * 10); + } + // RNG without a seed from the host for quick testing, but obviously not deterministic! + #ifdef FORCE_NON_DETERMINISM + non_deterministic_rng = true; + #endif + if (non_deterministic_rng) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + + stochastic_rounding(scaled_value, &data128[index], random_number); + add_value_stats(value, data128[index]); } __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { @@ -525,59 +692,4 @@ void global_sum_deterministic(float* result, const Float* values, int count, cud cudaCheck(cudaGetLastError()); } -// ---------------------------------------------------------------------------- -// Random Number Generation used in Stochastic Rounding - -// SquirrelNoise5 - Squirrel's Raw Noise utilities (version 5) -// This gives us a random number from threadIdx/blockIdx + a single seed for the entire GPU -// todo - possibly overkill and we don't need such high quality random numbers? (tbd) -// http://eiserloh.net/noise/SquirrelNoise5.hpp -__device__ __host__ constexpr unsigned int SquirrelNoise5(unsigned int positionX, unsigned int seed) -{ - constexpr unsigned int SQ5_BIT_NOISE1 = 0xd2a80a3f; // 11010010101010000000101000111111 - constexpr unsigned int SQ5_BIT_NOISE2 = 0xa884f197; // 10101000100001001111000110010111 - constexpr unsigned int SQ5_BIT_NOISE3 = 0x6C736F4B; // 01101100011100110110111101001011 - constexpr unsigned int SQ5_BIT_NOISE4 = 0xB79F3ABB; // 10110111100111110011101010111011 - constexpr unsigned int SQ5_BIT_NOISE5 = 0x1b56c4f5; // 00011011010101101100010011110101 - unsigned int mangledBits = positionX; - mangledBits *= SQ5_BIT_NOISE1; - mangledBits += seed; - mangledBits ^= (mangledBits >> 9); - mangledBits += SQ5_BIT_NOISE2; - mangledBits ^= (mangledBits >> 11); - mangledBits *= SQ5_BIT_NOISE3; - mangledBits ^= (mangledBits >> 13); - mangledBits += SQ5_BIT_NOISE4; - mangledBits ^= (mangledBits >> 15); - mangledBits *= SQ5_BIT_NOISE5; - mangledBits ^= (mangledBits >> 17); - return mangledBits; -} -__device__ __host__ constexpr unsigned int Get2dNoiseUint(int indexX, int indexY, unsigned int seed) -{ - constexpr unsigned int PRIME_NUMBER = 198491317u; // Large prime number with non-boring bits - unsigned int x = static_cast(indexX); - unsigned int y = static_cast(indexY); - - return SquirrelNoise5(x + (PRIME_NUMBER * y), seed); -} - -// stochastic rounding built on top of Squirel Noise above (with seed updated per step via xorshift) -__device__ __forceinline__ void stochastic_rounding(float in, __nv_bfloat16 *out, unsigned int seed) { - // todo - is this stochastic rounding *too good*? can we cut any corners? - // makes sure each thread gets a different random number - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed); - unsigned int threshold = random & 0xFFFF; - unsigned int float_bits = __float_as_uint(in); - unsigned int rounded_bits = float_bits & 0x0000FFFF; - float_bits = (rounded_bits > threshold) ? (float_bits | 0xFFFF) : (float_bits & ~0xFFFF); - *out = __float2bfloat16_rn(__uint_as_float(float_bits)); -} -__device__ __forceinline__ void stochastic_rounding(float in, half *out, unsigned int random) { - *out = (float)in; // todo - implement this... -} -__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { - *out = in; // dummy function for when floatX is float (FP32 mode) -} - #endif \ No newline at end of file diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index a95688b53..2bcd6017e 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -33,7 +33,6 @@ __global__ void encoder_forward_kernel3(tensorX out, auto wte128 = load_tensor128(wte, ix * C + c); auto wpe128 = load_tensor128(wpe, t * C + c); - x128 packed_out; for (int k = 0; k < x128::size; k++) { out128.set(k, wte128.get(k) + wpe128.get(k)); } @@ -73,11 +72,9 @@ __global__ void wte_backward_kernel(tensorX dwte, for(int item = warp_id; item < bucket_size; item += BLOCK_SIZE/WARP_SIZE) { int bt = workload_indices[bucket_start_idx + item]; - - const floatX* dout_btc = dout + bt * C + c; - x128 packed_inp1 = load128cs(dout_btc); - for (int k = 0; k < packed_inp1.size; k++) { - accum[k] += (float)packed_inp1[k]; + auto dout128 = load_tensor128(dout, bt * C + c, true); + for (int k = 0; k < dout128.elements; k++) { + accum[k] += dout128.get(k); } } @@ -90,8 +87,7 @@ __global__ void wte_backward_kernel(tensorX dwte, } // Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance - floatX* dwte_ix = dwte + bucket_ix * C + c; - x128 packed_in_out = load128(dwte_ix); + auto dwte128 = load_tensor128(dwte, bucket_ix * C + c, false, true); // note: threads which have returned are considered synchronised by CUDA so no risk of deadlock __syncthreads(); @@ -103,15 +99,15 @@ __global__ void wte_backward_kernel(tensorX dwte, } } - // Add the result to dwte and write back to global memory (read-modify-write) + // add the result to dwte and write back to global memory (read-modify-write) + // we use stochastic rounding to go from FP32 to BF16/whatever (the seed is deterministic) + // reusing same random value but shifting based on the index in set_stochastic ("good enough") + unsigned int random = get_random_noise(seed, threadIdx.x, bucket); for (unsigned int k = 0; k < x128::size; k++) { - // We use stochastic rounding to go from FP32 to BF16 - // The seed is deterministic and unique for each parameter to guarantee we have determinism AND - // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB - // and that somehow messing the quality of random numbers - stochastic_rounding(accum[k] + (float)packed_in_out[k], &packed_in_out[k], seed + bucket * WARP_SIZE + threadIdx.x + k); + dwte128.set_stochastic(k, accum[k] + dwte128.get(k), random); } - store128(dwte_ix, packed_in_out); + dwte128.store(bucket_ix * C + c); + dwte128.update_absmax(threadIdx.x, blockDim.x, true); } __global__ void wpe_backward_kernel(tensorX dwpe, @@ -131,22 +127,23 @@ __global__ void wpe_backward_kernel(tensorX dwpe, float accum[x128::size] = {0.0f}; for (int b = 0; b < B; b++) { - x128 packed_dout = load128cs(dout + (b * T * C) + (t * C) + c); // will never be read again + auto dout128 = load_tensor128(dout, b * T * C + t * C + c, true); for (int k = 0; k < x128::size; k++) { - accum[k] += (float)packed_dout[k]; + accum[k] += dout128.get(k); } } - floatX* dwpe_tc = dwpe + (t * C) + c; - x128 packed_dwpe = load128(dwpe_tc); + auto dwpe128 = load_tensor128(dwpe, t * C + c); + unsigned int random = get_random_noise(seed, t, c); for (unsigned int k = 0; k < x128::size; k++) { // We use stochastic rounding to go from FP32 to BF16 // The seed is deterministic and unique for each parameter to guarantee we have determinism AND // to avoid **potential** issues with positionX int SquirrelNoise5 argument overflowing which is UB // and that somehow messing the quality of random numbers - stochastic_rounding(accum[k] + (float)packed_dwpe[k], &packed_dwpe[k], seed + idx + k); + dwpe128.set_stochastic(k, accum[k] + dwpe128.get(k), random); } - store128(dwpe_tc, packed_dwpe); + dwpe128.store(t * C + c); + dwpe128.update_absmax(threadIdx.x, blockDim.x, true); } // ---------------------------------------------------------------------------- diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index 8b29ca233..0a32b1229 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -16,23 +16,22 @@ struct SoftmaxParams { float Offset; }; -__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, tensorX inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) - - const floatX* x = inp + idx * P; + int elements = inp.num_per_128(); float thread_maxval = -INFINITY; float thread_sumval = 0.0f; - int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; + int i = (V+elements-1)/elements + threadIdx.x - blockDim.x; // special-case loop to handle the unaligned elements at the end of the array // this lets us skip the bounds check in the main loop below, which improves performance - while ((i+1)*x128::size > V) { - for(int k = 0; k < x128::size; ++k) { - if (i*x128::size+k >= V) { + while ((i+1)*elements > V) { + for(int k = 0; k < elements; ++k) { + if (i*elements+k >= V) { break; // bounds checking against real V (rather than padded P) } - float v = (float)x[i*x128::size+k]; + float v = inp.get_scalar(idx * P + i * elements + k); float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); @@ -43,9 +42,9 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // main loop for the bulk of the iterations (no bounds checking required!) for (; i >= 0; i -= blockDim.x) { - x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop - for(int k = 0; k < x128::size; ++k) { - float v = (float)packed_x[k]; + auto inp128 = load_tensor128(inp, idx * P + i * elements); + for(int k = 0; k < elements; ++k) { + float v = inp128.get(k); float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); thread_sumval *= expf((old_maxval - thread_maxval)); @@ -67,13 +66,14 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) - fused_classifier_kernel5(floatX* dlogits, const floatX* logits, float* losses, floatX* probs, + fused_classifier_kernel5(tensorX dlogits, tensorX logits, float* losses, floatX* probs, const float dloss, const int* targets, int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; - // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) + // by making it size_t here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit - int64_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data + int elements = logits.num_per_128(); + size_t idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) @@ -81,7 +81,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { - float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; + float prob = expf(logits.get_scalar(idx * P + ix) - sp.Offset) * sp.Scale; losses[idx] -= logf(prob); } @@ -93,43 +93,45 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging - const floatX* logits_vec = logits + idx * P; - for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) { + auto dlogits128 = new_tensor128(dlogits); + for (int i = threadIdx.x; i < V/elements; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // it will be overwritten by the logits gradients which is when we reduce cache persistence - x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs - x128 packed_probs; - for(int k = 0; k < x128::size; ++k) { - int element = i*x128::size + k; - float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale; + auto logits128 = load_tensor128(logits, idx * P + i * elements); + x128 packed_probs; // todo - unused but might be read on CPU in the future so not scaling (???) + for(int k = 0; k < elements; ++k) { + int element = i*elements + k; + float prob = expf(logits128.get(k) - sp.Offset) * sp.Scale; packed_probs[k] = (floatX)prob; float indicator = (element == ix) ? 1.0f : 0.0f; - packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); + dlogits128.set(k, (prob - indicator) * dloss); } if (WriteDLogits){ // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here - store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec); + dlogits128.store(idx * P + i * elements, true); } if (WriteProbs) { - store128(probs + idx * P + i * x128::size, packed_probs); + store128(probs + idx * P + i * elements, packed_probs); } } - // handle remaining elements after the last multiple of x128::size + // handle remaining elements after the last multiple of the number of elements // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements - int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size + int unaligned_start = V & ~(elements - 1); // round down to multiple of x128::size for (int i = threadIdx.x + unaligned_start; i < V; i++) { - float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; + float prob = expf(logits.get_scalar(idx * P + i) - sp.Offset) * sp.Scale; float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; if (WriteDLogits){ - __stcs(dlogits + idx * P + i, (floatX)dlogit); + floatX dlogitX = dlogits.set_scalar(idx * P + i, dlogit); // write to memory + dlogits128.add_value_stats(dlogit, dlogitX); // add to absmax stats etc. } if (WriteProbs) { probs[idx * P + i] = (floatX)prob; } } + dlogits128.update_absmax(threadIdx.x, blockDim.x, true); } // ---------------------------------------------------------------------------- @@ -137,7 +139,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // replaces logits with logit gradients template -void fused_classifier(floatX* dlogits, const floatX* logits, float* losses, +void fused_classifier(tensorX dlogits, tensorX logits, tensorFP32 losses, const float dloss, const int* targets, int BT, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index a940a3ea7..34cd9749f 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -31,8 +31,6 @@ __global__ void gelu_forward_kernel2(tensorFP8e4 out, tensorFP8e4 inp) { out128.set(k, half_xi * tanh_in_out + half_xi); } out128.store_same_length(idx, false); - - // Update absmax out128.update_absmax(threadIdx.x, blockDim.x, true); } @@ -41,11 +39,11 @@ template __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorGPU inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); - auto packed_dinp = new_tensor128(dinp); - auto packed_inp = load_tensor128(inp, idx, true); - auto packed_dout = load_tensor128(dout, idx); + auto dinp128 = new_tensor128(dinp); + auto inp128 = load_tensor128(inp, idx, true); + auto dout128 = load_tensor128(dout, idx); for (int k = 0; k < dout.num_per_128(); ++k) { - float x = packed_inp.get(k); + float x = inp128.get(k); float cube = 0.044715f * x * x * x; float tanh_in_out = GELU_SCALING_FACTOR * (x + cube); @@ -57,13 +55,11 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG float sech_out = 1.0f - (tanh_in_out * tanh_in_out); float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); - float result = local_grad * (float)packed_dout.get(k); - packed_dinp.set(k, result); + float result = local_grad * (float)dout128.get(k); + dinp128.set(k, result); } - packed_dinp.store_same_length(idx, false); - - // Update absmax - packed_dinp.update_absmax(threadIdx.x, blockDim.x, true); + dinp128.store_same_length(idx, false); + dinp128.update_absmax(threadIdx.x, blockDim.x, true); } // ---------------------------------------------------------------------------- @@ -72,7 +68,6 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG void gelu_forward(tensorX out, tensorX inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; - assert(out.num_per_128() == inp.num_per_128()); assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); @@ -82,11 +77,7 @@ void gelu_forward(tensorX out, tensorX inp, cudaStream_t stream=main_stream) { void gelu_backward(tensorX dinp, tensorX dout, tensorX inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - const int block_size = 512; - assert(dout.num_per_128() == inp.num_per_128()); - assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); - assert(dout.num_elements == inp.num_elements && dout.num_elements == dinp.num_elements); - + const int block_size = 256; const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); gelu_backward_kernel<<>>(dinp, dout, inp); cudaCheck(cudaGetLastError()); diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 771d79452..6f602754f 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -342,52 +342,43 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // ---------------------------------------------------------------------------- // kernel launchers -// similar to `fused_residual_forward5` -void layernorm_forward(tensorX out, tensorFP32 mean, tensorFP32 rstd, - tensorX inp, const tensorX weight, const tensorX bias, - int N, int C, cudaStream_t stream=main_stream) { - NVTX_RANGE_FN(); - int block_size = 256; // hardcoded in kernel as well +// Helper function to set the block size based on available shared memory and launch the kernel +template +void launch_layernorm_kernel(KernelFunc kernel, int N, int C, cudaStream_t stream, Args... args) { + int block_size = 256; int block_y = block_size / WARP_SIZE; size_t smem = block_y * C * sizeof(floatX); - auto status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - // todo - comment + retry to unify into one function? (failed when I tried due to kernel argument not sure why) + auto status = cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + // if we don't have enough shared memory, try smaller block sizes down to 32 threads + // should fit on practically every modern GPU even for very large numbers of channels + // todo - do we want to manually set the shared memory vs L1 carveout as well? while (status != cudaSuccess) { if (block_y == 1) { - printf("ERROR: not enough shared memory for layernorm_forward\n"); + printf("ERROR: not enough shared memory for kernel\n"); exit(EXIT_FAILURE); } block_y /= 2, block_size /= 2; smem = (2 + block_y) * C * sizeof(floatX); - status = cudaFuncSetAttribute(layernorm_forward_kernel6, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + status = cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); } int grid_size = CEIL_DIV(N, block_y); - layernorm_forward_kernel6<<>>(out, mean, rstd, inp, weight, bias, N, C); + kernel<<>>(args..., N, C); cudaCheck(cudaGetLastError()); } +void layernorm_forward(tensorX out, tensorFP32 mean, tensorFP32 rstd, + tensorX inp, const tensorX weight, const tensorX bias, + int N, int C, cudaStream_t stream=main_stream) { + NVTX_RANGE_FN(); + launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); +} + void fused_residual_forward5(tensorX residual, tensorX normed, tensorFP32 mean, tensorFP32 rstd, tensorX inp1, tensorX inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - int block_size = 256; - int block_y = block_size / WARP_SIZE; - size_t smem = (2 + block_y) * C * sizeof(floatX); - auto status = cudaFuncSetAttribute(fused_residual_forward_kernel5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - while (status != cudaSuccess) { - if (block_y == 1) { - printf("ERROR: not enough shared memory for fused_residual_forward\n"); - exit(EXIT_FAILURE); - } - block_y /= 2, block_size /= 2; - smem = (2 + block_y) * C * sizeof(floatX); - status = cudaFuncSetAttribute(fused_residual_forward5, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - } - int grid_size = CEIL_DIV(N, block_y); - fused_residual_forward_kernel5<<>>(residual, normed, - mean, rstd, inp1, inp2, - weight, bias, N, C); - cudaCheck(cudaGetLastError()); + launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, diff --git a/train_gpt2.cu b/train_gpt2.cu index 9e17cc645..077cde101 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -105,7 +105,7 @@ typedef struct { int btc; // (B, T, C) int local_scratch; // (B, T, C) int output_scratch; // huge - int output_scratch_fp32; // typically same buffer as above + int output_scratch_fp32; // same memory as FP32 } MultiuseTensors; typedef struct { @@ -126,7 +126,6 @@ typedef struct { size_t num_parameters; size_t num_parameters_bytes; - char* multiuse_memory = NULL; char* params_memory[NUM_TYPES_PARAM] = {0}; @@ -156,6 +155,7 @@ void* gpu_tensor_scale_memory = NULL; void* gpu_tensor_absmax_memory = NULL; TensorSpec tensor_specs[MAX_TENSORS] = {0}; +TensorSpec* tensor_specs_gpu = NULL; TT current_tensor_type = TT::PARAMETER; size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; @@ -198,6 +198,9 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); assert(false); } + if (flags & REUSED_MEMORY) { + base_spec.flags |= REUSED_MEMORY; + } assert(base_spec.tensor_type == spec->tensor_type); assert(new_tensor_bytes <= original_tensor_bytes); } else { @@ -395,7 +398,7 @@ void gpt2_allocate(GPT2 *model) { spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); } - // allocate a single huge GPU buffer for all the tensors + // allocate a single huge GPU buffer for all the tensors of a given type cudaCheck(cudaMalloc(&model->multiuse_memory, tensors_bytes[ACTIVATIONS_MULTIUSE])); cudaCheck(cudaMemset(model->multiuse_memory, 0, tensors_bytes[ACTIVATIONS_MULTIUSE])); @@ -420,6 +423,10 @@ void gpt2_allocate(GPT2 *model) { } } + // we are finished creating the tensors specs and can copy them to the GPU (effectively read-only) + cudaMalloc(&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); + cudaMemcpy(tensor_specs_gpu, tensor_specs, sizeof(TensorSpec) * num_tensor_specs, cudaMemcpyHostToDevice); + //initialise helper variables model->num_parameters = tensors_elements[TT::PARAMETER]; model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; @@ -776,11 +783,13 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NvtxRange layer_range("Layer", l); tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); - matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + tensorX qkvr = MULTI(output_scratch); // non-cudnn reuses tensor with different memory pre/post-permute + qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); + matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); #else - attention_forward(ACT(atty), ACT(qkvr), ACT(att), MULTI(output_scratch), B, T, C, NH); + attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); @@ -941,11 +950,13 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); - matmul_forward_cublaslt(CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch), ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + tensorX qkvr = ACT(qkvr); + qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); + matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); #else - attention_forward(ACT(atty), ACT(qkvr), ACT(att), MULTI(output_scratch), B, T, C, NH); + attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); @@ -1064,7 +1075,6 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo TensorSpec opt_v_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_V]]; // todo - adjust offset into params/grads when optimiser state is sharded - floatX* param_ptr = (floatX*)param_spec.ptr; float* master_ptr = NULL; if (model->params_memory[PARAMETER_MASTER] != NULL) { master_ptr = (float*)master_spec.ptr; @@ -1076,7 +1086,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo if(init_state && model->use_master_weights) { size_t grid_size = CEIL_DIV(shard_elements, 512); - copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard_elements, shard_elements, tensor_elements); + copy_and_cast_kernel<<>>(master_ptr, ((tensorX)param_spec).data_ptr, shard_elements, shard_elements, tensor_elements); cudaCheck(cudaGetLastError()); } @@ -1089,7 +1099,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo assert(false); } else { // ok finally call the kernel to update the weights with AdamW - adamw_update(param_ptr, master_ptr, (floatX*)grad_spec.ptr, + adamw_update((tensorX)param_spec, master_ptr, (tensorX)grad_spec, (float*)opt_m_spec.ptr, (float*)opt_v_spec.ptr, shard_elements, tensor_elements, tensor_elements, shard_elements, num_layers, learning_rate, From c02382eb4d2bde667aff2296c55c42aa187c655b Mon Sep 17 00:00:00 2001 From: ademeure Date: Sat, 7 Sep 2024 19:10:22 +0000 Subject: [PATCH 11/27] WIP, new unified adam is working, scaling is crashing, should be easy to fix (famous last words) --- Makefile | 2 +- llmc/adamw.cuh | 184 ++++++++++++++++++++++++++------------------ llmc/cuda_utils.cuh | 64 +++++++++------ train_gpt2.cu | 110 +++++++++++++------------- 4 files changed, 202 insertions(+), 158 deletions(-) diff --git a/Makefile b/Makefile index 6fa511db4..b9c174dd0 100644 --- a/Makefile +++ b/Makefile @@ -269,7 +269,7 @@ $(NVCC_CUDNN): llmc/cudnn_att.cpp $(NVCC) -c $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_INCLUDES) -o $@ train_gpt2cu: train_gpt2.cu $(NVCC_CUDNN) - $(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) + $(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) train_gpt2fp32cu: train_gpt2_fp32.cu $(NVCC) $(NVCC_FLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index bea72b0c4..ab437964b 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -15,84 +15,116 @@ __device__ float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } -template -__device__ void adamw_update(Tp* params_memory, float* master_params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, - float grad_scale, unsigned int seed) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_parameters) { return; } // guard - - // get the gradient, m, and v for this parameter - float grad = grad_scale * (float)grads_memory[idx]; - float m = m_memory[idx]; - float v = v_memory[idx]; - // update the first moment (momentum) - m = lerp(grad, m, beta1); - m_memory[idx] = m; - // update the second moment (RMSprop) - v = lerp(grad * grad, v, beta2); - v_memory[idx] = v; - m /= beta1_correction; // m_hat - v /= beta2_correction; // v_hat - // fetch the old value of this parameter as a float, from either source - float old_param = (master_params_memory != NULL) ? master_params_memory[idx] : (float)params_memory[idx]; - // update this parameter - float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param)); - // update our low precision version of the parameters using stochastic rounding - // this will be used in the next forward pass - stochastic_rounding(param, ¶ms_memory[idx], seed); - // write the full, float version of the param into our master copy, if we maintain one - // this will be used in the next update - if (master_params_memory != NULL) { master_params_memory[idx] = param; } -} +template +__global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, + int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, + float eps, float weight_decay, float grad_scale, int t, bool init_from_master_only=false) { + // ... + constexpr size_t block_size = 64; // 64 ==> 4KiB chunks with iteration_size=16 for FP32 opt/master + size_t iteration_size = 16; + assert(iteration_size <= 16); + size_t idx_blk = blockIdx.x * block_size * iteration_size; + size_t idx = idx_blk + (threadIdx.x * iteration_size); + size_t stride = gridDim.x * blockDim.x * iteration_size; -template -__global__ void adamw_kernel3(TensorGPU params_memory, float* master_params_memory, TensorGPU grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, - float grad_scale, unsigned int seed) { - adamw_update(params_memory + blockIdx.y * w_stride, - master_params_memory ? master_params_memory + blockIdx.y * s_stride : NULL, - grads_memory + blockIdx.y * g_stride, - m_memory + blockIdx.y * s_stride, - v_memory + blockIdx.y * s_stride, - num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, - seed - ); -} + int spec_id = 0; -template -__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed) { - size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_parameters) { return; } - params_memory += blockIdx.y * w_stride; // adjust for layer offset - master_params_memory += blockIdx.y * s_stride; - stochastic_rounding(master_params_memory[idx], ¶ms_memory[idx], seed); -} + TensorSpec* grad_specs = specs + num_params_tensors; + TensorSpec* opt_m_specs = specs + 2 * num_params_tensors; + TensorSpec* opt_v_specs = specs + 3 * num_params_tensors; + TensorSpec* master_specs = use_master_weights ? specs + 4 * num_params_tensors : opt_m_specs; -template -void adamw_update(TensorGPU params_memory, float* master_params_memory, TensorGPU grads_memory, float* m_memory, float* v_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t g_stride, ptrdiff_t s_stride, int num_slices, float learning_rate, float beta1, float beta2, int t, float eps, float weight_decay, - float grad_scale, unsigned int seed, cudaStream_t stream=main_stream) { - // AdamW update - int block_size = 512; - int num_blocks = CEIL_DIV(num_parameters, block_size); - float beta1_correction = 1.0f - powf(beta1, t); - float beta2_correction = 1.0f - powf(beta2, t); - adamw_kernel3<<>>(params_memory, master_params_memory, grads_memory, - m_memory, v_memory, num_parameters, w_stride, g_stride, s_stride, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, - grad_scale, seed); - cudaCheck(cudaGetLastError()); -} + TensorSpec opt_spec = opt_v_specs[spec_id]; + size_t current_start = opt_spec.offset / sizeof(float); + size_t current_end = current_start + opt_spec.num_elements; + + while (idx < num_opt_parameters) { + // todo - do this part on thread 0 only? + while (idx >= current_end) { + spec_id++; + if (spec_id >= num_params_tensors) { + return; + } + opt_spec = opt_v_specs[spec_id]; + current_start = opt_spec.offset / sizeof(float); + current_end = current_start + opt_spec.num_elements; + } + + TensorGPU grad_tensor = grad_specs[spec_id]; + TensorGPU master_tensor = master_specs[spec_id]; + TensorGPU opt_m_tensor = opt_m_specs[spec_id]; + TensorGPU opt_v_tensor = opt_spec; + + auto out_master128 = new_tensor128(master_tensor, true); + auto out_opt_m128 = new_tensor128(opt_m_tensor, true); + auto out_opt_v128 = new_tensor128(opt_v_tensor, true); + + // todo - make it configurable whether weight decay applies to e.g. bias or not + float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; + + if (specs[spec_id].data_type == DType::BF16) { + TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; + auto out_param128 = new_tensor128(param_tensor); + + __syncthreads(); // todo - hopefully results in better memory access patterns => TBC + while (idx < current_end) { + // always sizeof(param) <= sizeof(grad) <= sizeof(opt/master) + // todo - maybe not true, could have FP32 param and BF16 grad? + // todo - hack - currently assuming grad is always bfloat16 + unsigned int random = get_random_noise(seed, idx); + for (int i = 0; i < iteration_size; i += 16 / sizeof(__nv_bfloat16)) { + size_t offset = (idx - current_start) + i; + auto param128 = load_tensor128(param_tensor, offset); + auto grad128 = load_tensor128(grad_tensor, offset); + for (int j = 0; j < sizeof(float) / sizeof(__nv_bfloat16); j++) { + // todo - sparse(-ish) accesses, I don't like it. + auto opt_m128 = load_tensor128(opt_m_tensor, offset + j * f128::size, true); + auto opt_v128 = load_tensor128(opt_v_tensor, offset + j * f128::size, true); + // optimised away if we don't use it (and pointer will be equal to opt_m128) + auto master128 = load_tensor128(master_tensor, offset + j * f128::size, true); + + if (master_init_modes && init_from_master_only) { + for (int k = 0; k < f128::size; k++) { + float old_param = master128.get(k); + out_param128.set_stochastic(k + j*f128::size, old_param, random); + } + continue; + } + + for (int k = 0; k < f128::size; k++) { + float grad = grad128.get(k + j*f128::size); + float m = opt_m128.get(k); + float v = opt_v128.get(k); + m = lerp(grad, m, beta1); + v = lerp(grad * grad, v, beta2); + out_opt_m128.set(k, m); + out_opt_v128.set(k, v); + m /= beta1_correction; + v /= beta2_correction; -template -void init_from_master(Tp* params_memory, float* master_params_memory, size_t num_parameters, - ptrdiff_t w_stride, ptrdiff_t s_stride, int num_slices, unsigned int seed, cudaStream_t stream=main_stream) { - int block_size = 512; // must match block size of adamw_update so that RNG also matches - int num_blocks = CEIL_DIV(num_parameters, block_size); - init_from_master_kernel<<>> - (params_memory, master_params_memory, num_parameters, w_stride, s_stride, seed); - cudaCheck(cudaGetLastError()); + float old_param; + if (use_master_weights && !master_init_modes) { + old_param = master128.get(k); + } else { + old_param = param128.get(k + j*f128::size); + } + float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); + out_param128.set_stochastic(k + j*f128::size, param, random); + out_master128.set(k, param); + } + out_opt_m128.store(offset + j * f128::size); + out_opt_v128.store(offset + j * f128::size); + if constexpr (use_master_weights) { + out_master128.store(offset + j * f128::size); + } + } + out_param128.store(offset); + } + out_param128.update_absmax(threadIdx.x, block_size, false); + idx_blk += stride; + idx += stride; + } + } + } } diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 85a78ebf6..4f266caf8 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -276,7 +276,7 @@ struct TensorGPU { return sizeof(int4) / sizeof(ElementType); } - __device__ __host__ float get_scalar(size_t index, bool disable_scaling=true) const { + __device__ __host__ float get_scalar(size_t index, bool disable_scaling=false) const { ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -285,7 +285,7 @@ struct TensorGPU { return value * descale; // [1] = descale } - __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=true) { + __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=false) { ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -306,35 +306,39 @@ typedef TensorGPU tensorFP8e4; typedef TensorGPU tensorFP8e5; extern TensorGPU null_tensorX; -extern TensorGPU null_tensorFP32; +extern TensorGPU null_tensorFP32; template struct tensor128 { private: Packed128 data128; ElementType* data_ptr; - unsigned int *absmax_ptr; - float scale; - float descale; + unsigned int *absmax_ptr = nullptr; + float scale = 1.0f; + float descale = 1.0f; float new_absmax = 0.0f; bool wrote_data = false; bool wrote_absmax = false; public: - bool scaling = (sizeof(ElementType) == 33); // todo - fp8 only + bool scaling = true; // todo - fp8 only static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { - float2* __restrict__ ptr_restricted = (float2*)tensor.scale_descale_ptr; - float2 scale_descale = *ptr_restricted; - scale = scale_descale.x; - descale = scale_descale.y; data_ptr = tensor.data_ptr; - absmax_ptr = tensor.absmax_ptr; - if (disable_scaling) { + + if (!disable_scaling) { + float2* __restrict__ ptr_restricted = (float2*)tensor.scale_descale_ptr; + if (tensor.scale_descale_ptr == nullptr) { + printf("tensor.scale_descale_ptr: %p\n", tensor.scale_descale_ptr); + } + float2 scale_descale = *ptr_restricted; + scale = scale_descale.x; + descale = scale_descale.y; + absmax_ptr = tensor.absmax_ptr; + } else { scaling = false; } - scaling = false; } __device__ void load(size_t offset, bool cache_streaming=false) { @@ -416,6 +420,7 @@ public: return false; // if we return true, we can skip __syncthreads() in some kernels } wrote_absmax = true; + return false; // use native integer reductions as much as possible (supported on all GPUs with FP8) // this might treat NaN/INF slightly differently but that is the least of our problems @@ -502,13 +507,16 @@ constexpr size_t MAX_TENSORS = 16*1024; constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - should make this a command line option extern int num_tensor_specs; extern int current_absmax_index; -extern void* gpu_tensor_scale_memory; -extern void* gpu_tensor_absmax_memory; +extern float* gpu_scale_memory; +extern unsigned int* gpu_absmax_memory; + +__constant__ float* gpu_scale_memory_ptr; +__constant__ unsigned int* gpu_absmax_memory_ptr; enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRAD, PARAMETER_MASTER, PARAMETER_OPT_M, PARAMETER_OPT_V, // 1 allocation each + PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch - DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_OPT_V+1 + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 }; enum TFlags : uint8_t { @@ -536,25 +544,33 @@ typedef struct { char name[16]; template - operator T*() const { + __host__ __device__ operator T*() const { + // TODO !!! make it work device side! + /* if (std::is_same::value && data_type != DType::FP32 || std::is_same::value && data_type != DType::FP16 || std::is_same::value && data_type != DType::BF16) { printf("ERROR: Unexpected data type (%d) for tensor %s\n", (int)data_type, name); exit(EXIT_FAILURE); } + */ return reinterpret_cast(ptr); } template - operator TensorGPU() const { + __device__ __host__ operator TensorGPU() const { TensorGPU tensor; - int absmax_idx = id + (current_absmax_index * num_tensor_specs); - tensor.num_elements = num_elements; tensor.data_ptr = this->operator T*(); - tensor.scale_descale_ptr = reinterpret_cast(gpu_tensor_scale_memory) + id; - tensor.absmax_ptr = reinterpret_cast(gpu_tensor_absmax_memory) + absmax_idx; + + #ifdef __CUDA_ARCH__ + printf("gpu_scale_memory_ptr: %p\n", gpu_scale_memory_ptr); + tensor.scale_descale_ptr = gpu_scale_memory_ptr + 2*id; + tensor.absmax_ptr = gpu_absmax_memory_ptr + id; + #else + tensor.scale_descale_ptr = gpu_scale_memory + 2*id; + tensor.absmax_ptr = gpu_absmax_memory + id; + #endif return tensor; } diff --git a/train_gpt2.cu b/train_gpt2.cu index 077cde101..88ee65cf0 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -83,7 +83,8 @@ char filename_buffer[512]; // global vars containing information about the GPU this process is running on cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; -TensorGPU null_tensorX; +TensorGPU null_tensorX = {0}; +TensorGPU null_tensorFP32 = {0}; // buffer size to use for device <-> disk io constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; @@ -149,18 +150,19 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; -int num_tensor_specs = 0; -int current_absmax_index = 0; -void* gpu_tensor_scale_memory = NULL; -void* gpu_tensor_absmax_memory = NULL; - TensorSpec tensor_specs[MAX_TENSORS] = {0}; TensorSpec* tensor_specs_gpu = NULL; + TT current_tensor_type = TT::PARAMETER; size_t tensors_start[TT::COUNT] = {0}; size_t tensors_bytes[TT::COUNT] = {0}; size_t tensors_elements[TT::COUNT] = {0}; +int num_tensor_specs = 0; +int current_absmax_index = 0; +float* gpu_scale_memory = NULL; +unsigned int* gpu_absmax_memory = NULL; + TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { TensorSpec spec = tensor_specs[spec_index]; if (layer > 0 && spec.remaining_layers >= layer) { @@ -258,7 +260,7 @@ void gpt2_allocate(GPT2 *model) { int shards_grad = (multi_gpu_config.zero_stage >= 2) ? num_gpu : 1; // 1) parameters & optimizer state - for (int t = PARAMETER; t <= PARAMETER_OPT_V; t++) { + for (int t = PARAMETER; t <= PARAMETER_MASTER; t++) { DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; DType dtype_lowp = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future @@ -418,7 +420,7 @@ void gpt2_allocate(GPT2 *model) { spec->ptr = model->multiuse_memory + spec->offset; break; default: - assert(spec->tensor_type <= PARAMETER_OPT_V); + assert(spec->tensor_type <= PARAMETER_MASTER); spec->ptr = model->params_memory[spec->tensor_type] + spec->offset; } } @@ -442,10 +444,22 @@ void gpt2_allocate(GPT2 *model) { // allocate_state stuff // ======================= // absmax/scaling/descaling buffers for FP8 & Friends - cudaMalloc(&gpu_tensor_scale_memory, sizeof(float) * num_tensor_specs); - cudaMemset(gpu_tensor_scale_memory, 0, sizeof(float) * num_tensor_specs); - cudaMalloc(&gpu_tensor_absmax_memory, sizeof(float) * num_tensor_specs * MAX_ABSMAX_HISTORY); - cudaMemset(gpu_tensor_absmax_memory, 0, sizeof(float) * num_tensor_specs * MAX_ABSMAX_HISTORY); + cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); + cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); + + // Initialize gpu_scale_memory with 1.0f for all elements + size_t scale_memory_elements = 2 * num_tensor_specs; + cudaMalloc(&gpu_scale_memory, scale_memory_elements * sizeof(float)); + float* h_scale_memory = (float*)malloc(scale_memory_elements * sizeof(float)); + for (size_t i = 0; i < scale_memory_elements; ++i) { + h_scale_memory[i] = 1.0f; + } + cudaMemcpy(gpu_scale_memory, h_scale_memory, scale_memory_elements * sizeof(float), cudaMemcpyHostToDevice); + free(h_scale_memory); + + // copy to constant buffers + cudaMemcpyToSymbol(gpu_scale_memory_ptr, gpu_scale_memory, sizeof(float*)); + cudaMemcpyToSymbol(gpu_absmax_memory_ptr, gpu_absmax_memory, sizeof(unsigned int*)); cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); @@ -466,6 +480,7 @@ void gpt2_allocate(GPT2 *model) { size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATIONS_MULTIUSE] / B; // pessimistic (output buffer etc.) printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); + cudaCheck(cudaGetLastError()); } void gpt2_init_common(GPT2 *model) { @@ -1064,49 +1079,35 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; - // todo: merge everything into 1 kernel call - for (int i = 0; i < tensors_start[PARAMETER_GRAD];) { - unsigned int seed = random_u32(&model->rng_state); - - TensorSpec param_spec = tensor_specs[i]; - TensorSpec grad_spec = tensor_specs[i + tensors_start[PARAMETER_GRAD]]; - TensorSpec master_spec = tensor_specs[i + tensors_start[PARAMETER_MASTER]]; - TensorSpec opt_m_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_M]]; - TensorSpec opt_v_spec = tensor_specs[i + tensors_start[PARAMETER_OPT_V]]; - - // todo - adjust offset into params/grads when optimiser state is sharded - float* master_ptr = NULL; - if (model->params_memory[PARAMETER_MASTER] != NULL) { - master_ptr = (float*)master_spec.ptr; - } - - size_t tensor_elements = param_spec.num_elements; - size_t shard_elements = master_spec.num_elements; - int num_layers = param_spec.remaining_layers + 1; + float beta1_correction = 1.0f - powf(beta1, t); + float beta2_correction = 1.0f - powf(beta2, t); + unsigned int seed = random_u32(&model->rng_state); + int num_shards = tensor_specs[tensors_start[PARAMETER_OPT_M]].num_shards; - if(init_state && model->use_master_weights) { - size_t grid_size = CEIL_DIV(shard_elements, 512); - copy_and_cast_kernel<<>>(master_ptr, ((tensorX)param_spec).data_ptr, shard_elements, shard_elements, tensor_elements); - cudaCheck(cudaGetLastError()); - } - - // todo - make it configurable whether weight decay applies to e.g. bias or not - float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - - if (init_from_master_only) { - // when resuming training from a checkpoint with master weights (allows changing precision) - //init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); - assert(false); + const int block_size = 64; + const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; + if (model->use_master_weights) { + if (init_state || init_from_master_only) { + // reads regular weights & writes to master+regular weights + // or init_from_master_only: reads master & write to regular weights as-is + adamw_full_update<<>>( + tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], + model->num_parameters, model->num_parameters / num_shards, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t, + init_from_master_only); } else { - // ok finally call the kernel to update the weights with AdamW - adamw_update((tensorX)param_spec, master_ptr, (tensorX)grad_spec, - (float*)opt_m_spec.ptr, (float*)opt_v_spec.ptr, - shard_elements, tensor_elements, tensor_elements, shard_elements, num_layers, - learning_rate, - beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); + // reads master weights & writes to master+regular weights + adamw_full_update<<>>( + tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], + model->num_parameters, model->num_parameters / num_shards, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } - - i += num_layers; + } else { + // reads & writes regular weights only + adamw_full_update<<>>( + tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], + model->num_parameters, model->num_parameters / num_shards, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } // AdamW update @@ -1232,11 +1233,6 @@ void common_start(bool override_enable_tf32 = true, bool print_device_info = tru printf("Device %d: %s\n", multi_gpu_config.local_device_idx, deviceProp.name); } - null_tensorX.data_ptr = nullptr; - null_tensorX.absmax_ptr = nullptr; - null_tensorX.scale_descale_ptr = nullptr; - null_tensorX.num_elements = 0; - // set up the cuda streams. atm everything is on the single main stream cudaCheck(cudaStreamCreate(&main_stream)); nvtxNameCudaStreamA(main_stream, "main stream"); From eedb4d04afe58a5f14a6aa5e2ebefd8bb8e9ad5e Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 10 Sep 2024 19:31:51 +0000 Subject: [PATCH 12/27] WIP, seems to all kinda work (famous last words) - but cuBLAS doesn't support scaling for BF16... --- llmc/adamw.cuh | 6 +- llmc/attention.cuh | 20 ++++--- llmc/cuda_utils.cuh | 37 +++++++----- llmc/layernorm.cuh | 20 ++++--- llmc/matmul.cuh | 50 ++++++++++++---- train_gpt2.cu | 139 +++++++++++++++++++++++++------------------- 6 files changed, 172 insertions(+), 100 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index ab437964b..006f7c72c 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -63,7 +63,9 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, // todo - make it configurable whether weight decay applies to e.g. bias or not float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - if (specs[spec_id].data_type == DType::BF16) { + if (specs[spec_id].data_type == DType::BF16 || specs[spec_id].data_type == DType::FP16) { + // todo - this is actually "EQUAL FLOATX" right now, doesn't work for mix and match + // !!! TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; auto out_param128 = new_tensor128(param_tensor); @@ -125,6 +127,8 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, idx_blk += stride; idx += stride; } + } else { + assert(false); // TODO } } } diff --git a/llmc/attention.cuh b/llmc/attention.cuh index e65dcef21..f91380ad3 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -70,7 +70,11 @@ __global__ void permute_kernel_backward(tensorX dinp, dinp128_q.store(inp_idx); dinp128_k.store(inp_idx + NH * d); dinp128_v.store(inp_idx + 2 * (NH * d)); - dinp128_q.update_absmax(threadIdx.x, blockDim.x, true); + + // todo - merge this into 1 update + dinp128_q.update_absmax(threadIdx.x, blockDim.x, false); + dinp128_k.update_absmax(threadIdx.x, blockDim.x, false); + dinp128_v.update_absmax(threadIdx.x, blockDim.x, true); } __global__ void unpermute_kernel(tensorX out, floatX* inp, int B, int N, int NH, int d) { @@ -245,7 +249,7 @@ void attention_forward(tensorX out, floatX* qkvr, floatX* att, permute_kernel<<>>(q, k, v, inp, B, T, NH, HS); floatX* preatt = inp; // reuse inp as scratch buffer - matmul_cublaslt(tensorX::from(preatt), tensorX::from(k), tensorX::from(q), nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(preatt), tensorX::from(k), tensorX::from(q), null_tensorX, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // multiply all elements of preatt elementwise by scale float scale = 1.f / sqrtf(HS); @@ -255,7 +259,7 @@ void attention_forward(tensorX out, floatX* qkvr, floatX* att, // new approach: first cuBLAS another batched matmul floatX* vaccum = inp; // y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs) - matmul_cublaslt(tensorX::from(vaccum), tensorX::from(v), tensorX::from(att), nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(vaccum), tensorX::from(v), tensorX::from(att), null_tensorX, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // now unpermute // y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side @@ -286,18 +290,20 @@ void attention_backward(tensorX dinp, floatX* dqkvr, floatX* datt, floatX* scrat // backward through the unpermute operation int num_blocks = CEIL_DIV(B * T * C, block_size * dout.num_per_128()); unpermute_kernel_backward<<>>(scratch, dout, B, T, NH, HS); + cudaCheck(cudaGetLastError()); // backward into datt - matmul_cublaslt(tensorX::from(datt), tensorX::from(v), tensorX::from(scratch), nullptr, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); + matmul_cublaslt(tensorX::from(datt), tensorX::from(v), tensorX::from(scratch), null_tensorX, T, T, HS, stream, true, false, B * NH, T * HS, T * HS, T * T); // backward into dv - matmul_cublaslt(tensorX::from(dv), tensorX::from(scratch), tensorX::from(att), nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dv), tensorX::from(scratch), tensorX::from(att), null_tensorX, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); const float scale = 1.0f / sqrtf((float)HS); // backward into preatt. this is an in-place operation; datt turns into dpreatt here softmax_autoregressive_backward_inplace_kernel<<>>(datt, att, B, T, C, scale); + cudaCheck(cudaGetLastError()); floatX* dpreatt = datt; // backward into q - matmul_cublaslt(tensorX::from(dq), tensorX::from(k), tensorX::from(dpreatt), nullptr, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dq), tensorX::from(k), tensorX::from(dpreatt), null_tensorX, HS, T, T, stream, false, false, B * NH, T * HS, T * T, T * HS); // backward into k - matmul_cublaslt(tensorX::from(dk), tensorX::from(q), tensorX::from(dpreatt), nullptr, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); + matmul_cublaslt(tensorX::from(dk), tensorX::from(q), tensorX::from(dpreatt), null_tensorX, HS, T, T, stream, false, true, B * NH, T * HS, T * T, T * HS); // backward into inp num_blocks = CEIL_DIV(B * NH * T * HS, block_size * dinp.num_per_128()); permute_kernel_backward<<>>(dinp, dq, dk, dv, B, T, NH, HS); diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 4f266caf8..531f903e4 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -5,6 +5,12 @@ #include "cuda_common.h" +struct TensorSpec; // Forward declaration + +__device__ __constant__ TensorSpec* tensor_specs_ptr; +__device__ __constant__ float* gpu_scale_memory_ptr; +__device__ __constant__ unsigned int* gpu_absmax_memory_ptr; + // ---------------------------------------------------------------------------- // Packed128 data structure that forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) @@ -177,8 +183,12 @@ __device__ __host__ unsigned int get_random_noise(unsigned int seed, unsigned in // (didn't matter with BF16 because denorms are so tiny they're irrelevant, unlike in FP8/FP16) template __device__ __forceinline__ void stochastic_rounding(float in, Ti *out, unsigned int seed, float prob_offset=0.0f) { - unsigned int random = noise ? get_random_noise(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed) : seed; + if constexpr (std::is_same::value) { + *out = in; + return; + } + unsigned int random = noise ? get_random_noise(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed) : seed; // prob_offset allows rounding towards gradient more of the time (one paper recommends that) // e.g. +0.3f ==> 65% chance up, 35% chance down highp threshold_percentage = ((highp)random / (highp)0xFFFFFFFF) - prob_offset; @@ -236,15 +246,13 @@ __device__ __forceinline__ void stochastic_rounding(float in, Ti *out, unsigned highp lerp = ((highp)in - (highp)rounded_down) / diff; // division by 0 is OK as it means (up == down) anyway *out = (lerp > threshold_percentage) ? rounded_up : rounded_down; } -__device__ __forceinline__ void stochastic_rounding(float in, float *out, unsigned int random) { - *out = in; // dummy function for when floatX is float (FP32 mode) -} // ---------------------------------------------------------------------------- // ... template struct TensorGPU { ElementType* data_ptr; + int id; float* scale_descale_ptr; unsigned int* absmax_ptr; size_t num_elements; @@ -252,6 +260,7 @@ struct TensorGPU { static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { TensorGPU tmp = {0}; tmp.data_ptr = ptr; + tmp.id = -1; return tmp; } @@ -319,18 +328,21 @@ private: float new_absmax = 0.0f; bool wrote_data = false; bool wrote_absmax = false; + int id = -1; public: bool scaling = true; // todo - fp8 only static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); + __device__ tensor128() { scaling = false; } __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { data_ptr = tensor.data_ptr; + id = tensor.id; if (!disable_scaling) { float2* __restrict__ ptr_restricted = (float2*)tensor.scale_descale_ptr; if (tensor.scale_descale_ptr == nullptr) { - printf("tensor.scale_descale_ptr: %p\n", tensor.scale_descale_ptr); + assert(false); } float2 scale_descale = *ptr_restricted; scale = scale_descale.x; @@ -483,7 +495,10 @@ public: __device__ ~tensor128() { // this should ~always be optimised away by the compiler - assert(wrote_absmax || !scaling || !wrote_data); + if (!wrote_absmax && scaling && wrote_data) { + printf("id: %d\n", id); + assert(false); + } } }; @@ -510,9 +525,6 @@ extern int current_absmax_index; extern float* gpu_scale_memory; extern unsigned int* gpu_absmax_memory; -__constant__ float* gpu_scale_memory_ptr; -__constant__ unsigned int* gpu_absmax_memory_ptr; - enum TT : uint8_t { PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch @@ -530,8 +542,7 @@ enum TFlags : uint8_t { EMBEDDING=64, STATS=128 }; - -typedef struct { +struct TensorSpec { char* ptr; size_t offset; // into base pointer size_t num_elements; // per shard @@ -562,9 +573,9 @@ typedef struct { TensorGPU tensor; tensor.num_elements = num_elements; tensor.data_ptr = this->operator T*(); + tensor.id = id; #ifdef __CUDA_ARCH__ - printf("gpu_scale_memory_ptr: %p\n", gpu_scale_memory_ptr); tensor.scale_descale_ptr = gpu_scale_memory_ptr + 2*id; tensor.absmax_ptr = gpu_absmax_memory_ptr + id; #else @@ -574,7 +585,7 @@ typedef struct { return tensor; } -} TensorSpec; +}; // ---------------------------------------------------------------------------- // Copy, cast functions diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 6f602754f..130320a4e 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -195,10 +195,10 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with for (int c = 0; c < iterations_C; c++) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); - auto dout128 = new_tensor128(dout); - auto inp128 = new_tensor128(inp); - auto dinp128 = new_tensor128(dinp_old); - auto weight128 = new_tensor128(weight); + tensor128 dout128; + tensor128 inp128; + tensor128 weight128; + tensor128 dinp128; if(global_index < C) { dout128 = load_tensor128(dout, bt * C + global_index, true); @@ -316,6 +316,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // convert from float/FP32 to floatX/BF16 for the final write // this is separate because it cannot use as many warps as the above (f128 vs x128) // todo - if we split this code into another kernel, we could maybe do it at the same time? + auto dbias128_out = new_tensor128(dbias); + auto dweight128_out = new_tensor128(dweight); for (int c = warpId; c < iterations_C; c += warpsInBlock) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); if (global_index >= C) { @@ -329,13 +331,15 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with f128 s_dw = load128(dweight_shared + global_index + o * f128::size); for(int i = 0; i < f128::size; ++i) { int x = o * f128::size + i; - dbias128.set(x, s_db[i] + dbias128.get(x)); - dweight128.set(x, s_dw[i] + dweight128.get(x)); + dbias128_out.set(x, s_db[i] + dbias128.get(x)); + dweight128_out.set(x, s_dw[i] + dweight128.get(x)); } } - dbias128.store_same_length(global_index); - dweight128.store_same_length(global_index); + dbias128_out.store_same_length(global_index); + dweight128_out.store_same_length(global_index); } + dbias128_out.update_absmax(threadIdx.x, BLOCK_SIZE, false); + dweight128_out.update_absmax(threadIdx.x, BLOCK_SIZE, false); } } diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index 2f5e07061..a338f0c57 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -106,17 +106,17 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c // https://docs.nvidia.com/cuda/cublas/#cublasltmatmul -void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* bias, +void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX bias, int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false, int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0, - bool accumulate=false, floatX* pre_gelu=NULL, bool backward=false) + bool accumulate=false, tensorX pre_gelu=null_tensorX, bool backward=false) { NVTX_RANGE_FN(); - bool has_bias = (bias != NULL); - bool has_gelu = (pre_gelu != NULL); + bool has_bias = (bias.data_ptr != NULL); + bool has_gelu = (pre_gelu.data_ptr != NULL); // check alignment (some modes work unaligned but it always best to be aligned for performance) - if(((uintptr_t)a % 16) != 0 || ((uintptr_t)b % 16) != 0 || ((uintptr_t)d % 16) != 0 || ((uintptr_t)bias % 16) != 0) { + if(((uintptr_t)a.data_ptr % 16) != 0 || ((uintptr_t)b.data_ptr % 16) != 0 || ((uintptr_t)d.data_ptr % 16) != 0 || ((uintptr_t)bias.data_ptr % 16) != 0) { printf("All cuBLASLt pointers must be aligned!\n"); exit(EXIT_FAILURE); } @@ -176,12 +176,22 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* if (has_gelu) { int64_t gelu_ld = m; // todo - is this affected by anything else? cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &gelu_ld, sizeof(gelu_ld))); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu, sizeof(pre_gelu))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu.data_ptr, sizeof(pre_gelu.data_ptr))); if (backward) { assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias epilogue = CUBLASLT_EPILOGUE_DGELU; + if (pre_gelu.scale_descale_ptr) { // descale input + float* gelu_descale_ptr = pre_gelu.scale_descale_ptr + 1; + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER, &gelu_descale_ptr, sizeof(float*))); + } } else { epilogue = has_bias ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; + if (pre_gelu.absmax_ptr) { + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER, &pre_gelu.absmax_ptr, sizeof(pre_gelu.absmax_ptr))); + } + if (pre_gelu.scale_descale_ptr) { // scale output + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER, &pre_gelu.scale_descale_ptr, sizeof(float*))); + } } } else if(has_bias){ epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; @@ -194,7 +204,23 @@ void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* // cuBLASLt requires bias in FP8 mode to be BF16... (sigh) cublasDataType_t bias_data_type = (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP; // force BF16 bias for FP8 mode cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type))); - cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias.data_ptr, sizeof(bias.data_ptr))); + } + + // scale factors + if (a.scale_descale_ptr) { + //float* a_descale_ptr = a.scale_descale_ptr + 1; + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_descale_ptr, sizeof(float*))); + } + if (b.scale_descale_ptr) { + //float* b_descale_ptr = b.scale_descale_ptr + 1; + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_descale_ptr, sizeof(float*))); + } + if (d.scale_descale_ptr) { + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d.scale_descale_ptr, sizeof(float*))); + } + if (d.absmax_ptr) { + //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &d.absmax_ptr, sizeof(float*))); } // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) @@ -235,7 +261,7 @@ void matmul_forward_cublaslt(tensorX out, TensorGPU pre_gelu=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) if (gelu_fusion < 1 && pre_gelu != null_tensorX) { - matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); + matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); gelu_forward(out, pre_gelu, stream); } else { matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); @@ -277,8 +303,8 @@ void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, } // backward to input, uses = in the backward pass (set the gradient) - matmul_cublaslt(dinp, weight, dout, NULL, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, - gelu_fusion >= 2 ? pre_gelu.data_ptr : NULL, true); + matmul_cublaslt(dinp, weight, dout, null_tensorX, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, + gelu_fusion >= 2 ? pre_gelu : null_tensorX, true); // backward GELU (if it wasn't fused into the matmul above) if (gelu_fusion < 2 && pre_gelu != null_tensorX) { @@ -286,6 +312,6 @@ void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one - matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, BT, stream, false, true, 0, 0, 0, 0, - true /* accumulate */, NULL, true); + matmul_cublaslt(dweight, inp, dout, null_tensorX /*dbias*/, C, OC, BT, stream, false, true, 0, 0, 0, 0, + true /* accumulate */, null_tensorX, true); } diff --git a/train_gpt2.cu b/train_gpt2.cu index 88ee65cf0..1fec4c1ac 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -150,6 +150,9 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; +GPT2 model; // todo - move back +bool backward = false; // todo - hack - REMOVE + TensorSpec tensor_specs[MAX_TENSORS] = {0}; TensorSpec* tensor_specs_gpu = NULL; @@ -163,6 +166,74 @@ int current_absmax_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; +// debug helper function +void print_tensor_elements(int tensor_id) { + return; + if (backward == false) return; + + printf("Printing tensor %d\n", tensor_id); + TensorSpec spec = tensor_specs[tensor_id]; + size_t num_elements = spec.num_elements; + const char* tensor_name = spec.name; + TT tensor_type = spec.tensor_type; + DType dtype = spec.data_type; + size_t element_size = sizeof_dtype(dtype); + + void* gpu_memory = (tensor_type == TT::ACTIVATIONS_MULTIUSE) ? model.multiuse_memory : model.params_memory[tensor_type]; + void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); + void* cpu_tensor = malloc(num_elements * element_size); + + printf("Printing tensor %s\n", tensor_name); + printf("GPU memory: %p\n", gpu_tensor); + printf("CPU memory: %p\n", cpu_tensor); + printf("Num elements: %zu\n", num_elements); + printf("Element size: %zu\n", element_size); + printf("Offset: %zu\n", tensor_specs[tensor_id].offset); + + cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); + + printf("Did memcpy\n"); + + printf("First 4 of %s: ", tensor_name); + for (int i = 0; i < num_elements && i < 4; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Middle 4 of %s: ", tensor_name); + for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Last 4 of %s: ", tensor_name); + for (int i = num_elements - 4; i < num_elements; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + printf("\n"); + + free(cpu_tensor); +} + TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { TensorSpec spec = tensor_specs[spec_index]; if (layer > 0 && spec.remaining_layers >= layer) { @@ -172,6 +243,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { assert(false); } assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); + print_tensor_elements(spec_index); return spec; } @@ -179,7 +251,8 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, assert(num_tensor_specs < 16*1024); assert((total_elements % num_shards) == 0); TensorSpec* spec = &tensor_specs[num_tensor_specs]; - strncpy(spec->name, name, 16); + strncpy(spec->name, name, 15); + spec->name[15] = 0; spec->id = num_tensor_specs; spec->num_elements = total_elements / num_shards; @@ -426,7 +499,7 @@ void gpt2_allocate(GPT2 *model) { } // we are finished creating the tensors specs and can copy them to the GPU (effectively read-only) - cudaMalloc(&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); + cudaMalloc((void**)&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); cudaMemcpy(tensor_specs_gpu, tensor_specs, sizeof(TensorSpec) * num_tensor_specs, cudaMemcpyHostToDevice); //initialise helper variables @@ -458,8 +531,9 @@ void gpt2_allocate(GPT2 *model) { free(h_scale_memory); // copy to constant buffers - cudaMemcpyToSymbol(gpu_scale_memory_ptr, gpu_scale_memory, sizeof(float*)); - cudaMemcpyToSymbol(gpu_absmax_memory_ptr, gpu_absmax_memory, sizeof(unsigned int*)); + cudaMemcpyToSymbol(tensor_specs_ptr, &tensor_specs_gpu, sizeof(TensorSpec*)); + cudaMemcpyToSymbol(gpu_scale_memory_ptr, &gpu_scale_memory, sizeof(float*)); + cudaMemcpyToSymbol(gpu_absmax_memory_ptr, &gpu_absmax_memory, sizeof(unsigned int*)); cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); @@ -699,60 +773,6 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { free(params_memory_cpu); } -// debug helper function -void print_tensor_elements(GPT2 *model, int tensor_id) { - TensorSpec spec = tensor_specs[tensor_id]; - size_t num_elements = spec.num_elements; - const char* tensor_name = spec.name; - TT tensor_type = spec.tensor_type; - DType dtype = spec.data_type; - size_t element_size = sizeof_dtype(dtype); - - void* gpu_memory = (tensor_id == TT::ACTIVATIONS_MULTIUSE) ? model->multiuse_memory : model->params_memory[tensor_type]; - void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); - void* cpu_tensor = malloc(num_elements * element_size); - cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); - - printf("First 4 of %s: ", tensor_name); - for (int i = 0; i < num_elements && i < 4; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Middle 4 of %s: ", tensor_name); - for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Last 4 of %s: ", tensor_name); - for (int i = num_elements - 4; i < num_elements; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - printf("\n"); - - free(cpu_tensor); -} - // Helper macros for accessing tensors #define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) #define ACT_L(x,layer) get_tensor(model->acts.x, ACTIVATIONS_MULTIUSE, layer) @@ -856,6 +876,7 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { NVTX_RANGE_FN(); + backward = true; // todo - hack - REMOVE // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; @@ -1623,7 +1644,7 @@ int main(int argc, char *argv[]) { } // build the GPT-2 model - GPT2 model; + // todo - add model declaration back here gpt2_init_common(&model); model.use_master_weights = use_master_weights; model.gelu_fusion = gelu_fusion; From b09dbc9b211f54f416d2cfcaad076e5552c6cea1 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 16 Sep 2024 10:54:35 +0000 Subject: [PATCH 13/27] fake FP8 kinda-sorta works --- llmc/adamw.cuh | 37 +- llmc/copy_and_fp8.h | 694 ++++++++++++++++++++++++++++++++++++++ llmc/cuda_common.h | 2 +- llmc/cuda_utils.cuh | 309 +++++++++-------- llmc/fused_classifier.cuh | 10 +- llmc/gelu.cuh | 2 +- llmc/layernorm.cuh | 3 +- llmc/matmul.cuh | 7 +- train_gpt2.cu | 182 +++++++--- 9 files changed, 1041 insertions(+), 205 deletions(-) create mode 100644 llmc/copy_and_fp8.h diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 006f7c72c..e5f53a140 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -21,6 +21,8 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, float grad_scale, int t, bool init_from_master_only=false) { // ... + __shared__ int shared_spec_id; + constexpr size_t block_size = 64; // 64 ==> 4KiB chunks with iteration_size=16 for FP32 opt/master size_t iteration_size = 16; assert(iteration_size <= 16); @@ -39,17 +41,30 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, size_t current_start = opt_spec.offset / sizeof(float); size_t current_end = current_start + opt_spec.num_elements; - while (idx < num_opt_parameters) { - // todo - do this part on thread 0 only? - while (idx >= current_end) { - spec_id++; - if (spec_id >= num_params_tensors) { - return; + while (true) { + // todo - performance analysis/optimisation! (impact of using step 0?) + if (threadIdx.x == 0) { + while (idx >= current_end) { + spec_id++; + if (spec_id >= num_params_tensors) { + shared_spec_id = -1; + return; + } + opt_spec = opt_v_specs[spec_id]; + current_start = opt_spec.offset / sizeof(float); + current_end = current_start + opt_spec.num_elements; } - opt_spec = opt_v_specs[spec_id]; - current_start = opt_spec.offset / sizeof(float); - current_end = current_start + opt_spec.num_elements; + shared_spec_id = spec_id; } + __syncthreads(); + spec_id = shared_spec_id; + if (spec_id == -1) { + return; + } + + opt_spec = opt_v_specs[spec_id]; + current_start = opt_spec.offset / sizeof(float); + current_end = current_start + opt_spec.num_elements; TensorGPU grad_tensor = grad_specs[spec_id]; TensorGPU master_tensor = master_specs[spec_id]; @@ -63,7 +78,7 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, // todo - make it configurable whether weight decay applies to e.g. bias or not float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - if (specs[spec_id].data_type == DType::BF16 || specs[spec_id].data_type == DType::FP16) { + if (specs[spec_id].data_type == DType::BF16) { // todo - this is actually "EQUAL FLOATX" right now, doesn't work for mix and match // !!! TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; @@ -123,10 +138,10 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, } out_param128.store(offset); } - out_param128.update_absmax(threadIdx.x, block_size, false); idx_blk += stride; idx += stride; } + out_param128.update_absmax(threadIdx.x, block_size, false); } else { assert(false); // TODO } diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h new file mode 100644 index 000000000..4cdbb556f --- /dev/null +++ b/llmc/copy_and_fp8.h @@ -0,0 +1,694 @@ +/* +Helpers for FP8 including copy and transpose with format conversion, and absmax +See /dev/cuda/advanced_copy_transpose.cu for more information and options +*/ +#ifndef FP8_HELPERS_CUH +#define FP8_HELPERS_CUH + +#include +#include +#include "cuda_common.h" +#include "cuda_utils.cuh" + +// todo - tune these for performance (but should be close to optimal already) +#define ABSMAX_ITERATIONS_PER_THREAD 4 +#define TRANSPOSE_TILE_SIZE 64UL + +// ---------------------------------------------------------------------------- +// elementwise functions which can be applied as part of the copy/transpose +// for elementwise kernels that require metadata (e.g. layernorm forward with known mean/std), +// we could maybe store it in constant buffers rather than in yet-another-function-parameter... +using elementwise_func_t = float (*) (float); +__device__ float nothing_elementwise(float x) { + return x; +} +__device__ float gelu_forward_elementwise(float x) { + float cube = 0.044715f * x * x * x; + + float tanh_out; + float tanh_arg = sqrtf(2.0f / M_PI) * (x + cube); + asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_out) : "f"(tanh_arg)); + + // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" + float half_x = 0.5f * x; + return half_x * tanh_out + half_x; +} + +// ---------------------------------------------------------------------------- +// CUDA kernels + +// Same as copy_simple_kernel but with optional absmax and elementwise function options +// absmax is calculated before scaling but after the elementwise function +template +__global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { + constexpr size_t vec_size = 16 / ((sizeof(T1) < sizeof(T2)) ? sizeof(T2) : sizeof(T1)); + size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; + size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; + if (idx >= in.num_elements) { return; } + + auto inp128 = load_tensor128(in, idx, true, disable_scaling); + auto out128 = new_tensor128(out); + for (int k = 0; k < vec_size; k++) { + float out_fp32 = elementwise_func(inp128.get(k)); + out128.set(k, out_fp32); + } + out128.store_same_length(idx); + out128.update_absmax(threadIdx.x, block_size, true); +} + +/* +// transpose + copy + format conversion (+ elementwise + absmax) kernel +template +__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, + const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, + unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) +{ + constexpr size_t TILE_DIM_PADDED = TILE_DIM + 4/sizeof(T1); + __shared__ T1 tile[TILE_DIM][TILE_DIM_PADDED]; + int width = gridDim.x * TILE_DIM; + int height = gridDim.y * TILE_DIM; + + constexpr size_t T1_elements = 16 / sizeof(T1); + constexpr size_t T2_elements = 16 / sizeof(T2); + constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; + + float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal + float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; + scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; + int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); + int y = blockIdx.y * TILE_DIM + threadIdx.y; + uint absmax_uint = 0; + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + Packed128 in128 = load128cs(input + x + (y+j)*width); + Packed128 copy128[copy_vectors]; + for (int k = 0; k < in128.size; k++) { + T2 in = in128[k]; + float out_float = elementwise_func((float)in * descale_factor); + update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax + + T1 out = (T1)(out_float * scale_factor); + copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused + } + + for (int o = 0; o < copy_vectors; o++) { + if constexpr (enable_copy) { + store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); + } + + size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM_PADDED + o*T1_elements; + int* one_bank = reinterpret_cast(&tile[0][0] + tile_offset); + for (int k = 0; k < 4; k++) { + one_bank[k] = *(int*)(©128[o][k*4/sizeof(T1)]); + } + //store_same_length(&tile[0][0] + tile_offset, copy128[o]); + } + } + + if constexpr (absmax_factor != 0) { + update_global_absmax(absmax_output, absmax_uint); + } else { + __syncthreads(); + } + + // reduce the number of threads for the write if T1_elements > T2_elements + // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first + // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y + constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; + constexpr size_t block_size_y = BLOCK_ROWS; + + constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; + constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; + constexpr size_t block_size_x_div_r = block_size_x / ratio; + constexpr size_t block_size_y_div_r = block_size_y / ratio; + + int adjusted_tid_x = threadIdx.x % block_size_x_div_r; + int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); + if (threadIdx.y >= block_size_y_div_r) { return; } + + // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) + if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } + + // x/y for final write to global memory + x = blockIdx.y * TILE_DIM + adjusted_tid_x * T1_elements; + y = blockIdx.x * TILE_DIM + adjusted_tid_y; + + constexpr int in_parallel = 4/sizeof(T1); + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS * in_parallel) { + if ((j+adjusted_tid_y) * in_parallel >= TILE_DIM) { return; } + + // we need more instructions for the write than the read if T2_elements > T1_elements + #pragma unroll + for (int o = 0; o < copy_vectors; o++) { + Packed128 out128[in_parallel]; + #pragma unroll + for (int k = 0; k < Packed128::size; k++) { + int in32 = *(int*)(&tile[k + (adjusted_tid_x + o * blockDim.x) * Packed128::size][(adjusted_tid_y + j) * in_parallel]); + for (int p = 0; p < in_parallel; p++) { + out128[p][k] = ((T1*)&in32)[p]; + } + } + for (int p = 0; p < in_parallel; p++) { + store128(transposed + x + (o * blockDim.x * Packed128::size) + (y+p + j * in_parallel)*height, out128[p]); + } + } + } +} +*/ + +/* +template +__global__ void transpose_kernel_tensor(TensorGPU transposed, TensorGPU copy, TensorGPU input, int height) { + __shared__ T1 tile[TILE_DIM][TILE_DIM]; + int width = gridDim.x * TILE_DIM; + height = gridDim.y * TILE_DIM; + + constexpr bool disable_scaling = (sizeof(T1) == sizeof(T2)); // TODO - THIS IS WRONG - need to check types are identical, not just same size! + constexpr size_t T1_elements = 16 / sizeof(T1); + constexpr size_t T2_elements = 16 / sizeof(T2); + constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; + + int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); + int y = blockIdx.y * TILE_DIM + threadIdx.y; + + tensor128 copy128 = new_tensor128(copy, disable_scaling); + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + auto in128 = load_tensor128(input, x + (y+j)*width, true, disable_scaling); + Packed128 in128 = load128cs(input + x + (y+j)*width); + Packed128 copy128[copy_vectors]; + for (int k = 0; k < in128.size; k++) { + float out_float = elementwise_func(in128.get(k)); + copy128.set(k % T1_elements, out_float * scale_factor); // optimised away by compiler if unused + + if (k+1 == out128.size) { + // ... + + } + } + + for (int o = 0; o < copy_vectors; o++) { + if constexpr (enable_copy) { + store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); + } + size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM + o*T1_elements; + store_same_length(&tile[0][0] + tile_offset, copy128[o]); + } + } + + +} +*/ + + + + + +// transpose + copy + format conversion (+ elementwise + absmax) kernel +template +__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, int height, + const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, + unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) +{ + /* + __shared__ T1 tile[TILE_DIM][TILE_DIM]; + int width = gridDim.x * TILE_DIM; + height = gridDim.y * TILE_DIM; + + constexpr size_t T1_elements = 16 / sizeof(T1); + constexpr size_t T2_elements = 16 / sizeof(T2); + constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; + + float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal + float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; + scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; + int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); + int y = blockIdx.y * TILE_DIM + threadIdx.y; + uint absmax_uint = 0; + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + Packed128 in128 = load128cs(input + x + (y+j)*width); + Packed128 copy128[copy_vectors]; + for (int k = 0; k < in128.size; k++) { + T2 in = in128[k]; + float out_float = elementwise_func((float)in * descale_factor); + update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax + + T1 out = (T1)(out_float * scale_factor); + copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused + } + + for (int o = 0; o < copy_vectors; o++) { + if constexpr (enable_copy) { + store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); + } + size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM + o*T1_elements; + store_same_length(&tile[0][0] + tile_offset, copy128[o]); + } + } + + if constexpr (absmax_factor != 0) { + update_global_absmax(absmax_output, absmax_uint); + } else { + __syncthreads(); + } + + // reduce the number of threads for the write if T1_elements > T2_elements + // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first + // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y + constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; + constexpr size_t block_size_y = BLOCK_ROWS; + + constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; + constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; + constexpr size_t block_size_x_div_r = block_size_x / ratio; + constexpr size_t block_size_y_div_r = block_size_y / ratio; + + int adjusted_tid_x = threadIdx.x % block_size_x_div_r; + int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); + if (threadIdx.y >= block_size_y_div_r) { return; } + + // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) + if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } + + // x/y for final write to global memory + x = blockIdx.y * TILE_DIM + adjusted_tid_x * T1_elements; + y = blockIdx.x * TILE_DIM + adjusted_tid_y; + + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + // we need more instructions for the write than the read if T2_elements > T1_elements + #pragma unroll + for (int o = 0; o < copy_vectors; o++) { + Packed128 out128; + #pragma unroll + for (int k = 0; k < out128.size; k++) { + // these are tiny 8-bit loads with loads of bank conflicts for FP8 + // extremely hard to avoid and not a bottleneck when everything else is well optimised + out128[k] = tile[k + (adjusted_tid_x + o * blockDim.x) * out128.size][adjusted_tid_y + j]; + } + store128(transposed + x + (o * blockDim.x * out128.size) + (y+j)*height, out128); + } + } + */ +} + + +/* +// best I could come up with (without using TMA) - no bank conflicts, but 64B reads/writes not ideal +// Z_DIM=2 improves perf by ~2% partly by improving L2 hit rates for the writes as far as I can tell +template +__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, int height, + const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, + unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) +{ + constexpr int in_parallel = 4/sizeof(T1); + + constexpr size_t TILE_DIM_PADDED = (TILE_DIM * 33) / 32; + __shared__ T1 tile[Z_DIM][TILE_DIM][TILE_DIM_PADDED]; + int w = gridDim.x * TILE_DIM; + + constexpr size_t T1_elements = 16 / sizeof(T1); + constexpr size_t T2_elements = 16 / sizeof(T2); + constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; + + float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal + float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; + scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; + + int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); + int y = blockIdx.y * TILE_DIM * Z_DIM + threadIdx.z * TILE_DIM + threadIdx.y; + + uint absmax_uint = 0; + if (y < height) { + #pragma unroll + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { + Packed128 copy128[copy_vectors]; + + int4 payload; + const int4* address = reinterpret_cast(input + x + (y+j)*w); + asm volatile("ld.global.L2::128B.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(payload.x), "=r"(payload.y), "=r"(payload.z), "=r"(payload.w) + : "l"(address)); + Packed128 in128(payload); + + #pragma unroll + for (int k = 0; k < in128.size; k++) { + T2 in = in128[k]; + float out_float = elementwise_func((float)in * descale_factor); + + T1 out = (T1)(out_float * scale_factor); + copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused + update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax + } + + #pragma unroll + for (int o = 0; o < copy_vectors; o++) { + if constexpr (enable_copy) { + store_same_length(copy + x + (y+j)*w + o*T1_elements, copy128[o]); + } + + size_t offset_x = (threadIdx.x * T2_elements) + (o * T1_elements); + size_t offset_y = (threadIdx.y + j) * TILE_DIM; + offset_y += (offset_y / (128/sizeof(T1))) * in_parallel; + + int* one_bank = reinterpret_cast(&tile[threadIdx.z][0][0] + offset_x + offset_y); + #pragma unroll + for (int k = 0; k < 4; k++) { + one_bank[k] = *(int*)(©128[o][k*4/sizeof(T1)]); + } + } + } + } + + if constexpr (absmax_factor != 0) { + update_global_absmax(absmax_output, absmax_uint); + } else { + __syncthreads(); + } + + // reduce the number of threads for the write if T1_elements > T2_elements + // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first + // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y + constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; + constexpr size_t block_size_y = BLOCK_ROWS; + constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; + constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; + constexpr size_t block_size_x_div_r = block_size_x / ratio; + constexpr size_t block_size_y_div_r = block_size_y / ratio; + + int adjusted_tid_x = threadIdx.x % block_size_x_div_r; + int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); + if (threadIdx.y >= block_size_y_div_r) { return; } + + // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) + if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } + + // x/y for final write to global memory + x = blockIdx.y * TILE_DIM * Z_DIM + threadIdx.z * TILE_DIM + adjusted_tid_x * T1_elements; + y = blockIdx.x * TILE_DIM + (adjusted_tid_y*in_parallel); + + if (x >= height) { return; } + + #pragma unroll + for (int j = 0; j < TILE_DIM / in_parallel; j += BLOCK_ROWS) { + if ((j+adjusted_tid_y) * in_parallel * ratio >= TILE_DIM) { return; } + + // we need more instructions for the write than the read if T2_elements > T1_elements + #pragma unroll + for (int o = 0; o < copy_vectors; o++) { + Packed128 out128[in_parallel]; + #pragma unroll + for (int k = 0; k < Packed128::size; k++) { + int offset_x = (adjusted_tid_y + j) * in_parallel; + int offset_y = ((adjusted_tid_x + o * blockDim.x) * Packed128::size + k) * TILE_DIM; + offset_y += (offset_y / (128/sizeof(T1))) * in_parallel; + + int in32 = *(int*)(&tile[threadIdx.z][0][0] + offset_x + offset_y); + for (int p = 0; p < in_parallel; p++) { + out128[p][k] = ((T1*)&in32)[p]; + } + } + #pragma unroll + for (int p = 0; p < in_parallel; p++) { + store128(transposed + x + (o * blockDim.x * Packed128::size) + (y+p + j * in_parallel) * height, out128[p]); + } + } + } +} +*/ + +// only calculate absmax of the input tensor (non-fused) +template +__global__ void update_absmax_kernel(TensorGPU inp) { + size_t idx = ((blockIdx.x * blockDim.x * ABSMAX_ITERATIONS_PER_THREAD) + threadIdx.x) * inp.num_per_128(); + auto max128 = new_tensor128(inp, disable_scaling); + if (idx < inp.num_elements) { + #pragma unroll + for (int i = 0; i < ABSMAX_ITERATIONS_PER_THREAD; i++) { + auto inp128 = load_tensor128(inp, idx, disable_scaling); + for(int k = 0; k < inp.num_per_128(); ++k) { + float value = inp128.get(k); + max128.add_value_stats(value); + } + idx += blockDim.x * inp.num_per_128(); + } + } + max128.update_absmax(threadIdx.x, blockDim.x, true, true); +} + +// ---------------------------------------------------------------------------- +// kernel launchers +/* +template +void copy_simple(T1 *copy, const T2 *input, size_t N, float* scale_pointer=NULL, const size_t block_size=512) { + size_t fewest_elements = min(Packed128::size, Packed128::size); + const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); + + if (scale_pointer) { + copy_simple_kernel<<>>(copy, input, N, scale_pointer); + } else { + copy_simple_kernel<<>>(copy, input, N); + } + cudaCheck(cudaGetLastError()); +} +*/ + +template +void copy_advanced(T1 *copy, const T2 *input, size_t N, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=512) { + size_t fewest_elements = min(Packed128::size, Packed128::size); + const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); + assert((N % fewest_elements) == 0); + + constexpr uint absmax_factor = 1; + unsigned int* absmax_uint = (unsigned int*)absmax_output; + + if (absmax_output) { + /*if (memset_absmax) { + cudaMemset(absmax_output, 0, sizeof(unsigned int)); + }*/ + if (scale_pointer || descale_pointer) { + copy_advanced_kernel<<>>(copy, input, N, descale_pointer, scale_pointer, absmax_uint); + } else { + copy_advanced_kernel<<>>(copy, input, N, NULL, NULL, absmax_uint); + } + } else { + if (scale_pointer || descale_pointer) { + copy_advanced_kernel<<>>(copy, input, N, descale_pointer, scale_pointer); + } else { + copy_advanced_kernel<<>>(copy, input, N); + } + } + cudaCheck(cudaGetLastError()); +} + +// only 2 important template parameters: write_absmax and elementwise_func +// (use copy_and_transpose() rather than enable_copy=true for clarity) +// slight inefficiency in that we don't optimise away scaling for kernels that don't need it (kernel checks for NULL) +template // advanced template options, usually don't need to be changed +void transpose(T1 *transposed, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, + /*bool memset_absmax=true,*/ cudaStream_t stream=0, size_t block_size=128, T1 *copy=NULL) { // advanced parameters + assert((w % TRANSPOSE_TILE_SIZE) == 0 && (h % TRANSPOSE_TILE_SIZE) == 0); + cudaCheck(cudaGetLastError()); + constexpr int DIM_Z = 1; + block_size /= DIM_Z; + + size_t block_size_x = (TRANSPOSE_TILE_SIZE * sizeof(T2)) / 16; + size_t block_size_y = min(TRANSPOSE_TILE_SIZE, block_size / block_size_x); + dim3 grid_size(w / TRANSPOSE_TILE_SIZE, h / (TRANSPOSE_TILE_SIZE * DIM_Z)); + dim3 block_size_dim(block_size_x, block_size_y, DIM_Z); + + constexpr uint absmax_factor = write_absmax ? 1 : 0; + unsigned int* absmax_uint = (unsigned int*)absmax_output; + /*if (write_absmax && memset_absmax) { + cudaMemset(absmax_output, 0, sizeof(unsigned int)); + }*/ + + switch (block_size_y) { + case 64: transpose_kernel<64, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; + case 32: transpose_kernel<32, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; + case 16: transpose_kernel<16, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; + /*case 8: transpose_kernel<8, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint,); break; + case 4: transpose_kernel<4, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; + case 2: transpose_kernel<2, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; + case 1: transpose_kernel<1, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break;*/ + default: printf("Invalid block size (might be easy to add): %lu\n", block_size_y); exit(1); + } + cudaCheck(cudaGetLastError()); +} + +// wrapper so the parameters of the standard transpose function are less messy +template +void copy_and_transpose(T1 *transposed, T1 *copy, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, unsigned int* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=256) { + transpose(transposed, input, w, h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size, copy); +} + +template +void copy_or_transpose(bool transposing, T1 *output, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, unsigned int* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=0) { + if (transposing) { + transpose(output, input, w, h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size ? block_size : 256); + } else { + copy_advanced(output, input, w*h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size ? block_size : 512); + } + cudaCheck(cudaGetLastError()); +} + +template +void update_absmax(TensorGPU inp, bool memset_absmax=false, cudaStream_t stream=main_stream, size_t max_block_size=512) { + size_t N = inp.num_elements; + if (N == 0 || inp.absmax_ptr == NULL) { return; } + + // find the largest block size that divides N + size_t block_size = max_block_size; + while ((N % (block_size * Packed128::size * ABSMAX_ITERATIONS_PER_THREAD)) != 0) { + block_size /= 2; + assert(block_size >= 32); // block size of 1 would be OK, but so inefficient we'd rather fail and debug I think + } + + const dim3 grid_size(CEIL_DIV(N, block_size * ABSMAX_ITERATIONS_PER_THREAD * Packed128::size)); + if (memset_absmax) { + cudaMemset(inp.absmax_ptr, 0, sizeof(unsigned int)); + } + update_absmax_kernel<<>>(inp); + cudaCheck(cudaGetLastError()); +} + +// ---------------------------------------------------------------------------- +// Scratch allocation for FP8 conversions etc. +// todo - consider alternatives (or at least move it somewhere else) + +#include +#include +#include + +class CudaScratchAllocator { +private: + struct Allocation { + void* ptr; + size_t size; + bool in_use; + + Allocation(void* p, size_t s) : ptr(p), size(s), in_use(false) {} + }; + + static std::vector allocations; + static size_t total_allocated; + +public: + template + static T* getMemory(size_t count, bool exact=false) { + size_t size = count * sizeof(T); + + // Find the smallest free allocation that fits the requested size + auto it = std::min_element(allocations.begin(), allocations.end(), + [size](const Allocation& a, const Allocation& b) { + return !a.in_use && a.size >= size && (b.in_use || b.size < size || a.size < b.size); + }); + + if (it != allocations.end() && !it->in_use && it->size >= size && (!exact || it->size == size)) { + it->in_use = true; + return reinterpret_cast(it->ptr); + } + + // If no suitable allocation found, create a new one + void* new_ptr; + cudaMalloc(&new_ptr, size); + allocations.emplace_back(new_ptr, size); + allocations.back().in_use = true; + total_allocated += size; + printf("Allocated CUDA scratch memory: %lu bytes (%p) ==> total allocated: %.1fGiB\n", size, new_ptr, total_allocated / (1024.0 * 1024.0 * 1024.0)); + return reinterpret_cast(new_ptr); + } + + template + static void releaseMemory(T* ptr) { + if (ptr == nullptr) { return; } + auto it = std::find_if(allocations.begin(), allocations.end(), + [ptr](const Allocation& a) { return a.ptr == (void*)ptr; }); + + if (it != allocations.end()) { + it->in_use = false; + } + } + + static void cleanup() { + for (const auto& alloc : allocations) { + cudaFree(alloc.ptr); + } + allocations.clear(); + } +}; +std::vector CudaScratchAllocator::allocations; +size_t CudaScratchAllocator::total_allocated = 0; + +// ---------------------------------------------------------------------------- +// Transposed Cache (for FP8 weights) + +#include + +// Custom hash function for std::pair +// todo - why did we need this? complained about default constructor issue? +struct PairHash { + std::size_t operator()(const std::pair& p) const { + return std::hash{}(p.first) ^ (std::hash{}(p.second) << 1); + } +}; + +class TransposedCache { +private: + struct CacheEntry { + void* ptr; + size_t size; + }; + + std::unordered_map, CacheEntry, PairHash> cache; + +public: + TransposedCache() = default; + + template + Tout* getTransposed(const T* original, const void* associatedTensor, size_t m, size_t k, bool compute=true, bool find_only=false, cudaStream_t stream=0) { + uint64_t key1 = reinterpret_cast(original); + uint64_t key2 = reinterpret_cast(associatedTensor); + auto key = std::make_pair(key1, key2); + size_t size = m * k * sizeof(T); + + auto it = cache.find(key); + if (it != cache.end() && it->second.size == size) { + return reinterpret_cast(it->second.ptr); + } + if (find_only) { + return nullptr; + } + + Tout* transposed = CudaScratchAllocator::getMemory(m * k, true); + if (compute) { + copy_or_transpose(true, transposed, original, m, k, nullptr, nullptr, nullptr, stream); + } + + cache[key] = {transposed, size}; + return transposed; + } + + void clearCache() { + for (const auto& entry : cache) { + CudaScratchAllocator::releaseMemory(entry.second.ptr); + } + cache.clear(); + } +}; +TransposedCache g_transposed_cache; + +#endif \ No newline at end of file diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index c4fb2fd24..c199ca129 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -47,7 +47,7 @@ extern cudaStream_t main_stream; // short-cuts for compile-time boolean values that can be used as function arguments constexpr std::bool_constant True; -constexpr std::bool_constant False; +constexpr std::bool_constant False; // ---------------------------------------------------------------------------- // Error checking diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 531f903e4..44e095472 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -3,6 +3,8 @@ #ifndef CUDA_UTILS_CUH #define CUDA_UTILS_CUH +#define FAKE_FP8 + #include "cuda_common.h" struct TensorSpec; // Forward declaration @@ -11,6 +13,24 @@ __device__ __constant__ TensorSpec* tensor_specs_ptr; __device__ __constant__ float* gpu_scale_memory_ptr; __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; +enum TT : uint8_t { + PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each + ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 +}; + +enum TFlags : uint8_t { + NONE=0, + REUSED_MEMORY=1, + GRADIENT=2, + TENSOR_2D=4, // used for matmul *outputs* only, not inputs (+weights) + BIAS=8, + LAYERNORM=16, + RESIDUAL=32, + EMBEDDING=64, + STATS=128 +}; + // ---------------------------------------------------------------------------- // Packed128 data structure that forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) @@ -181,17 +201,16 @@ __device__ __host__ unsigned int get_random_noise(unsigned int seed, unsigned in // stochastic rounding (typicalling using Squirel Noise above to go from a seed to a random number) // new algorithm that calculates distance from rounded up/down values to correctly handle denorms // (didn't matter with BF16 because denorms are so tiny they're irrelevant, unlike in FP8/FP16) -template -__device__ __forceinline__ void stochastic_rounding(float in, Ti *out, unsigned int seed, float prob_offset=0.0f) { +template +__device__ void stochastic_rounding(float in, Ti &out, unsigned int random, float prob_offset=0.0f) { if constexpr (std::is_same::value) { - *out = in; - return; + out = in; + return; } - unsigned int random = noise ? get_random_noise(threadIdx.x, blockIdx.x * blockDim.x + blockIdx.y, seed) : seed; // prob_offset allows rounding towards gradient more of the time (one paper recommends that) // e.g. +0.3f ==> 65% chance up, 35% chance down - highp threshold_percentage = ((highp)random / (highp)0xFFFFFFFF) - prob_offset; + float threshold_percentage = ((float)random / (float)0xFFFFFFFF) - prob_offset; Ti rounded_down, rounded_up; if constexpr (std::is_same::value) { @@ -203,48 +222,64 @@ __device__ __forceinline__ void stochastic_rounding(float in, Ti *out, unsigned } else if constexpr (std::is_same::value) { // CUDA doesn't have round down/up instructions for FP8 (in SW or HW) so we do it ourselves // ARM-Intel-NVIDIA style FP8 E4M3 (different for AMD-Graphcore-Qualcomm format!) - Ti rounded = __nv_fp8_e4m3(in); - unsigned char rounded_bits = rounded.__x; - unsigned char absolute_bits = rounded_bits & 127; - unsigned char rounded_up_bits = absolute_bits + 1; - unsigned char rounded_down_bits = absolute_bits - 1; - - // compiler likes the following code atm, but small changes may increase instructions by a lot - // as it may suddenly decide to use branches rather than predication... - if (absolute_bits >= 126) { // maximum normal value (+NaN) - rounded_up_bits = absolute_bits; - if (absolute_bits == 127) { // NaN (not always preserving sign) - rounded_down_bits = 127; - } - } else if (absolute_bits == 0) { // zero - rounded_down_bits = 0; + float low = in; + float high = in; + + if (fabsf(in) < 0.0156f) { + low -= 0.000975f; + high += 0.000975f; } else { - unsigned char mantissa_bits = absolute_bits & 7; - if (mantissa_bits == 7) { // maximum mantissa (already known non-NaN/non-max) - rounded_up_bits = (absolute_bits - mantissa_bits) + 8; // clear mantissa, add 1 to exponent - } else if (mantissa_bits == 0) { // minimum mantissa (already known non-zero) - rounded_down_bits = (absolute_bits + 7) - 8; // max mantissa, subtract 1 from exponent + if (in > 0.0f) { + low *= (15.5f / 16.0f); + high *= (8.5f / 8.0f); + } else { + low *= (8.5f / 8.0f); + high *= (15.5f / 16.0f); } } - if (in < 0) { // negative input: swap rounded up/down and add negative sign - unsigned char swap_tmp = rounded_down_bits | 128; - rounded_down_bits = rounded_up_bits | 128; - rounded_up_bits = swap_tmp; - } - - // rounding to nearest even already gave us 1 of the 2 rounded values surrounding the input - // we only need the other one (but no point skipping anything above given SIMT divergence) - rounded_down.__x = ((float)rounded <= in) ? rounded.__x : rounded_down_bits; - rounded_up.__x = ((float)rounded >= in) ? rounded.__x : rounded_up_bits; - } else if constexpr (std::is_same::value) { - assert(false); // todo - } else { - assert(false); + rounded_up = (__nv_fp8_e4m3)high; + rounded_down = (__nv_fp8_e4m3)low; } - highp diff = (highp)rounded_up - (highp)rounded_down; - highp lerp = ((highp)in - (highp)rounded_down) / diff; // division by 0 is OK as it means (up == down) anyway - *out = (lerp > threshold_percentage) ? rounded_up : rounded_down; + float diff = (float)rounded_up - (float)rounded_down; + float lerp = (in - (float)rounded_down) / diff; // division by 0 is OK as it means (up == down) anyway + out = (lerp > threshold_percentage) ? rounded_up : rounded_down; +} + + +// ---------------------------------------------------------------------------- + +// todo - stochastic is bugged, spent hours debugging, no idea why backwards is so broken with it +__device__ float fake_fp8(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { +#ifdef FAKE_FP8 + unsigned int random_number; + if (false) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + + if (faking && scale != 1.0f) { + assert(scale == 1.0f/descale || scale == 1.0f); + if (mode_e5) { + __nv_fp8_e5m2 value_fp8 = __nv_fp8_e5m2(input * scale); + if (false) { + //stochastic_rounding(input * scale, value_fp8, random_number); + } + return ((float)value_fp8) * descale; + + } else { + __nv_fp8_e4m3 value_fp8 = __nv_fp8_e4m3(input * scale); + if (stochastic) { + // BUGGED - spent 6+ hours debuggin this, and at this point, I genuinely suspect a compiler bug *sigh* + //stochastic_rounding(input * scale, value_fp8, random_number); + } + return ((float)value_fp8) * descale; + } + } +#endif + return input; } // ---------------------------------------------------------------------------- @@ -286,6 +321,10 @@ struct TensorGPU { } __device__ __host__ float get_scalar(size_t index, bool disable_scaling=false) const { + #ifdef FAKE_FP8 + disable_scaling = true; + #endif + ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -295,6 +334,10 @@ struct TensorGPU { } __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=false) { + #ifdef FAKE_FP8 + disable_scaling = true; + #endif + ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -329,6 +372,9 @@ private: bool wrote_data = false; bool wrote_absmax = false; int id = -1; + // fake fp8 mode + bool faking_fp8 = false; + bool mode_e5 = false; public: bool scaling = true; // todo - fp8 only @@ -339,18 +385,26 @@ public: data_ptr = tensor.data_ptr; id = tensor.id; - if (!disable_scaling) { - float2* __restrict__ ptr_restricted = (float2*)tensor.scale_descale_ptr; - if (tensor.scale_descale_ptr == nullptr) { - assert(false); +#ifdef FAKE_FP8 + if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { + if (!(tensor_specs_ptr[id].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[id].flags & TFlags::EMBEDDING)) { + faking_fp8 = true; + if ((tensor_specs_ptr[id].flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::ACTIVATIONS_MULTIUSE)) { + mode_e5 = true; + } } - float2 scale_descale = *ptr_restricted; - scale = scale_descale.x; - descale = scale_descale.y; - absmax_ptr = tensor.absmax_ptr; + } + scaling = false; // only do "fake" scaling +#endif + + if (!disable_scaling) { + const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; + scale = ptr_restricted[0]; + descale = ptr_restricted[1]; } else { scaling = false; } + absmax_ptr = tensor.absmax_ptr; } __device__ void load(size_t offset, bool cache_streaming=false) { @@ -387,16 +441,20 @@ public: // call this manually if e.g. you use set_scalar() to update the tensor // todo - in the future, this could consider more than just absmax - __device__ void add_value_stats(float value, ElementType output) { + __device__ void add_value_stats(float value, ElementType output=(ElementType)0.0f) { new_absmax = max(new_absmax, fabsf(value)); } __device__ float get(int index) { - return (float)data128[index] * (scaling ? descale : 1.0f); + float value = (float)data128[index] * (scaling ? descale : 1.0f); + value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); + return value; } __device__ void set(int index, float value) { - data128[index] = (ElementType)(value * (scaling ? scale : 1.0f)); + float output = value * (scaling ? scale : 1.0f); + output = fake_fp8(faking_fp8, output, scale, descale, mode_e5); + data128[index] = (ElementType)(output); add_value_stats(value, data128[index]); } @@ -423,29 +481,47 @@ public: random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); } - stochastic_rounding(scaled_value, &data128[index], random_number); + stochastic_rounding(scaled_value, data128[index], random_number); add_value_stats(value, data128[index]); } __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { + #ifdef FAKE_FP8 + if (id < 0 || absmax_ptr == NULL) { + return false; + } + forced = true; + #endif + if (!forced && !scaling) { return false; // if we return true, we can skip __syncthreads() in some kernels } wrote_absmax = true; - return false; - - // use native integer reductions as much as possible (supported on all GPUs with FP8) - // this might treat NaN/INF slightly differently but that is the least of our problems - unsigned int absmax_uint = *(unsigned int*)&new_absmax; - asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); - __shared__ unsigned int shared[32]; // lane_id must be obtained directly from the special register // otherwise, the compiler does silly things related to the redux/atomicMax unsigned int lane_id ; asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); unsigned int num_warps = num_threads >> 5; - unsigned int warp_id = thread_id & 31; + unsigned int warp_id = thread_id >> 5; + + // use native integer reductions as much as possible (supported on all GPUs with FP8) + // this might treat NaN/INF slightly differently but that is the least of our problems + unsigned int absmax_uint = *(unsigned int*)&new_absmax; + __shared__ unsigned int shared[32]; + + + // slow path in case redux causes issues + /*shared[lane_id] = absmax_uint; + __syncwarp(); + for (int i = 0; i < 32; i++) { + absmax_uint = max(absmax_uint, shared[i]); + } + __syncwarp();*/ + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + + + // with this condition instead of lane_id == 0, we have shared[lane_id] both here and below // this reduces the number of instructions for addressing @@ -456,7 +532,7 @@ public: // sync can be after exit (dead threads don't count) but must be before return // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR // but this way the EXIT is right before the barrier which frees the warps slightly quicker - bool done = (warp_id != 0 || lane_id >= num_warps); + bool done = (warp_id != 0); if (done && exit) asm volatile("exit;"); __syncthreads(); if (done && !exit) return true; @@ -464,47 +540,57 @@ public: // one more warp reduction then global memory atomic // we want as few global atomics as possible (i.e. 1 per threadblock) absmax_uint = shared[lane_id]; + if (lane_id >= num_warps) { + absmax_uint = 0; + } + + + // slow path in case redux causes issues + /*shared[lane_id] = absmax_uint; + __syncwarp(); + for (int i = 0; i < 32; i++) { + absmax_uint = max(absmax_uint, shared[i]); + } + __syncwarp();*/ asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + + + if (lane_id == 0) { atomicMax(absmax_ptr, absmax_uint); } return true; } - __device__ void update_absmax_1D(bool exit=false) { - update_absmax(threadIdx.x & 31, blockDim.x >> 5, exit); + __device__ void update_absmax_auto(int dimensions=1, bool exit=false) { + if (dimensions == 1) { + update_absmax(threadIdx.x, blockDim.x, exit); + } else if (dimensions == 2) { + update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); + } else if (dimensions == 3) { + update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, + blockDim.x * blockDim.y * blockDim.z, exit); + } } __device__ void skip_absmax() { wrote_absmax = true; } - template - __device__ void force_precision(bool stochastic=false, int microtensor_scale=false, - int zeroed_mantissa_bits=0, bool two_four_sparsity=false) { - for (int k = 0; k < elements; k++) { - // todo: fancy stuff - if (scaling || scale == 0.0f) { // already scaled - data128[k] = (ElementType)((ForcedType)(data128[k])); - } else { // need to scale & descale - float scaled_value = (float)data128[k] * scaling; - ForcedType converted_value = (ForcedType)scaled_value; - float descaled_value = (float)converted_value * descale; - data128[k] = (ElementType)descaled_value; - } - } - } - __device__ ~tensor128() { // this should ~always be optimised away by the compiler if (!wrote_absmax && scaling && wrote_data) { - printf("id: %d\n", id); + //printf("id: %d\n", id); assert(false); } } }; -template +template __device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling=false) { - return tensor128(tensor, disable_scaling); + if constexpr (init) { + return tensor128(tensor, disable_scaling); + } else { + return tensor128(); + } } template @@ -525,23 +611,6 @@ extern int current_absmax_index; extern float* gpu_scale_memory; extern unsigned int* gpu_absmax_memory; -enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each - ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch - DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 -}; - -enum TFlags : uint8_t { - NONE=0, - REUSED_MEMORY=1, - GRADIENT=2, - TENSOR_2D=4, - BIAS=8, - LAYERNORM=16, - RESIDUAL=32, - EMBEDDING=64, - STATS=128 -}; struct TensorSpec { char* ptr; size_t offset; // into base pointer @@ -590,42 +659,6 @@ struct TensorSpec { // ---------------------------------------------------------------------------- // Copy, cast functions -using elementwise_func_t = float (*) (float); -__device__ float nothing_elementwise(float x) { - return x; -} -template -__global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { - constexpr size_t vec_size = 16 / ((sizeof(T1) < sizeof(T2)) ? sizeof(T2) : sizeof(T1)); - size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; - size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; - if (idx >= in.num_elements) { return; } - - auto inp128 = load_tensor128(in, idx, true, disable_scaling); - auto out128 = new_tensor128(out); - for (int k = 0; k < vec_size; k++) { - float out_fp32 = elementwise_func(inp128.get(k)); - out128.set(k, out_fp32); - } - out128.store_same_length(idx); - out128.update_absmax(threadIdx.x, block_size, true); -} - -// todo - move to GELU etc. -__device__ float gelu_forward_elementwise(float x) { - float cube = 0.044715f * x * x * x; - - float tanh_out; - float tanh_arg = sqrtf(2.0f / M_PI) * (x + cube); - asm ("tanh.approx.f32 %0,%1;" : "=f"(tanh_out) : "f"(tanh_arg)); - - // the following uses FMUL+FMA instead of FMUL+FADD+FMUL for "0.5f * x * (1.0f + tanh_out)" - float half_x = 0.5f * x; - return half_x * tanh_out + half_x; -} - // device functions and the kernel to cast data between types template __device__ Td cast_value(Ts val); diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index 0a32b1229..f56d0c023 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -93,11 +93,11 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging - auto dlogits128 = new_tensor128(dlogits); + tensor128 dlogits128 = new_tensor128(dlogits, true); for (int i = threadIdx.x; i < V/elements; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // it will be overwritten by the logits gradients which is when we reduce cache persistence - auto logits128 = load_tensor128(logits, idx * P + i * elements); + auto logits128 = load_tensor128(logits, idx * P + i * elements, false, true); x128 packed_probs; // todo - unused but might be read on CPU in the future so not scaling (???) for(int k = 0; k < elements; ++k) { int element = i*elements + k; @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float indicator = (element == ix) ? 1.0f : 0.0f; dlogits128.set(k, (prob - indicator) * dloss); } - if (WriteDLogits){ + if constexpr (WriteDLogits) { // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here dlogits128.store(idx * P + i * elements, true); @@ -131,7 +131,9 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) probs[idx * P + i] = (floatX)prob; } } - dlogits128.update_absmax(threadIdx.x, blockDim.x, true); + if constexpr (WriteDLogits) { + dlogits128.update_absmax(threadIdx.x, blockDim.x, true); + } } // ---------------------------------------------------------------------------- diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index 34cd9749f..d58825549 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -55,7 +55,7 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG float sech_out = 1.0f - (tanh_in_out * tanh_in_out); float local_grad = 0.5f * ((1.0f + tanh_in_out) + x * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x)); - float result = local_grad * (float)dout128.get(k); + float result = local_grad * dout128.get(k); dinp128.set(k, result); } dinp128.store_same_length(idx, false); diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 130320a4e..f8b166622 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -266,9 +266,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with // if we did actually update the absmax (returns true), we already did __syncthreads() here if (!dinp_new128.update_absmax(threadIdx.x, BLOCK_SIZE, false)) { - //__syncthreads(); + __syncthreads(); } - __syncthreads(); // Each block writes its partial sum to global memory // The last block to finish becomes responsible for summing up all the partial sums diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index a338f0c57..d413b12c4 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -10,6 +10,9 @@ Matrix Multiplication, with help from cuBLASLt // GELU can be either fused (cublasLt) or non-fused (gelu.h) #include "gelu.cuh" +// todo - does this need to be included globally? +#include "copy_and_fp8.h" + // ---------------------------------------------------------------------------- // CUDA kernels @@ -181,7 +184,7 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias epilogue = CUBLASLT_EPILOGUE_DGELU; if (pre_gelu.scale_descale_ptr) { // descale input - float* gelu_descale_ptr = pre_gelu.scale_descale_ptr + 1; + //float* gelu_descale_ptr = pre_gelu.scale_descale_ptr + 1; //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER, &gelu_descale_ptr, sizeof(float*))); } } else { @@ -243,6 +246,8 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); + update_absmax(d, false, stream); + // cleanups cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); diff --git a/train_gpt2.cu b/train_gpt2.cu index 1fec4c1ac..80a4a31c2 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -2,7 +2,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. */ #define UNIQUE_TENSOR_MEMORY false -#define LAYERS_PER_ACTIVATION_CHECKPOINT 1 // 0 = disabled +#define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled #include #include @@ -88,6 +88,45 @@ TensorGPU null_tensorFP32 = {0}; // buffer size to use for device <-> disk io constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; +// todo - move this +__global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned int* gpu_absmax_memory, int num_tensor_specs) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_tensor_specs) return; + + // Get the absmax value for this tensor + unsigned int absmax_uint = gpu_absmax_memory[tid]; + float absmax = __uint_as_float(absmax_uint); + + // Calculate scale and descale + if (absmax == 0.0f) { + absmax = 1.0f; + } + float scale = 1.0f / absmax; + float descale = absmax; + + if (!(tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[tid].flags & TFlags::EMBEDDING) && absmax != 1.0f) { + if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::ACTIVATIONS_MULTIUSE)) { + // e5 + scale *= 32768.0f; + descale *= 1.0f/32768.0f; + } else { + // e4 + scale *= 256.0f; + descale *= (1.0f/256.0f); + } + } else { + scale = 1.0f; + descale = 1.0f; + } + + // todo: circular buffer + //gpu_absmax_memory[tid] = 0.0f; + + // Update gpu_scale_memory + gpu_scale_memory[tid * 2] = scale; + gpu_scale_memory[tid * 2 + 1] = descale; +} + // ---------------------------------------------------------------------------- // GPT-2 model definition @@ -388,89 +427,122 @@ void gpt2_allocate(GPT2 *model) { reuse_every_n = LAYERS_PER_ACTIVATION_CHECKPOINT; assert(!reuse_every_n || (L % reuse_every_n) == 0); - TENSOR_SPECS (encoded, 1, BTC, 0); - TENSOR_SPECS (lnf, 1, BTC, 0); - TENSOR_SPECS_FP32(lnf_mean, 1, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(losses, 1, BT, 0); - - TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); - TENSOR_SPECS (atty, L, BTC, 0); - TENSOR_SPECS (residual2, L, BTC, RESIDUAL); - TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_LOWP(fch, L, 4 * BTC, 0); - TENSOR_SPECS (qkvr, L, 3 * BTC, 0); + TENSOR_SPECS (encoded, 1, BTC, EMBEDDING); + TENSOR_SPECS (qkvr, L, 3 * BTC, TENSOR_2D); #ifdef ENABLE_CUDNN TENSOR_SPECS_FP32(att, L, NH * B * T, 0); #else TENSOR_SPECS (att, L, NH * B * T * T, 0); #endif - - if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS (output, 1, output_size, 0); - TENSOR_SPECS_LOWP(fcproj, L, BTC, 0); - TENSOR_SPECS_LOWP(attproj, L, BTC, 0); - } else { - spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch, REUSED_MEMORY); - spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY); - spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY); - } + TENSOR_SPECS (atty, L, BTC, 0); + TENSOR_SPECS (residual2, L, BTC, RESIDUAL); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, TENSOR_2D); // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS(fch_gelu, L, 4 * BTC, 0); TENSOR_SPECS(ln1, L, BTC, LAYERNORM); TENSOR_SPECS(ln2, L, BTC, LAYERNORM); - TENSOR_SPECS(fch_gelu, L, 4 * BTC, 0); } else if (model->recompute < 2) { + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); TENSOR_SPECS(ln1, L, BTC, LAYERNORM); TENSOR_SPECS(ln2, L, BTC, LAYERNORM); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output, REUSED_MEMORY); } else { + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->acts.output, REUSED_MEMORY); } + TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); + + if (UNIQUE_TENSOR_MEMORY) { + TENSOR_SPECS_LOWP(attproj, L, BTC, TENSOR_2D); + TENSOR_SPECS_LOWP(fcproj, L, BTC, TENSOR_2D); + TENSOR_SPECS (output, 1, output_size, TENSOR_2D | EMBEDDING); + } else { + spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); + spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); + spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch, REUSED_MEMORY | EMBEDDING | TENSOR_2D); + } + + TENSOR_SPECS (lnf, 1, BTC, LAYERNORM); + TENSOR_SPECS_FP32(lnf_mean, 1, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(losses, 1, BT, 0); + + + + if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu + gelu_forward(ACT(fch_gelu), ACT(fch)); + } + matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); + + if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm + layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); + } + matmul_backward(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, B*T, C, 4 * C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); + + #ifdef ENABLE_CUDNN + attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C); + #else + // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory + floatX* buffer_a = ACT(atty); + floatX* buffer_b = ACT(fch); + attention_backward(AGRAD(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); + #endif + + if(model->recompute >= 2) { + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + } + matmul_backward(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, B*T, C, 3 * C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); + + + // 4) activation gradients - // todo - specify subtype! + // note: TENSOR_2D are for the tensors written to by a matmul which are different here + // todo - is "LAYERNORM" applied logically here? do we care? reuse_every_n = 0; spec = &model->acts_grads; dtype_lowp = DTYPE_FLOATX; // todo FP8 shards = 1; if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(encoded, 1, BTC, GRADIENT); - TENSOR_SPECS(output, 1, output_size, GRADIENT); - TENSOR_SPECS(lnf, 1, BTC, GRADIENT | LAYERNORM); - TENSOR_SPECS(ln1, L, BTC, GRADIENT | LAYERNORM); - TENSOR_SPECS(atty, L, BTC, GRADIENT); + TENSOR_SPECS(encoded, 1, BTC, GRADIENT | EMBEDDING); + TENSOR_SPECS(output, 1, output_size, GRADIENT | EMBEDDING); + TENSOR_SPECS(lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS(ln1, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS(atty, L, BTC, GRADIENT | TENSOR_2D); TENSOR_SPECS(residual2, L, BTC, GRADIENT | RESIDUAL); - TENSOR_SPECS(ln2, L, BTC, GRADIENT | LAYERNORM); - TENSOR_SPECS(fch, L, 4 * BTC, GRADIENT); - TENSOR_SPECS(fch_gelu, L, 4 * BTC, GRADIENT); + TENSOR_SPECS(ln2, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS(fch, L, 4 * BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS(fch_gelu, L, 4 * BTC, GRADIENT | TENSOR_2D); TENSOR_SPECS(residual3, L, BTC, GRADIENT | RESIDUAL); TENSOR_SPECS(qkvr, L, 3 * BTC, GRADIENT); } else { - spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT); + spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT | EMBEDDING); int reused_btc = model->acts.residual3 + (L-1); // todo - check if this works with activation checkpointing - spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM); - spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM); - spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT); + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT | TENSOR_2D); int reused_btc2 = model->acts.lnf; spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); spec->residual3 = add_layer_specs(L, "residual3", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); - spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT); + spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT | EMBEDDING); // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) - spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); - spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); - spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); + spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); } // allocate a single huge GPU buffer for all the tensors of a given type @@ -818,7 +890,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NvtxRange layer_range("Layer", l); tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); - tensorX qkvr = MULTI(output_scratch); // non-cudnn reuses tensor with different memory pre/post-permute + tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN @@ -1196,6 +1268,11 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo } */ + // todo - hack - update scale/descale from absmax + int absmax_block_size = 256; + int num_blocks = (num_tensor_specs + absmax_block_size - 1) / block_size; + update_scale_descale_kernel<<>>(gpu_scale_memory, gpu_absmax_memory, num_tensor_specs); + cudaCheck(cudaDeviceSynchronize()); } @@ -1779,6 +1856,9 @@ int main(int argc, char *argv[]) { // in any case, this must be true or we'd index beyond the model's wpe (position embedding table) assert(T <= model.config.max_seq_len); + // todo - hack - do this to update the absmax of all the weights + gpt2_update(&model, 0.0f, 0.9f, 0.95f, 1e-8f, 1.0f, 1.0f, 1, &multi_gpu_config); + // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -1926,6 +2006,14 @@ int main(int argc, char *argv[]) { // clip the gradient norm to a maximum value float grad_clip = 1.0f; float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + + // todo - hack - because the 1st step is now kinda useless due to FP8 absmax scaling not being ready + // todo - ideally should rerun this step so we don't "waste" the data without training on it + if (step == 0) { + step_learning_rate = 0.0f; + weight_decay = 1.0f; + } + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); } cudaCheck(cudaEventRecord(end)); From d2b3e82b30cd36c56ae85047c3767c36000de426 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 16 Sep 2024 10:56:04 +0000 Subject: [PATCH 14/27] compilation fix --- train_gpt2.cu | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 80a4a31c2..2809fa008 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -473,38 +473,6 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); TENSOR_SPECS_FP32(losses, 1, BT, 0); - - - if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT(fch_gelu), ACT(fch)); - } - matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); - - if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm - layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); - } - matmul_backward(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, B*T, C, 4 * C); - layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); - matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); - - #ifdef ENABLE_CUDNN - attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C); - #else - // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = ACT(atty); - floatX* buffer_b = ACT(fch); - attention_backward(AGRAD(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); - #endif - - if(model->recompute >= 2) { - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); - } - matmul_backward(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, B*T, C, 3 * C); - layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); - - - - // 4) activation gradients // note: TENSOR_2D are for the tensors written to by a matmul which are different here // todo - is "LAYERNORM" applied logically here? do we care? From b94c3b70bf1edeffb871b81c43cc61eb91331285 Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 16 Sep 2024 15:33:41 +0000 Subject: [PATCH 15/27] moved tensor functionality into tensor.cuh, added transpose_simple kernel (no format conversion, no elementwise) --- llmc/copy_and_fp8.h | 436 +++----------------------------- llmc/cuda_utils.cuh | 430 ++----------------------------- llmc/tensor.cuh | 599 ++++++++++++++++++++++++++++++++++++++++++++ train_gpt2.cu | 327 +++++------------------- 4 files changed, 709 insertions(+), 1083 deletions(-) create mode 100644 llmc/tensor.cuh diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h index 4cdbb556f..6092db07f 100644 --- a/llmc/copy_and_fp8.h +++ b/llmc/copy_and_fp8.h @@ -58,376 +58,48 @@ __global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { out128.update_absmax(threadIdx.x, block_size, true); } -/* // transpose + copy + format conversion (+ elementwise + absmax) kernel -template -__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, - const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, - unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) +template +__global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* __restrict__ input, int height) { - constexpr size_t TILE_DIM_PADDED = TILE_DIM + 4/sizeof(T1); - __shared__ T1 tile[TILE_DIM][TILE_DIM_PADDED]; - int width = gridDim.x * TILE_DIM; - int height = gridDim.y * TILE_DIM; - - constexpr size_t T1_elements = 16 / sizeof(T1); - constexpr size_t T2_elements = 16 / sizeof(T2); - constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; - - float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal - float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; - scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; - int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); - int y = blockIdx.y * TILE_DIM + threadIdx.y; - uint absmax_uint = 0; - - #pragma unroll - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { - Packed128 in128 = load128cs(input + x + (y+j)*width); - Packed128 copy128[copy_vectors]; - for (int k = 0; k < in128.size; k++) { - T2 in = in128[k]; - float out_float = elementwise_func((float)in * descale_factor); - update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax - - T1 out = (T1)(out_float * scale_factor); - copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused - } - - for (int o = 0; o < copy_vectors; o++) { - if constexpr (enable_copy) { - store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); - } - - size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM_PADDED + o*T1_elements; - int* one_bank = reinterpret_cast(&tile[0][0] + tile_offset); - for (int k = 0; k < 4; k++) { - one_bank[k] = *(int*)(©128[o][k*4/sizeof(T1)]); - } - //store_same_length(&tile[0][0] + tile_offset, copy128[o]); - } - } - - if constexpr (absmax_factor != 0) { - update_global_absmax(absmax_output, absmax_uint); - } else { - __syncthreads(); - } - - // reduce the number of threads for the write if T1_elements > T2_elements - // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first - // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y - constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; - constexpr size_t block_size_y = BLOCK_ROWS; - - constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; - constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; - constexpr size_t block_size_x_div_r = block_size_x / ratio; - constexpr size_t block_size_y_div_r = block_size_y / ratio; - - int adjusted_tid_x = threadIdx.x % block_size_x_div_r; - int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); - if (threadIdx.y >= block_size_y_div_r) { return; } - - // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) - if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } - - // x/y for final write to global memory - x = blockIdx.y * TILE_DIM + adjusted_tid_x * T1_elements; - y = blockIdx.x * TILE_DIM + adjusted_tid_y; - - constexpr int in_parallel = 4/sizeof(T1); - - #pragma unroll - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS * in_parallel) { - if ((j+adjusted_tid_y) * in_parallel >= TILE_DIM) { return; } - - // we need more instructions for the write than the read if T2_elements > T1_elements - #pragma unroll - for (int o = 0; o < copy_vectors; o++) { - Packed128 out128[in_parallel]; - #pragma unroll - for (int k = 0; k < Packed128::size; k++) { - int in32 = *(int*)(&tile[k + (adjusted_tid_x + o * blockDim.x) * Packed128::size][(adjusted_tid_y + j) * in_parallel]); - for (int p = 0; p < in_parallel; p++) { - out128[p][k] = ((T1*)&in32)[p]; - } - } - for (int p = 0; p < in_parallel; p++) { - store128(transposed + x + (o * blockDim.x * Packed128::size) + (y+p + j * in_parallel)*height, out128[p]); - } - } - } -} -*/ - -/* -template -__global__ void transpose_kernel_tensor(TensorGPU transposed, TensorGPU copy, TensorGPU input, int height) { __shared__ T1 tile[TILE_DIM][TILE_DIM]; int width = gridDim.x * TILE_DIM; height = gridDim.y * TILE_DIM; - constexpr bool disable_scaling = (sizeof(T1) == sizeof(T2)); // TODO - THIS IS WRONG - need to check types are identical, not just same size! - constexpr size_t T1_elements = 16 / sizeof(T1); - constexpr size_t T2_elements = 16 / sizeof(T2); - constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; - - int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); + constexpr size_t elements = 16 / sizeof(T1); + int x = blockIdx.x * TILE_DIM + (threadIdx.x * elements); int y = blockIdx.y * TILE_DIM + threadIdx.y; - tensor128 copy128 = new_tensor128(copy, disable_scaling); - #pragma unroll for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { - auto in128 = load_tensor128(input, x + (y+j)*width, true, disable_scaling); - Packed128 in128 = load128cs(input + x + (y+j)*width); - Packed128 copy128[copy_vectors]; - for (int k = 0; k < in128.size; k++) { - float out_float = elementwise_func(in128.get(k)); - copy128.set(k % T1_elements, out_float * scale_factor); // optimised away by compiler if unused - - if (k+1 == out128.size) { - // ... - - } - } - - for (int o = 0; o < copy_vectors; o++) { - if constexpr (enable_copy) { - store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); - } - size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM + o*T1_elements; - store_same_length(&tile[0][0] + tile_offset, copy128[o]); - } + Packed128 in128 = load128cs(input + x + (y+j)*width); + size_t tile_offset = (threadIdx.x * elements) + (threadIdx.y+j)*TILE_DIM; + store128(&tile[0][0] + tile_offset, in128); } + __syncthreads(); - -} -*/ - - - - - -// transpose + copy + format conversion (+ elementwise + absmax) kernel -template -__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, int height, - const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, - unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) -{ - /* - __shared__ T1 tile[TILE_DIM][TILE_DIM]; - int width = gridDim.x * TILE_DIM; - height = gridDim.y * TILE_DIM; - - constexpr size_t T1_elements = 16 / sizeof(T1); - constexpr size_t T2_elements = 16 / sizeof(T2); - constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; - - float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal - float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; - scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; - int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); - int y = blockIdx.y * TILE_DIM + threadIdx.y; - uint absmax_uint = 0; - - #pragma unroll - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { - Packed128 in128 = load128cs(input + x + (y+j)*width); - Packed128 copy128[copy_vectors]; - for (int k = 0; k < in128.size; k++) { - T2 in = in128[k]; - float out_float = elementwise_func((float)in * descale_factor); - update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax - - T1 out = (T1)(out_float * scale_factor); - copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused - } - - for (int o = 0; o < copy_vectors; o++) { - if constexpr (enable_copy) { - store_same_length(copy + x + (y+j)*width + o*T1_elements, copy128[o]); - } - size_t tile_offset = (threadIdx.x * T2_elements) + (threadIdx.y+j)*TILE_DIM + o*T1_elements; - store_same_length(&tile[0][0] + tile_offset, copy128[o]); - } - } - - if constexpr (absmax_factor != 0) { - update_global_absmax(absmax_output, absmax_uint); - } else { - __syncthreads(); - } - - // reduce the number of threads for the write if T1_elements > T2_elements - // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first - // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y - constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; + constexpr size_t block_size_x = (TILE_DIM * sizeof(T1)) / 16; constexpr size_t block_size_y = BLOCK_ROWS; - constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; - constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; - constexpr size_t block_size_x_div_r = block_size_x / ratio; - constexpr size_t block_size_y_div_r = block_size_y / ratio; - - int adjusted_tid_x = threadIdx.x % block_size_x_div_r; - int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); - if (threadIdx.y >= block_size_y_div_r) { return; } - - // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) - if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } + int adjusted_tid_x = threadIdx.x % block_size_x; + int adjusted_tid_y = (threadIdx.y) + (threadIdx.x / block_size_y); // x/y for final write to global memory - x = blockIdx.y * TILE_DIM + adjusted_tid_x * T1_elements; + x = blockIdx.y * TILE_DIM + adjusted_tid_x * elements; y = blockIdx.x * TILE_DIM + adjusted_tid_y; #pragma unroll for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { - // we need more instructions for the write than the read if T2_elements > T1_elements - #pragma unroll - for (int o = 0; o < copy_vectors; o++) { - Packed128 out128; - #pragma unroll - for (int k = 0; k < out128.size; k++) { - // these are tiny 8-bit loads with loads of bank conflicts for FP8 - // extremely hard to avoid and not a bottleneck when everything else is well optimised - out128[k] = tile[k + (adjusted_tid_x + o * blockDim.x) * out128.size][adjusted_tid_y + j]; - } - store128(transposed + x + (o * blockDim.x * out128.size) + (y+j)*height, out128); - } - } - */ -} - - -/* -// best I could come up with (without using TMA) - no bank conflicts, but 64B reads/writes not ideal -// Z_DIM=2 improves perf by ~2% partly by improving L2 hit rates for the writes as far as I can tell -template -__global__ void transpose_kernel(T1* __restrict__ transposed, T1* __restrict__ copy, const T2* __restrict__ input, int height, - const float* __restrict__ descale_pointer=(float*)NULL, const float* __restrict__ scale_pointer=(float*)NULL, - unsigned int* absmax_output=(unsigned int*)NULL, const void** meta=NULL) -{ - constexpr int in_parallel = 4/sizeof(T1); - - constexpr size_t TILE_DIM_PADDED = (TILE_DIM * 33) / 32; - __shared__ T1 tile[Z_DIM][TILE_DIM][TILE_DIM_PADDED]; - int w = gridDim.x * TILE_DIM; - - constexpr size_t T1_elements = 16 / sizeof(T1); - constexpr size_t T2_elements = 16 / sizeof(T2); - constexpr size_t copy_vectors = (sizeof(T1) >= sizeof(T2)) ? (sizeof(T1) / sizeof(T2)) : 1; - - float descale_factor = (scaling && descale_pointer) ? *descale_pointer : 1.0f; // never reciprocal - float scale_factor = (scaling && scale_pointer) ? *scale_pointer : 1.0f; - scale_factor = (reciprocal_scale && scale_factor != 0.0f) ? (1.0f / scale_factor) : scale_factor; - - int x = blockIdx.x * TILE_DIM + (threadIdx.x * T2_elements); - int y = blockIdx.y * TILE_DIM * Z_DIM + threadIdx.z * TILE_DIM + threadIdx.y; - - uint absmax_uint = 0; - if (y < height) { - #pragma unroll - for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { - Packed128 copy128[copy_vectors]; - - int4 payload; - const int4* address = reinterpret_cast(input + x + (y+j)*w); - asm volatile("ld.global.L2::128B.v4.s32 {%0, %1, %2, %3}, [%4];" - : "=r"(payload.x), "=r"(payload.y), "=r"(payload.z), "=r"(payload.w) - : "l"(address)); - Packed128 in128(payload); - - #pragma unroll - for (int k = 0; k < in128.size; k++) { - T2 in = in128[k]; - float out_float = elementwise_func((float)in * descale_factor); - - T1 out = (T1)(out_float * scale_factor); - copy128[k/T1_elements][k%T1_elements] = out; // optimised away by compiler if unused - update_local_absmax(absmax_uint, out_float, absmax_factor); // optional absmax - } - - #pragma unroll - for (int o = 0; o < copy_vectors; o++) { - if constexpr (enable_copy) { - store_same_length(copy + x + (y+j)*w + o*T1_elements, copy128[o]); - } - - size_t offset_x = (threadIdx.x * T2_elements) + (o * T1_elements); - size_t offset_y = (threadIdx.y + j) * TILE_DIM; - offset_y += (offset_y / (128/sizeof(T1))) * in_parallel; - - int* one_bank = reinterpret_cast(&tile[threadIdx.z][0][0] + offset_x + offset_y); - #pragma unroll - for (int k = 0; k < 4; k++) { - one_bank[k] = *(int*)(©128[o][k*4/sizeof(T1)]); - } - } - } - } - - if constexpr (absmax_factor != 0) { - update_global_absmax(absmax_output, absmax_uint); - } else { - __syncthreads(); - } - - // reduce the number of threads for the write if T1_elements > T2_elements - // we want to keep all 32 threads in a warp active, so we try to eliminate in y dimension first - // so we create fake/adjusted tid.x/tid.y where "extra" threadIdx.x adds to the effective tid.y - constexpr size_t block_size_x = (TILE_DIM * sizeof(T2)) / 16; - constexpr size_t block_size_y = BLOCK_ROWS; - constexpr size_t desired_ratio = (sizeof(T2) >= sizeof(T1)) ? (sizeof(T2) / sizeof(T1)) : 1; - constexpr size_t ratio = (desired_ratio <= block_size_y) ? desired_ratio : block_size_y; - constexpr size_t block_size_x_div_r = block_size_x / ratio; - constexpr size_t block_size_y_div_r = block_size_y / ratio; - - int adjusted_tid_x = threadIdx.x % block_size_x_div_r; - int adjusted_tid_y = (threadIdx.y * ratio) + (threadIdx.x / block_size_x_div_r); - if (threadIdx.y >= block_size_y_div_r) { return; } - - // if we cannot reduce block_size.y enough, also reduce x (hurting perf with partial warps) - if (ratio != desired_ratio && adjusted_tid_x >= TILE_DIM / T1_elements) { return; } - - // x/y for final write to global memory - x = blockIdx.y * TILE_DIM * Z_DIM + threadIdx.z * TILE_DIM + adjusted_tid_x * T1_elements; - y = blockIdx.x * TILE_DIM + (adjusted_tid_y*in_parallel); - - if (x >= height) { return; } - - #pragma unroll - for (int j = 0; j < TILE_DIM / in_parallel; j += BLOCK_ROWS) { - if ((j+adjusted_tid_y) * in_parallel * ratio >= TILE_DIM) { return; } - - // we need more instructions for the write than the read if T2_elements > T1_elements + Packed128 out128; #pragma unroll - for (int o = 0; o < copy_vectors; o++) { - Packed128 out128[in_parallel]; - #pragma unroll - for (int k = 0; k < Packed128::size; k++) { - int offset_x = (adjusted_tid_y + j) * in_parallel; - int offset_y = ((adjusted_tid_x + o * blockDim.x) * Packed128::size + k) * TILE_DIM; - offset_y += (offset_y / (128/sizeof(T1))) * in_parallel; - - int in32 = *(int*)(&tile[threadIdx.z][0][0] + offset_x + offset_y); - for (int p = 0; p < in_parallel; p++) { - out128[p][k] = ((T1*)&in32)[p]; - } - } - #pragma unroll - for (int p = 0; p < in_parallel; p++) { - store128(transposed + x + (o * blockDim.x * Packed128::size) + (y+p + j * in_parallel) * height, out128[p]); - } + for (int k = 0; k < elements; k++) { + // these are tiny 8-bit loads with loads of bank conflicts for FP8 + // extremely hard to avoid and not a bottleneck when everything else is well optimised + out128[k] = tile[k + (adjusted_tid_x) * out128.size][adjusted_tid_y + j]; } + store128(transposed + x + out128.size + (y+j)*height, out128); } } -*/ // only calculate absmax of the input tensor (non-fused) template @@ -449,24 +121,9 @@ __global__ void update_absmax_kernel(TensorGPU inp) { } // ---------------------------------------------------------------------------- -// kernel launchers -/* -template -void copy_simple(T1 *copy, const T2 *input, size_t N, float* scale_pointer=NULL, const size_t block_size=512) { - size_t fewest_elements = min(Packed128::size, Packed128::size); - const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); - - if (scale_pointer) { - copy_simple_kernel<<>>(copy, input, N, scale_pointer); - } else { - copy_simple_kernel<<>>(copy, input, N); - } - cudaCheck(cudaGetLastError()); -} -*/ template -void copy_advanced(T1 *copy, const T2 *input, size_t N, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=512) { +void copy_advanced(TensorGPU *copy, TensorGPU *input, size_t N, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=512) { size_t fewest_elements = min(Packed128::size, Packed128::size); const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); assert((N % fewest_elements) == 0); @@ -474,6 +131,9 @@ void copy_advanced(T1 *copy, const T2 *input, size_t N, float* descale_pointer=N constexpr uint absmax_factor = 1; unsigned int* absmax_uint = (unsigned int*)absmax_output; + // todo - fix this function + assert(false); + if (absmax_output) { /*if (memset_absmax) { cudaMemset(absmax_output, 0, sizeof(unsigned int)); @@ -493,58 +153,25 @@ void copy_advanced(T1 *copy, const T2 *input, size_t N, float* descale_pointer=N cudaCheck(cudaGetLastError()); } -// only 2 important template parameters: write_absmax and elementwise_func -// (use copy_and_transpose() rather than enable_copy=true for clarity) -// slight inefficiency in that we don't optimise away scaling for kernels that don't need it (kernel checks for NULL) -template // advanced template options, usually don't need to be changed -void transpose(T1 *transposed, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, - /*bool memset_absmax=true,*/ cudaStream_t stream=0, size_t block_size=128, T1 *copy=NULL) { // advanced parameters +template +void transpose_simple(TensorGPU transposed, TensorGPU input, size_t w, size_t h, cudaStream_t stream=0, size_t block_size=128) { assert((w % TRANSPOSE_TILE_SIZE) == 0 && (h % TRANSPOSE_TILE_SIZE) == 0); cudaCheck(cudaGetLastError()); - constexpr int DIM_Z = 1; - block_size /= DIM_Z; - size_t block_size_x = (TRANSPOSE_TILE_SIZE * sizeof(T2)) / 16; + size_t block_size_x = (TRANSPOSE_TILE_SIZE * sizeof(T1)) / 16; size_t block_size_y = min(TRANSPOSE_TILE_SIZE, block_size / block_size_x); - dim3 grid_size(w / TRANSPOSE_TILE_SIZE, h / (TRANSPOSE_TILE_SIZE * DIM_Z)); - dim3 block_size_dim(block_size_x, block_size_y, DIM_Z); - - constexpr uint absmax_factor = write_absmax ? 1 : 0; - unsigned int* absmax_uint = (unsigned int*)absmax_output; - /*if (write_absmax && memset_absmax) { - cudaMemset(absmax_output, 0, sizeof(unsigned int)); - }*/ + dim3 grid_size(w / TRANSPOSE_TILE_SIZE, h / (TRANSPOSE_TILE_SIZE)); + dim3 block_size_dim(block_size_x, block_size_y, 1); switch (block_size_y) { - case 64: transpose_kernel<64, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; - case 32: transpose_kernel<32, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; - case 16: transpose_kernel<16, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; - /*case 8: transpose_kernel<8, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint,); break; - case 4: transpose_kernel<4, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; - case 2: transpose_kernel<2, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break; - case 1: transpose_kernel<1, TRANSPOSE_TILE_SIZE, reciprocal, enable_copy, true, absmax_factor, elementwise_func><<>>(transposed, copy, input, h, descale_pointer, scale_pointer, absmax_uint); break;*/ + case 64: transpose_simple_kernel<64, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; + case 32: transpose_simple_kernel<32, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; + case 16: transpose_simple_kernel<16, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; default: printf("Invalid block size (might be easy to add): %lu\n", block_size_y); exit(1); } cudaCheck(cudaGetLastError()); } -// wrapper so the parameters of the standard transpose function are less messy -template -void copy_and_transpose(T1 *transposed, T1 *copy, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, unsigned int* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=256) { - transpose(transposed, input, w, h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size, copy); -} - -template -void copy_or_transpose(bool transposing, T1 *output, const T2 *input, size_t w, size_t h, float* descale_pointer=NULL, float* scale_pointer=NULL, unsigned int* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=0) { - if (transposing) { - transpose(output, input, w, h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size ? block_size : 256); - } else { - copy_advanced(output, input, w*h, descale_pointer, scale_pointer, absmax_output, /*memset_absmax,*/ stream, block_size ? block_size : 512); - } - cudaCheck(cudaGetLastError()); -} - template void update_absmax(TensorGPU inp, bool memset_absmax=false, cudaStream_t stream=main_stream, size_t max_block_size=512) { size_t N = inp.num_elements; @@ -675,7 +302,8 @@ class TransposedCache { Tout* transposed = CudaScratchAllocator::getMemory(m * k, true); if (compute) { - copy_or_transpose(true, transposed, original, m, k, nullptr, nullptr, nullptr, stream); + // todo + //copy_or_transpose(true, transposed, original, m, k, nullptr, nullptr, nullptr, stream); } cache[key] = {transposed, size}; diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 44e095472..9102b2541 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -2,35 +2,8 @@ #ifndef CUDA_UTILS_CUH #define CUDA_UTILS_CUH - -#define FAKE_FP8 - #include "cuda_common.h" -struct TensorSpec; // Forward declaration - -__device__ __constant__ TensorSpec* tensor_specs_ptr; -__device__ __constant__ float* gpu_scale_memory_ptr; -__device__ __constant__ unsigned int* gpu_absmax_memory_ptr; - -enum TT : uint8_t { - PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each - ACTIVATIONS_MULTIUSE, // single buffer shared for activations, activation gradients, and scratch - DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 -}; - -enum TFlags : uint8_t { - NONE=0, - REUSED_MEMORY=1, - GRADIENT=2, - TENSOR_2D=4, // used for matmul *outputs* only, not inputs (+weights) - BIAS=8, - LAYERNORM=16, - RESIDUAL=32, - EMBEDDING=64, - STATS=128 -}; - // ---------------------------------------------------------------------------- // Packed128 data structure that forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) @@ -138,7 +111,7 @@ typedef Packed128 x128; // enumerator to indentify the datatype of a tensor. enum class DType : uint8_t { - FP32, FP16, BF16 + FP32, FP16, BF16, FP8E4M3, FP8E5M2 }; // Given a datatype enum, returns the underlying number of bytes @@ -151,6 +124,10 @@ size_t sizeof_dtype(DType type) { return sizeof(half); case DType::BF16: return sizeof(nv_bfloat16); + case DType::FP8E4M3: + return sizeof(__nv_fp8_e4m3); + case DType::FP8E5M2: + return sizeof(__nv_fp8_e5m2); default: // handle or get compiler warning fprintf(stderr, "Unknown datatype\n"); exit(EXIT_FAILURE); @@ -160,6 +137,8 @@ size_t sizeof_dtype(DType type) { DType dtype_of(float* f) { return DType::FP32; } DType dtype_of(nv_bfloat16 * f) { return DType::BF16; } DType dtype_of(half * f) { return DType::FP16; } +DType dtype_of(__nv_fp8_e4m3 * f) { return DType::FP8E4M3; } +DType dtype_of(__nv_fp8_e5m2 * f) { return DType::FP8E5M2; } // ---------------------------------------------------------------------------- // Random Number Generation used in Stochastic Rounding (defined here as used by TensorGPU) @@ -246,22 +225,20 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa out = (lerp > threshold_percentage) ? rounded_up : rounded_down; } - // ---------------------------------------------------------------------------- - // todo - stochastic is bugged, spent hours debugging, no idea why backwards is so broken with it __device__ float fake_fp8(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { -#ifdef FAKE_FP8 unsigned int random_number; - if (false) { - unsigned int clock, laneid; - asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); - asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); - random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); - } - if (faking && scale != 1.0f) { assert(scale == 1.0f/descale || scale == 1.0f); + + if (stochastic) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + if (mode_e5) { __nv_fp8_e5m2 value_fp8 = __nv_fp8_e5m2(input * scale); if (false) { @@ -273,389 +250,14 @@ __device__ float fake_fp8(bool faking, float input, float scale, float descale, __nv_fp8_e4m3 value_fp8 = __nv_fp8_e4m3(input * scale); if (stochastic) { // BUGGED - spent 6+ hours debuggin this, and at this point, I genuinely suspect a compiler bug *sigh* - //stochastic_rounding(input * scale, value_fp8, random_number); + stochastic_rounding(input * scale, value_fp8, random_number); } return ((float)value_fp8) * descale; } } -#endif return input; } -// ---------------------------------------------------------------------------- -// ... -template -struct TensorGPU { - ElementType* data_ptr; - int id; - float* scale_descale_ptr; - unsigned int* absmax_ptr; - size_t num_elements; - - static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { - TensorGPU tmp = {0}; - tmp.data_ptr = ptr; - tmp.id = -1; - return tmp; - } - - template - __device__ __host__ T* as() { - return reinterpret_cast(data_ptr); - } - - __device__ __host__ operator ElementType*() const { - return data_ptr; - } - - __device__ __host__ ElementType& operator[](size_t index) { - return data_ptr[index]; - } - - __device__ __host__ const ElementType& operator[](size_t index) const { - return data_ptr[index]; - } - - __device__ __host__ int num_per_128() const { - return sizeof(int4) / sizeof(ElementType); - } - - __device__ __host__ float get_scalar(size_t index, bool disable_scaling=false) const { - #ifdef FAKE_FP8 - disable_scaling = true; - #endif - - ElementType* __restrict__ data_ptr_restricted = data_ptr; - float* __restrict__ scale_ptr_restricted = scale_descale_ptr; - - float value = (float)data_ptr_restricted[index]; - float descale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[1] : 1.0f; - return value * descale; // [1] = descale - } - - __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=false) { - #ifdef FAKE_FP8 - disable_scaling = true; - #endif - - ElementType* __restrict__ data_ptr_restricted = data_ptr; - float* __restrict__ scale_ptr_restricted = scale_descale_ptr; - - float scale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[0] : 1.0f; - ElementType output = (ElementType)(value * scale); - data_ptr_restricted[index] = output; - return output; - } -}; - -// short-form typedefs -typedef TensorGPU tensorX; -typedef TensorGPU tensorFP32; -typedef TensorGPU tensorFP16; -typedef TensorGPU tensorBF16; - -typedef TensorGPU tensorFP8e4; -typedef TensorGPU tensorFP8e5; - -extern TensorGPU null_tensorX; -extern TensorGPU null_tensorFP32; - -template -struct tensor128 { -private: - Packed128 data128; - ElementType* data_ptr; - unsigned int *absmax_ptr = nullptr; - float scale = 1.0f; - float descale = 1.0f; - float new_absmax = 0.0f; - bool wrote_data = false; - bool wrote_absmax = false; - int id = -1; - // fake fp8 mode - bool faking_fp8 = false; - bool mode_e5 = false; - -public: - bool scaling = true; // todo - fp8 only - static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); - __device__ tensor128() { scaling = false; } - - __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { - data_ptr = tensor.data_ptr; - id = tensor.id; - -#ifdef FAKE_FP8 - if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { - if (!(tensor_specs_ptr[id].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[id].flags & TFlags::EMBEDDING)) { - faking_fp8 = true; - if ((tensor_specs_ptr[id].flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::ACTIVATIONS_MULTIUSE)) { - mode_e5 = true; - } - } - } - scaling = false; // only do "fake" scaling -#endif - - if (!disable_scaling) { - const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; - scale = ptr_restricted[0]; - descale = ptr_restricted[1]; - } else { - scaling = false; - } - absmax_ptr = tensor.absmax_ptr; - } - - __device__ void load(size_t offset, bool cache_streaming=false) { - ElementType* addr = data_ptr + offset; - data128 = cache_streaming ? load128cs(addr) : load128(addr); - } - - __device__ void store(size_t offset, bool cache_streaming=false) { - if (cache_streaming) { - store128cs(data_ptr + offset, data128); - } else { - store128(data_ptr + offset, data128); - } - wrote_data = true; - } - - template - __device__ void store_same_length(size_t offset, bool cache_streaming=false) { - if (cache_streaming) { - store128_same_length_cs(data_ptr + offset, data128); - } else { - store128_same_length(data_ptr + offset, data128); - } - wrote_data = true; - } - - __device__ const Packed128& get128() const { - return data128; - } - - __device__ Packed128& get128() { - return data128; - } - - // call this manually if e.g. you use set_scalar() to update the tensor - // todo - in the future, this could consider more than just absmax - __device__ void add_value_stats(float value, ElementType output=(ElementType)0.0f) { - new_absmax = max(new_absmax, fabsf(value)); - } - - __device__ float get(int index) { - float value = (float)data128[index] * (scaling ? descale : 1.0f); - value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); - return value; - } - - __device__ void set(int index, float value) { - float output = value * (scaling ? scale : 1.0f); - output = fake_fp8(faking_fp8, output, scale, descale, mode_e5); - data128[index] = (ElementType)(output); - add_value_stats(value, data128[index]); - } - - __device__ void set_stochastic(int index, float value, unsigned int random_number, - bool rotate_by_index=true, bool non_deterministic_rng=false) { - float scaled_value = value * (scaling ? scale : 1.0f); - - // rotate the random number by the index so we can cheaply reuse the same RNG - // obviously less good than having true per-index RNG, but should be good enough - // when rounding FP32 to FP8, most of the bits make extremely little difference anyway... - // x10 is used so that it never repeats for indices [0;15] with a minimum difference of 2 etc. - if (rotate_by_index) { - assert(index < 16); // >=16 would repeat and be extremely bad RNG - random_number = __funnelshift_l(random_number, random_number, index * 10); - } - // RNG without a seed from the host for quick testing, but obviously not deterministic! - #ifdef FORCE_NON_DETERMINISM - non_deterministic_rng = true; - #endif - if (non_deterministic_rng) { - unsigned int clock, laneid; - asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); - asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); - random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); - } - - stochastic_rounding(scaled_value, data128[index], random_number); - add_value_stats(value, data128[index]); - } - - __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { - #ifdef FAKE_FP8 - if (id < 0 || absmax_ptr == NULL) { - return false; - } - forced = true; - #endif - - if (!forced && !scaling) { - return false; // if we return true, we can skip __syncthreads() in some kernels - } - wrote_absmax = true; - - // lane_id must be obtained directly from the special register - // otherwise, the compiler does silly things related to the redux/atomicMax - unsigned int lane_id ; - asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); - unsigned int num_warps = num_threads >> 5; - unsigned int warp_id = thread_id >> 5; - - // use native integer reductions as much as possible (supported on all GPUs with FP8) - // this might treat NaN/INF slightly differently but that is the least of our problems - unsigned int absmax_uint = *(unsigned int*)&new_absmax; - __shared__ unsigned int shared[32]; - - - // slow path in case redux causes issues - /*shared[lane_id] = absmax_uint; - __syncwarp(); - for (int i = 0; i < 32; i++) { - absmax_uint = max(absmax_uint, shared[i]); - } - __syncwarp();*/ - asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); - - - - - // with this condition instead of lane_id == 0, we have shared[lane_id] both here and below - // this reduces the number of instructions for addressing - if (lane_id == warp_id) { - shared[lane_id] = absmax_uint; - } - - // sync can be after exit (dead threads don't count) but must be before return - // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR - // but this way the EXIT is right before the barrier which frees the warps slightly quicker - bool done = (warp_id != 0); - if (done && exit) asm volatile("exit;"); - __syncthreads(); - if (done && !exit) return true; - - // one more warp reduction then global memory atomic - // we want as few global atomics as possible (i.e. 1 per threadblock) - absmax_uint = shared[lane_id]; - if (lane_id >= num_warps) { - absmax_uint = 0; - } - - - // slow path in case redux causes issues - /*shared[lane_id] = absmax_uint; - __syncwarp(); - for (int i = 0; i < 32; i++) { - absmax_uint = max(absmax_uint, shared[i]); - } - __syncwarp();*/ - asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); - - - - if (lane_id == 0) { - atomicMax(absmax_ptr, absmax_uint); - } - return true; - } - __device__ void update_absmax_auto(int dimensions=1, bool exit=false) { - if (dimensions == 1) { - update_absmax(threadIdx.x, blockDim.x, exit); - } else if (dimensions == 2) { - update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); - } else if (dimensions == 3) { - update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, - blockDim.x * blockDim.y * blockDim.z, exit); - } - } - __device__ void skip_absmax() { - wrote_absmax = true; - } - - __device__ ~tensor128() { - // this should ~always be optimised away by the compiler - if (!wrote_absmax && scaling && wrote_data) { - //printf("id: %d\n", id); - assert(false); - } - } -}; - -template -__device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling=false) { - if constexpr (init) { - return tensor128(tensor, disable_scaling); - } else { - return tensor128(); - } -} - -template -__device__ tensor128 load_tensor128(TensorGPU tensor, size_t offset, - bool cache_streaming = false, bool disable_scaling=false) { - tensor128 t128(tensor, disable_scaling); - t128.load(offset, cache_streaming); - return t128; -} - -// ---------------------------------------------------------------------------- -// ... - -constexpr size_t MAX_TENSORS = 16*1024; -constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - should make this a command line option -extern int num_tensor_specs; -extern int current_absmax_index; -extern float* gpu_scale_memory; -extern unsigned int* gpu_absmax_memory; - -struct TensorSpec { - char* ptr; - size_t offset; // into base pointer - size_t num_elements; // per shard - int id; - short num_shards; - short remaining_layers; - DType data_type; - TT tensor_type; - int flags; - char name[16]; - - template - __host__ __device__ operator T*() const { - // TODO !!! make it work device side! - /* - if (std::is_same::value && data_type != DType::FP32 || - std::is_same::value && data_type != DType::FP16 || - std::is_same::value && data_type != DType::BF16) { - printf("ERROR: Unexpected data type (%d) for tensor %s\n", (int)data_type, name); - exit(EXIT_FAILURE); - } - */ - return reinterpret_cast(ptr); - } - - template - __device__ __host__ operator TensorGPU() const { - TensorGPU tensor; - tensor.num_elements = num_elements; - tensor.data_ptr = this->operator T*(); - tensor.id = id; - - #ifdef __CUDA_ARCH__ - tensor.scale_descale_ptr = gpu_scale_memory_ptr + 2*id; - tensor.absmax_ptr = gpu_absmax_memory_ptr + id; - #else - tensor.scale_descale_ptr = gpu_scale_memory + 2*id; - tensor.absmax_ptr = gpu_absmax_memory + id; - #endif - - return tensor; - } -}; - // ---------------------------------------------------------------------------- // Copy, cast functions diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh new file mode 100644 index 000000000..ddee7f0e9 --- /dev/null +++ b/llmc/tensor.cuh @@ -0,0 +1,599 @@ +#ifndef TENSOR_CUH +#define TENSOR_CUH + +// ... +#define FAKE_FP8 +#define UNIQUE_TENSOR_MEMORY false +#define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled +// ... + +#include "cuda_common.h" +#include "cuda_utils.cuh" +#include + +// ---------------------------------------------------------------------------- + +enum TT : uint8_t { + PARAMETER=0, PARAMETER_GRAD, PARAMETER_OPT_M, PARAMETER_OPT_V, PARAMETER_MASTER, // 1 allocation each + MULTIUSE, // single allocation shared for activations, activation gradients, and scratch + DEFAULT, COUNT=DEFAULT, NUM_TYPES_PARAM=PARAMETER_MASTER+1 +}; + +enum TFlags : uint8_t { + NONE=0, + REUSED_MEMORY=1, + GRADIENT=2, + TENSOR_2D=4, // used for matmul *outputs* only, not inputs (+weights) + BIAS=8, + LAYERNORM=16, + RESIDUAL=32, + EMBEDDING=64, + STATS=128 +}; + +// ---------------------------------------------------------------------------- +// forward declarations & extern variables defined in the training file +struct TensorSpec; +constexpr size_t MAX_TENSORS = 16*1024; +constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - command line option + +extern TensorSpec tensor_specs[MAX_TENSORS]; +extern TensorSpec* tensor_specs_gpu; +extern size_t tensors_start[TT::COUNT]; +extern size_t tensors_bytes[TT::COUNT]; +extern size_t tensors_elements[TT::COUNT]; +extern int num_tensor_specs; + +extern TT current_tensor_type; +extern int current_absmax_index; +extern float* gpu_scale_memory; +extern unsigned int* gpu_absmax_memory; + +__device__ __constant__ TensorSpec* tensor_specs_ptr; +__device__ __constant__ float* gpu_scale_memory_ptr; +__device__ __constant__ unsigned int* gpu_absmax_memory_ptr; + +// ---------------------------------------------------------------------------- +// Helper macros for accessing tensors in the training file +#define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) +#define ACT_L(x,layer) get_tensor(model->acts.x, MULTIUSE, layer) +#define MULTI_L(x,layer) get_tensor(model->multiuse.x, MULTIUSE, layer) +#define AGRAD_L(x,layer) get_tensor(model->acts_grads.x, MULTIUSE, layer) +#define PARAM_L(x,layer) get_tensor(model->params[PARAMETER].x, PARAMETER, layer) +#define PGRAD_L(x,layer) get_tensor(model->params[PARAMETER_GRAD].x, PARAMETER_GRAD, layer) +#define ACT(x) ACT_L(x,l) +#define MULTI(x) MULTI_L(x,l) +#define AGRAD(x) AGRAD_L(x,l) +#define PARAM(x) PARAM_L(x,l) +#define PGRAD(x) PGRAD_L(x,l) + +// ---------------------------------------------------------------------------- + +template +struct TensorGPU { + ElementType* data_ptr; + int id; + float* scale_descale_ptr; + unsigned int* absmax_ptr; + size_t num_elements; + + static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { + TensorGPU tmp = {0}; + tmp.data_ptr = ptr; + tmp.id = -1; + return tmp; + } + + template + __device__ __host__ T* as() { + return reinterpret_cast(data_ptr); + } + + __device__ __host__ operator ElementType*() const { + return data_ptr; + } + + __device__ __host__ ElementType& operator[](size_t index) { + return data_ptr[index]; + } + + __device__ __host__ const ElementType& operator[](size_t index) const { + return data_ptr[index]; + } + + __device__ __host__ int num_per_128() const { + return sizeof(int4) / sizeof(ElementType); + } + + __device__ __host__ float get_scalar(size_t index, bool disable_scaling=false) const { + #ifdef FAKE_FP8 + disable_scaling = true; + #endif + + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float value = (float)data_ptr_restricted[index]; + float descale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[1] : 1.0f; + return value * descale; // [1] = descale + } + + __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=false) { + #ifdef FAKE_FP8 + disable_scaling = true; + #endif + + ElementType* __restrict__ data_ptr_restricted = data_ptr; + float* __restrict__ scale_ptr_restricted = scale_descale_ptr; + + float scale = (scale_descale_ptr && !disable_scaling) ? scale_ptr_restricted[0] : 1.0f; + ElementType output = (ElementType)(value * scale); + data_ptr_restricted[index] = output; + return output; + } +}; + +typedef TensorGPU tensorX; +typedef TensorGPU tensorFP32; +typedef TensorGPU tensorFP16; +typedef TensorGPU tensorBF16; + +#ifdef ENABLE_FP8 +typedef TensorGPU<__nv_fp8_e4m3> tensorFP8e4; +typedef TensorGPU<__nv_fp8_e5m2> tensorFP8e5; +#else +typedef TensorGPU tensorFP8e4; +typedef TensorGPU tensorFP8e5; +#endif + +extern TensorGPU null_tensorX; +extern TensorGPU null_tensorFP32; + +// ---------------------------------------------------------------------------- + +struct TensorSpec { + char* ptr; + size_t offset; // into base pointer + size_t num_elements; // per shard + int id; + short num_shards; + short remaining_layers; + DType data_type; + TT tensor_type; + int flags; + char name[16]; + + template + __host__ __device__ operator T*() const { + // todo - sanity check DType matches T + return reinterpret_cast(ptr); + } + + template + __device__ __host__ operator TensorGPU() const { + TensorGPU tensor; + tensor.num_elements = num_elements; + tensor.data_ptr = this->operator T*(); + tensor.id = id; + + #ifdef __CUDA_ARCH__ + tensor.scale_descale_ptr = gpu_scale_memory_ptr + 2*id; + tensor.absmax_ptr = gpu_absmax_memory_ptr + id; + #else + tensor.scale_descale_ptr = gpu_scale_memory + 2*id; + tensor.absmax_ptr = gpu_absmax_memory + id; + #endif + + return tensor; + } +}; + +// ---------------------------------------------------------------------------- + +TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { + TensorSpec spec = tensor_specs[spec_index]; + if (layer > 0 && spec.remaining_layers >= layer) { + spec = tensor_specs[spec_index + layer]; + } else if (layer > 0 && spec.remaining_layers > 0) { + printf("ERROR: get_tensor() for %s layer %d but only %d layers remaining\n", spec.name, layer, spec.remaining_layers); + assert(false); + } + assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); + //print_tensor_elements(spec_index); + return spec; +} + +int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { + assert(num_tensor_specs < 16*1024); + assert((total_elements % num_shards) == 0); + TensorSpec* spec = &tensor_specs[num_tensor_specs]; + strncpy(spec->name, name, 15); + spec->name[15] = 0; + + spec->id = num_tensor_specs; + spec->num_elements = total_elements / num_shards; + spec->num_shards = num_shards; + spec->remaining_layers = 0; + spec->data_type = data_type; + spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; + spec->flags = flags; + tensors_elements[spec->tensor_type] += spec->num_elements; + + if (copy_offset_from >= 0) { + TensorSpec base_spec = tensor_specs[copy_offset_from]; + spec->offset = base_spec.offset; + size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); + size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); + if (base_spec.tensor_type != spec->tensor_type) { + printf("ERROR: tensor_type mismatch for %s: %d vs %d\n", + spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); + assert(false); + } + if (flags & REUSED_MEMORY) { + base_spec.flags |= REUSED_MEMORY; + } + assert(base_spec.tensor_type == spec->tensor_type); + assert(new_tensor_bytes <= original_tensor_bytes); + } else { + spec->offset = tensors_bytes[spec->tensor_type]; + tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); + if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { + tensors_start[spec->tensor_type] = num_tensor_specs; + } + } + return num_tensor_specs++; +} + +int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, + int copy_offset_from=-1, int flags=TFlags::NONE, bool copy_per_layer=false, + int reuse_every_n_layers=0, TT tensor_type=TT::DEFAULT) { + int first_tensor_id = num_tensor_specs; + if (reuse_every_n_layers > 0 && num_layers > 1) { + flags |= REUSED_MEMORY; + } + for (int l = 0; l < num_layers; l++) { + char layer_name[16]; + assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); + if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { + copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); + } + int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); + if (copy_per_layer) { + copy_offset_from++; + } + tensor_specs[spec].remaining_layers = num_layers - (l + 1); + } + return first_tensor_id; +} + +// debug helper function +void print_tensor_elements(int tensor_id) { + return; + + printf("Printing tensor %d\n", tensor_id); + TensorSpec spec = tensor_specs[tensor_id]; + size_t num_elements = spec.num_elements; + const char* tensor_name = spec.name; + TT tensor_type = spec.tensor_type; + DType dtype = spec.data_type; + size_t element_size = sizeof_dtype(dtype); + + void* gpu_memory = spec.ptr; + void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); + void* cpu_tensor = malloc(num_elements * element_size); + + printf("Printing tensor %s (tensor_type: %d, data_type: %d)\n", tensor_name, (int)tensor_type, (int)dtype); + printf("GPU memory: %p\n", gpu_tensor); + printf("CPU memory: %p\n", cpu_tensor); + printf("Num elements: %zu\n", num_elements); + printf("Element size: %zu\n", element_size); + printf("Offset: %zu\n", tensor_specs[tensor_id].offset); + + cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); + + printf("Did memcpy\n"); + + printf("First 4 of %s: ", tensor_name); + for (int i = 0; i < num_elements && i < 4; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Middle 4 of %s: ", tensor_name); + for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Last 4 of %s: ", tensor_name); + for (int i = num_elements - 4; i < num_elements; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } + } + printf("\n"); + printf("\n"); + + free(cpu_tensor); +} + +// ---------------------------------------------------------------------------- + +__global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned int* gpu_absmax_memory, int num_tensor_specs) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_tensor_specs) return; + + // Get the absmax value for this tensor + unsigned int absmax_uint = gpu_absmax_memory[tid]; + float absmax = __uint_as_float(absmax_uint); + + // Calculate scale and descale + if (absmax == 0.0f) { + absmax = 1.0f; + } + float scale = 1.0f / absmax; + float descale = absmax; + + if (!(tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[tid].flags & TFlags::EMBEDDING) && absmax != 1.0f) { + if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE)) { + // e5 + scale *= 32768.0f; + descale *= 1.0f/32768.0f; + } else { + // e4 + //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER) { + scale *= 256.0f; + descale *= (1.0f/256.0f); + //} + } + } else { + scale = 1.0f; + descale = 1.0f; + } + + // todo: circular buffer + //gpu_absmax_memory[tid] = 0.0f; + + // Update gpu_scale_memory + gpu_scale_memory[tid * 2] = scale; + gpu_scale_memory[tid * 2 + 1] = descale; +} + +// ---------------------------------------------------------------------------- + +template +struct tensor128 { +private: + Packed128 data128; + ElementType* data_ptr; + unsigned int *absmax_ptr = nullptr; + float scale = 1.0f; + float descale = 1.0f; + float new_absmax = 0.0f; + bool wrote_data = false; + bool wrote_absmax = false; + int id = -1; + // fake fp8 mode + bool faking_fp8 = false; + bool mode_e5 = false; + +public: + bool scaling = (sizeof(ElementType) == 1); + static constexpr const size_t elements = sizeof(int4) / sizeof(ElementType); + __device__ tensor128() { scaling = false; } + + __device__ tensor128(TensorGPU tensor, bool disable_scaling=false) { + data_ptr = tensor.data_ptr; + id = tensor.id; + +#ifdef FAKE_FP8 + if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { + if ((tensor_specs_ptr[id].flags & (TFlags::RESIDUAL | TFlags::EMBEDDING | TFlags::BIAS)) == 0) { + faking_fp8 = true; + if ((tensor_specs_ptr[id].flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::MULTIUSE)) { + mode_e5 = true; + } + } + } + scaling = false; // only do "fake" scaling +#endif + + if (!disable_scaling) { + const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; + scale = ptr_restricted[0]; + descale = ptr_restricted[1]; + } else { + scaling = false; + } + absmax_ptr = tensor.absmax_ptr; + } + + __device__ void load(size_t offset, bool cache_streaming=false) { + ElementType* addr = data_ptr + offset; + data128 = cache_streaming ? load128cs(addr) : load128(addr); + } + + __device__ void store(size_t offset, bool cache_streaming=false) { + if (cache_streaming) { + store128cs(data_ptr + offset, data128); + } else { + store128(data_ptr + offset, data128); + } + wrote_data = true; + } + + template + __device__ void store_same_length(size_t offset, bool cache_streaming=false) { + if (cache_streaming) { + store128_same_length_cs(data_ptr + offset, data128); + } else { + store128_same_length(data_ptr + offset, data128); + } + wrote_data = true; + } + + __device__ const Packed128& get128() const { + return data128; + } + + __device__ Packed128& get128() { + return data128; + } + + // call this manually if e.g. you use set_scalar() to update the tensor + // todo - in the future, this could support more than just absmax + __device__ void add_value_stats(float value, ElementType output=(ElementType)0.0f) { + new_absmax = max(new_absmax, fabsf(value)); + } + + __device__ float get(int index) { + float value = (float)data128[index] * (scaling ? descale : 1.0f); + value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); + return value; + } + + __device__ void set(int index, float value) { + float output = value * (scaling ? scale : 1.0f); + output = fake_fp8(faking_fp8, output, scale, descale, mode_e5); + data128[index] = (ElementType)(output); + add_value_stats(value, data128[index]); + } + + __device__ void set_stochastic(int index, float value, unsigned int random_number, + bool rotate_by_index=true, bool non_deterministic_rng=false) { + float scaled_value = value * (scaling ? scale : 1.0f); + + // rotate the random number by the index so we can cheaply reuse the same RNG + // obviously less good than having true per-index RNG, but should be good enough + // when rounding FP32 to FP8, most of the bits make extremely little difference anyway... + // x10 is used so that it never repeats for indices [0;15] with a minimum difference of 2 etc. + if (rotate_by_index) { + assert(index < 16); // >=16 would repeat and be extremely bad RNG + random_number = __funnelshift_l(random_number, random_number, index * 10); + } + // RNG without a seed from the host for quick testing, but obviously not deterministic! + #ifdef FORCE_NON_DETERMINISM + non_deterministic_rng = true; + #endif + if (non_deterministic_rng) { + unsigned int clock, laneid; + asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); + asm volatile("mov.u32 %0, %%laneid;" : "=r"(laneid)); + random_number = get_random_noise(clock, laneid, blockIdx.x * blockDim.x); + } + + stochastic_rounding(scaled_value, data128[index], random_number); + add_value_stats(value, data128[index]); + } + + // if update_absmax returns true, we can skip __syncthreads() in some kernels + __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { + #ifdef FAKE_FP8 + if (id < 0 || absmax_ptr == NULL || !faking_fp8) { + return false; + } + forced = true; + #endif + + if (!forced && !scaling) { + return false; + } + wrote_absmax = true; + + // lane_id must be obtained directly from the special register + // otherwise, the compiler does silly things related to the redux/atomicMax + unsigned int lane_id ; + asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); + unsigned int num_warps = num_threads >> 5; + unsigned int warp_id = thread_id >> 5; + + // use native integer reductions as much as possible (supported on all GPUs with FP8) + // this might treat NaN/INF slightly differently but that is the least of our problems + __shared__ unsigned int shared[32]; + unsigned int absmax_uint = *(unsigned int*)&new_absmax; + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + + // with this condition instead of lane_id == 0, we have shared[lane_id] both here and below + // this reduces the number of instructions for addressing + if (lane_id == warp_id) { + shared[lane_id] = absmax_uint; + } + + // sync can be after exit (dead threads don't count) but must be before return + // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR + // but this way the EXIT is right before the barrier which frees the warps slightly quicker + bool done = (warp_id != 0); + if (done && exit) asm volatile("exit;"); + __syncthreads(); + if (done && !exit) return true; + + // one more warp reduction then global memory atomic + // we want as few global atomics as possible (i.e. 1 per threadblock) + absmax_uint = shared[lane_id]; + if (lane_id >= num_warps) { + absmax_uint = 0; + } + + asm volatile("redux.sync.max.u32 %0, %0, 0xff;" : "+r"(absmax_uint)); + if (lane_id == 0) { + atomicMax(absmax_ptr, absmax_uint); + } + return true; + } + __device__ void update_absmax_auto(int dimensions=1, bool exit=false) { + if (dimensions == 1) { + update_absmax(threadIdx.x, blockDim.x, exit); + } else if (dimensions == 2) { + update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); + } else if (dimensions == 3) { + update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, + blockDim.x * blockDim.y * blockDim.z, exit); + } + } + __device__ void skip_absmax() { + wrote_absmax = true; + } + + __device__ ~tensor128() { + // this should ~always be optimised away by the compiler + if (!wrote_absmax && scaling && wrote_data) { + //printf("id: %d\n", id); + assert(false); + } + } +}; + +template +__device__ tensor128 new_tensor128(TensorGPU tensor, bool disable_scaling=false) { + if constexpr (init) { + return tensor128(tensor, disable_scaling); + } else { + return tensor128(); + } +} + +template +__device__ tensor128 load_tensor128(TensorGPU tensor, size_t offset, + bool cache_streaming = false, bool disable_scaling=false) { + tensor128 t128(tensor, disable_scaling); + t128.load(offset, cache_streaming); + return t128; +} + +#endif // TENSOR_CUH diff --git a/train_gpt2.cu b/train_gpt2.cu index 2809fa008..cc200f0ef 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,9 +1,8 @@ +//#define ENABLE_FP8 + /* GPT-2 Transformer Neural Net training loop. See README.md for usage. */ -#define UNIQUE_TENSOR_MEMORY false -#define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled - #include #include #include @@ -42,6 +41,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // Packed128, f128, x128 // warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel #include "llmc/cuda_utils.cuh" +// ... todo ... +#include "llmc/tensor.cuh" // defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace // defines: cublas_compute, cublaslt_handle, cublas_handle #include "llmc/cublas_common.h" @@ -76,56 +77,29 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/zero.cuh" // ---------------------------------------------------------------------------- -// global vars for I/O -char filename_buffer[512]; - -// ---------------------------------------------------------------------------- -// global vars containing information about the GPU this process is running on +// global vars regarding GPU process and disk I/O cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; -TensorGPU null_tensorX = {0}; -TensorGPU null_tensorFP32 = {0}; -// buffer size to use for device <-> disk io -constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; - -// todo - move this -__global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned int* gpu_absmax_memory, int num_tensor_specs) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid >= num_tensor_specs) return; - - // Get the absmax value for this tensor - unsigned int absmax_uint = gpu_absmax_memory[tid]; - float absmax = __uint_as_float(absmax_uint); +char filename_buffer[512]; +constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // buffer for device <-> disk io - // Calculate scale and descale - if (absmax == 0.0f) { - absmax = 1.0f; - } - float scale = 1.0f / absmax; - float descale = absmax; - - if (!(tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[tid].flags & TFlags::EMBEDDING) && absmax != 1.0f) { - if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::ACTIVATIONS_MULTIUSE)) { - // e5 - scale *= 32768.0f; - descale *= 1.0f/32768.0f; - } else { - // e4 - scale *= 256.0f; - descale *= (1.0f/256.0f); - } - } else { - scale = 1.0f; - descale = 1.0f; - } +// ---------------------------------------------------------------------------- +// global vars for tensors (declared as extern in tensor.cuh to be visible everywhere) +// todo - avoid global variables for this? +TensorSpec tensor_specs[MAX_TENSORS] = {0}; +TensorSpec* tensor_specs_gpu = NULL; +size_t tensors_start[TT::COUNT] = {0}; +size_t tensors_bytes[TT::COUNT] = {0}; +size_t tensors_elements[TT::COUNT] = {0}; +int num_tensor_specs = 0; - // todo: circular buffer - //gpu_absmax_memory[tid] = 0.0f; +TT current_tensor_type = TT::PARAMETER; +int current_absmax_index = 0; +float* gpu_scale_memory = NULL; +unsigned int* gpu_absmax_memory = NULL; - // Update gpu_scale_memory - gpu_scale_memory[tid * 2] = scale; - gpu_scale_memory[tid * 2 + 1] = descale; -} +TensorGPU null_tensorX = {0}; +TensorGPU null_tensorFP32 = {0}; // ---------------------------------------------------------------------------- // GPT-2 model definition @@ -166,8 +140,7 @@ typedef struct { size_t num_parameters; size_t num_parameters_bytes; - char* multiuse_memory = NULL; - char* params_memory[NUM_TYPES_PARAM] = {0}; + char* tensor_memory[TT::COUNT] = {0}; // other run state configuration int batch_size = 0; // the batch size (B) of current forward pass @@ -189,166 +162,6 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; -GPT2 model; // todo - move back -bool backward = false; // todo - hack - REMOVE - -TensorSpec tensor_specs[MAX_TENSORS] = {0}; -TensorSpec* tensor_specs_gpu = NULL; - -TT current_tensor_type = TT::PARAMETER; -size_t tensors_start[TT::COUNT] = {0}; -size_t tensors_bytes[TT::COUNT] = {0}; -size_t tensors_elements[TT::COUNT] = {0}; - -int num_tensor_specs = 0; -int current_absmax_index = 0; -float* gpu_scale_memory = NULL; -unsigned int* gpu_absmax_memory = NULL; - -// debug helper function -void print_tensor_elements(int tensor_id) { - return; - if (backward == false) return; - - printf("Printing tensor %d\n", tensor_id); - TensorSpec spec = tensor_specs[tensor_id]; - size_t num_elements = spec.num_elements; - const char* tensor_name = spec.name; - TT tensor_type = spec.tensor_type; - DType dtype = spec.data_type; - size_t element_size = sizeof_dtype(dtype); - - void* gpu_memory = (tensor_type == TT::ACTIVATIONS_MULTIUSE) ? model.multiuse_memory : model.params_memory[tensor_type]; - void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); - void* cpu_tensor = malloc(num_elements * element_size); - - printf("Printing tensor %s\n", tensor_name); - printf("GPU memory: %p\n", gpu_tensor); - printf("CPU memory: %p\n", cpu_tensor); - printf("Num elements: %zu\n", num_elements); - printf("Element size: %zu\n", element_size); - printf("Offset: %zu\n", tensor_specs[tensor_id].offset); - - cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); - - printf("Did memcpy\n"); - - printf("First 4 of %s: ", tensor_name); - for (int i = 0; i < num_elements && i < 4; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Middle 4 of %s: ", tensor_name); - for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Last 4 of %s: ", tensor_name); - for (int i = num_elements - 4; i < num_elements; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - printf("\n"); - - free(cpu_tensor); -} - -TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { - TensorSpec spec = tensor_specs[spec_index]; - if (layer > 0 && spec.remaining_layers >= layer) { - spec = tensor_specs[spec_index + layer]; - } else if (layer > 0 && spec.remaining_layers > 0) { - printf("ERROR: get_tensor() for %s layer %d but only %d layers remaining\n", spec.name, layer, spec.remaining_layers); - assert(false); - } - assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); - print_tensor_elements(spec_index); - return spec; -} - -int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { - assert(num_tensor_specs < 16*1024); - assert((total_elements % num_shards) == 0); - TensorSpec* spec = &tensor_specs[num_tensor_specs]; - strncpy(spec->name, name, 15); - spec->name[15] = 0; - - spec->id = num_tensor_specs; - spec->num_elements = total_elements / num_shards; - spec->num_shards = num_shards; - spec->remaining_layers = 0; - spec->data_type = data_type; - spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; - spec->flags = flags; - tensors_elements[spec->tensor_type] += spec->num_elements; - - if (copy_offset_from >= 0) { - TensorSpec base_spec = tensor_specs[copy_offset_from]; - spec->offset = base_spec.offset; - size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); - size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); - if (base_spec.tensor_type != spec->tensor_type) { - printf("ERROR: tensor_type mismatch for %s: %d vs %d\n", - spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); - assert(false); - } - if (flags & REUSED_MEMORY) { - base_spec.flags |= REUSED_MEMORY; - } - assert(base_spec.tensor_type == spec->tensor_type); - assert(new_tensor_bytes <= original_tensor_bytes); - } else { - spec->offset = tensors_bytes[spec->tensor_type]; - tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); - if (tensors_start[spec->tensor_type] == 0 && spec->tensor_type != 0) { - tensors_start[spec->tensor_type] = num_tensor_specs; - } - } - return num_tensor_specs++; -} - -int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, - int copy_offset_from=-1, int flags=TFlags::NONE, bool copy_per_layer=false, - int reuse_every_n_layers=0, TT tensor_type=TT::DEFAULT) { - int first_tensor_id = num_tensor_specs; - if (reuse_every_n_layers > 0 && num_layers > 1) { - flags |= REUSED_MEMORY; - } - for (int l = 0; l < num_layers; l++) { - char layer_name[16]; - assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); - if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { - copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); - } - int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); - if (copy_per_layer) { - copy_offset_from++; - } - tensor_specs[spec].remaining_layers = num_layers - (l + 1); - } - return first_tensor_id; -} - #define TENSOR_SPECS(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype, -1, flags, false, reuse_every_n) #define TENSOR_SPECS_LOWP(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype_lowp, -1, flags, false, reuse_every_n) #define TENSOR_SPECS_FP32(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, DType::FP32, -1, flags, false, reuse_every_n) @@ -402,7 +215,7 @@ void gpt2_allocate(GPT2 *model) { } // 2) multiuse & scratch tensors - current_tensor_type = ACTIVATIONS_MULTIUSE; + current_tensor_type = MULTIUSE; if (UNIQUE_TENSOR_MEMORY) { model->multiuse.bt4c = -1; model->multiuse.btc = -1; @@ -514,27 +327,27 @@ void gpt2_allocate(GPT2 *model) { } // allocate a single huge GPU buffer for all the tensors of a given type - cudaCheck(cudaMalloc(&model->multiuse_memory, tensors_bytes[ACTIVATIONS_MULTIUSE])); - cudaCheck(cudaMemset(model->multiuse_memory, 0, tensors_bytes[ACTIVATIONS_MULTIUSE])); - - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER], tensors_bytes[PARAMETER])); - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_GRAD], tensors_bytes[PARAMETER_GRAD])); - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_M], tensors_bytes[PARAMETER_OPT_M])); - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_OPT_V], tensors_bytes[PARAMETER_OPT_V])); + cudaCheck(cudaMalloc(&model->tensor_memory[MULTIUSE], tensors_bytes[MULTIUSE])); + cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER], tensors_bytes[PARAMETER])); + cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_GRAD], tensors_bytes[PARAMETER_GRAD])); + cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_OPT_M], tensors_bytes[PARAMETER_OPT_M])); + cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_OPT_V], tensors_bytes[PARAMETER_OPT_V])); if (model->use_master_weights) { - cudaCheck(cudaMalloc(&model->params_memory[PARAMETER_MASTER], tensors_bytes[PARAMETER_MASTER])); + cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_MASTER], tensors_bytes[PARAMETER_MASTER])); } + // clear multiuse memory (better safe than sorry) + cudaCheck(cudaMemset(model->tensor_memory[MULTIUSE], 0, tensors_bytes[MULTIUSE])); // Set the ptr for each tensor spec based on type and offset for (size_t i = 0; i < num_tensor_specs; i++) { TensorSpec* spec = &tensor_specs[i]; switch (spec->tensor_type) { - case ACTIVATIONS_MULTIUSE: - spec->ptr = model->multiuse_memory + spec->offset; + case MULTIUSE: + spec->ptr = model->tensor_memory[MULTIUSE] + spec->offset; break; default: assert(spec->tensor_type <= PARAMETER_MASTER); - spec->ptr = model->params_memory[spec->tensor_type] + spec->offset; + spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; } } @@ -551,7 +364,7 @@ void gpt2_allocate(GPT2 *model) { printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); printf("number of v bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_V] / (1024*1024)); - printf("number of act+actgrad+multiuse bytes: %zu MiB\n", tensors_bytes[TT::ACTIVATIONS_MULTIUSE] / (1024*1024)); + printf("number of act+actgrad+multiuse bytes: %zu MiB\n", tensors_bytes[TT::MULTIUSE] / (1024*1024)); // ======================= // allocate_state stuff @@ -560,7 +373,7 @@ void gpt2_allocate(GPT2 *model) { cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); - // Initialize gpu_scale_memory with 1.0f for all elements + // Initialize gpu_scale_memory with 1.0f for all elements (todo - could use cuMemsetD8 but runtime vs driver...) size_t scale_memory_elements = 2 * num_tensor_specs; cudaMalloc(&gpu_scale_memory, scale_memory_elements * sizeof(float)); float* h_scale_memory = (float*)malloc(scale_memory_elements * sizeof(float)); @@ -591,7 +404,7 @@ void gpt2_allocate(GPT2 *model) { printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); // give an estimate of the maximum batch size - size_t bytes_per_sequence = tensors_bytes[TT::ACTIVATIONS_MULTIUSE] / B; // pessimistic (output buffer etc.) + size_t bytes_per_sequence = tensors_bytes[TT::MULTIUSE] / B; // pessimistic (output buffer etc.) printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); cudaCheck(cudaGetLastError()); @@ -620,7 +433,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - device_to_file(model_file, model->params_memory, model->num_parameters_bytes, IO_BUF_SIZE); + device_to_file(model_file, model->tensor_memory, model->num_parameters_bytes, IO_BUF_SIZE); // close file, we're done fcloseCheck(model_file); } @@ -680,7 +493,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w // read in the parameters if weight_init is true if (weight_init) { - file_to_device(model->params_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); + file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); } fcloseCheck(model_file); @@ -809,23 +622,10 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { } */ // copy them to GPU - cudaCheck(cudaMemcpy(model->params_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(model->tensor_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); free(params_memory_cpu); } -// Helper macros for accessing tensors -#define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) -#define ACT_L(x,layer) get_tensor(model->acts.x, ACTIVATIONS_MULTIUSE, layer) -#define MULTI_L(x,layer) get_tensor(model->multiuse.x, ACTIVATIONS_MULTIUSE, layer) -#define AGRAD_L(x,layer) get_tensor(model->acts_grads.x, ACTIVATIONS_MULTIUSE, layer) -#define PARAM_L(x,layer) get_tensor(model->params[PARAMETER].x, PARAMETER, layer) -#define PGRAD_L(x,layer) get_tensor(model->params[PARAMETER_GRAD].x, PARAMETER_GRAD, layer) -#define ACT(x) ACT_L(x,l) -#define MULTI(x) MULTI_L(x,l) -#define AGRAD(x) AGRAD_L(x,l) -#define PARAM(x) PARAM_L(x,l) -#define PGRAD(x) PGRAD_L(x,l) - // propagate inputs through the network to produce logits. void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); @@ -916,7 +716,6 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { NVTX_RANGE_FN(); - backward = true; // todo - hack - REMOVE // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; @@ -935,7 +734,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // 1) the losses accumulate += into acts.losses, reset here // 2) the gradients accumulate += into grads_memory, reset here cudaCheck(cudaMemsetAsync(ACT(losses), 0, B * T * sizeof(float), main_stream)); - cudaCheck(cudaMemsetAsync(model->params_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); + cudaCheck(cudaMemsetAsync(model->tensor_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); } // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier @@ -1073,7 +872,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - floatX* grads_memory = (floatX*)model->params_memory[PARAMETER_GRAD]; + floatX* grads_memory = (floatX*)model->tensor_memory[PARAMETER_GRAD]; // repurposing this buffer (which isn't needed now) to write grad norm into it float* grad_norm_squared = MULTI_L(output_scratch_fp32, 0); @@ -1124,7 +923,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // selectively weight decay some, but not all tensors :( // TODO: revisit and probably refactor this entire function NVTX_RANGE_FN(); - if(model->params_memory[PARAMETER] == nullptr || model->params_memory[PARAMETER_OPT_M] == nullptr || model->params_memory[PARAMETER_OPT_V] == nullptr) { + if(model->tensor_memory[PARAMETER] == nullptr || model->tensor_memory[PARAMETER_OPT_M] == nullptr || model->tensor_memory[PARAMETER_OPT_V] == nullptr) { fprintf(stderr, "Need to allocate optimizer state before update"); exit(EXIT_FAILURE); } @@ -1133,8 +932,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo if(init_state) { model->init_state = false; NvtxRange rng("InitOpt"); - cudaCheck(cudaMemset(model->params_memory[PARAMETER_OPT_M], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->params_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_M], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); } // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint @@ -1193,7 +992,7 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo // - the token embeddings are weight shared and participate in the final projection to logits // - the position embeddings actively participate at every forward/backward pass float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; - floatX* param_ptr = (floatX*)model->params_memory + local_offset_full; + floatX* param_ptr = (floatX*)model->tensor_memory + local_offset_full; floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; @@ -1224,9 +1023,9 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo #if MULTI_GPU ncclCheck(ncclGroupStart()); for(int l = 0; l < num_layers; ++l) { - // gather updated shards of model->params_memory from each process + // gather updated shards of model->tensor_memory from each process ncclCheck(ncclAllGather(param_ptr + l * tensor.size, - (floatX*) model->params_memory + tensor.offset + l * tensor.size, + (floatX*) model->tensor_memory + tensor.offset + l * tensor.size, shard.size, ncclFloatX, multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); } @@ -1274,11 +1073,9 @@ float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { } void gpt2_free(GPT2 *model) { - cudaFreeCheck(&model->multiuse_memory); - for (int i = 0; i < TT::NUM_TYPES_PARAM; i++) { - cudaFreeCheck(&model->params_memory[i]); + for (int i = 0; i < TT::COUNT; i++) { + cudaFreeCheck(&model->tensor_memory[i]); } - cudaFreeCheck(&model->inputs); cudaFreeCheck(&model->targets); cudaFreeCheck(&model->accumulated_mean_loss); @@ -1350,10 +1147,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->params_memory[PARAMETER_OPT_M], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - device_to_file(state_file, model->params_memory[PARAMETER_OPT_V], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_OPT_M], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_OPT_V], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - device_to_file(state_file, model->params_memory[PARAMETER_MASTER], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->tensor_memory[PARAMETER_MASTER], shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // write dataloader state if we are using the Permuted version of it @@ -1394,13 +1191,13 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename } model->init_state = false; // we just got the state from file, no need to do first-touch init - assert(model->params_memory[PARAMETER_OPT_M] != nullptr); - assert(model->params_memory[PARAMETER_OPT_V] != nullptr); - file_to_device(model->params_memory[PARAMETER_OPT_M], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - file_to_device(model->params_memory[PARAMETER_OPT_V], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + assert(model->tensor_memory[PARAMETER_OPT_M] != nullptr); + assert(model->tensor_memory[PARAMETER_OPT_V] != nullptr); + file_to_device(model->tensor_memory[PARAMETER_OPT_M], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->tensor_memory[PARAMETER_OPT_V], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - assert(model->params_memory[PARAMETER_MASTER] != nullptr); - file_to_device(model->params_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + assert(model->tensor_memory[PARAMETER_MASTER] != nullptr); + file_to_device(model->tensor_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); // restore weights from the master weights using the RNG state before last weight update model->rng_state = model->rng_state_last_update; gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); @@ -1689,7 +1486,7 @@ int main(int argc, char *argv[]) { } // build the GPT-2 model - // todo - add model declaration back here + GPT2 model; gpt2_init_common(&model); model.use_master_weights = use_master_weights; model.gelu_fusion = gelu_fusion; @@ -1896,7 +1693,7 @@ int main(int argc, char *argv[]) { // note this is still somewhat wasteful because we don't have a KV cache! gpt2_forward(&model, gen_tokens, 1, T); // get the V-dimensional vector probs[0, t-1, :] - floatX* logits = ((floatX*)&model.multiuse_memory[tensor_specs[model.acts.output].offset]) + (t - 1) * model.config.padded_vocab_size; + floatX* logits = ((floatX*)&model.tensor_memory[MULTIUSE][tensor_specs[model.acts.output].offset]) + (t - 1) * model.config.padded_vocab_size; // move probs back to CPU and sample (note we only move the first vocab_size logits, ignoring the padding) cudaCheck(cudaMemcpy(cpu_logits_raw, logits, model.config.vocab_size * sizeof(floatX), cudaMemcpyDeviceToHost)); // convert to FP32 into cpu_logits (this does nothing useful if floatX == float) From 2900ec9e89ce05bfc5d7516e63b6bdefd89a4cff Mon Sep 17 00:00:00 2001 From: ademeure Date: Mon, 16 Sep 2024 18:31:13 +0000 Subject: [PATCH 16/27] WIP FP8 forward (doesn't quite work yet - obviously...) --- llmc/adamw.cuh | 5 ++ llmc/cuda_common.h | 12 +++++ llmc/gelu.cuh | 4 +- llmc/layernorm.cuh | 50 ++++++++++-------- llmc/matmul.cuh | 123 ++++++++++++++++++++++++++++++++++++++++++--- llmc/tensor.cuh | 28 +++++++---- train_gpt2.cu | 89 ++++++++++++++++---------------- 7 files changed, 227 insertions(+), 84 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index e5f53a140..82288b39f 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -142,6 +142,11 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, idx += stride; } out_param128.update_absmax(threadIdx.x, block_size, false); + } else if (specs[spec_id].data_type == DType::FP8E4M3) { + TensorGPU param_tensor = specs[spec_id]; + auto out_param128 = new_tensor128(param_tensor); + return; + // todo } else { assert(false); // TODO } diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index c199ca129..7e5d265e3 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -98,6 +98,18 @@ typedef __nv_bfloat16 floatX; #define DTYPE_FLOATX DType::BF16 #endif +#if defined(ENABLE_FP8) +typedef __nv_fp8_e4m3 float8e4; +typedef __nv_fp8_e5m2 float8e5; +#define DTYPE_FP8E4 DType::FP8E4M3 +#define DTYPE_FP8E5 DType::FP8E5M2 +#else +typedef floatX float8e4; +typedef floatX float8e5; +#define DTYPE_FP8E4 DTYPE_FLOATX +#define DTYPE_FP8E5 DTYPE_FLOATX +#endif + // ---------------------------------------------------------------------------- // Load and store with streaming cache hints // Older nvcc does not provide __ldcs and __stcs for bfloat16, despite these diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index d58825549..d334c2b03 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -65,7 +65,7 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG // ---------------------------------------------------------------------------- // kernel launchers -void gelu_forward(tensorX out, tensorX inp, cudaStream_t stream=main_stream) { +void gelu_forward(tensorFP8e4 out, tensorFP8e4 inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); @@ -75,7 +75,7 @@ void gelu_forward(tensorX out, tensorX inp, cudaStream_t stream=main_stream) { cudaCheck(cudaGetLastError()); } -void gelu_backward(tensorX dinp, tensorX dout, tensorX inp, cudaStream_t stream=main_stream) { +void gelu_backward(tensorFP8e5 dinp, tensorFP8e5 dout, tensorFP8e4 inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index f8b166622..17d1b27b7 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -13,13 +13,15 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa // llmc internal imports #include "cuda_common.h" #include "cuda_utils.cuh" +#include "tensor.cuh" // ---------------------------------------------------------------------------- // CUDA kernels -__global__ void layernorm_forward_kernel6(tensorFP8e4 out, tensorFP32 mean, tensorFP32 rstd, - tensorFP8e4 inp, tensorFP8e4 weight, - const tensorX bias, int N, int C) { +template +__global__ void layernorm_forward_kernel6(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, + tensorX inp, tensorX weight, + tensorX bias, int N, int C) { // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here int idx = blockIdx.x * blockDim.y + threadIdx.y; // non-standard: threadIdx.x is used for c if(idx >= N) { return; } @@ -62,7 +64,7 @@ __global__ void layernorm_forward_kernel6(tensorFP8e4 out, tensorFP32 mean, tens float o = n * w128.get(k) + b128.get(k); // scale and shift it out128.set(k, o); } - out128.store_same_length(idx * C + c); + out128.template store_same_length(idx * C + c); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { // todo - add a way to pass equivalent of null for mean/rstd to avoid store @@ -73,8 +75,9 @@ __global__ void layernorm_forward_kernel6(tensorFP8e4 out, tensorFP32 mean, tens out128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -__global__ void fused_residual_forward_kernel5(tensorX residual_, tensorFP8e4 normed_, tensorFP32 mean, tensorFP32 rstd, - const tensorX inp1_, const tensorFP8e4 inp2_, +template +__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, + const tensorX inp1, const TensorGPU inp2, const tensorX weight, const tensorX bias, int N, int C) { // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here @@ -85,20 +88,20 @@ __global__ void fused_residual_forward_kernel5(tensorX residual_, tensorFP8e4 no extern __shared__ char* params[]; x128* s_res = reinterpret_cast(params) + (threadIdx.y * C / x128::size); - auto residual128 = new_tensor128(residual_); - auto normed128 = new_tensor128(normed_); + auto residual128 = new_tensor128(residual); + auto normed128 = new_tensor128(normed); const float eps = 1e-5f; float sum = 0.0f; for(int c = threadIdx.x * x128::size; c < C; c += WARP_SIZE * x128::size) { - auto inp1_128 = load_tensor128(inp1_, idx * C + c, true); - auto inp2_128 = load_tensor128(inp2_, idx * C + c, true); + auto inp1_128 = load_tensor128(inp1, idx * C + c, true); + auto inp2_128 = load_tensor128(inp2, idx * C + c, true); for(int k = 0; k < x128::size; ++k) { float out = inp1_128.get(k) + inp2_128.get(k); residual128.set(k, out); sum += residual128.get(k); } - residual128.store_same_length(idx * C + c, false); + residual128.store(idx * C + c, false); s_res[c / x128::size] = residual128.get128(); } @@ -125,7 +128,7 @@ __global__ void fused_residual_forward_kernel5(tensorX residual_, tensorFP8e4 no float o = n * w128.get(k) + b128.get(k); // scale and shift it normed128.set(k, o); } - normed128.store_same_length(idx * C + c, false); + normed128.template store_same_length(idx * C + c, false); } // cache the mean and rstd for the backward pass later if(threadIdx.x == 0) { @@ -133,15 +136,15 @@ __global__ void fused_residual_forward_kernel5(tensorX residual_, tensorFP8e4 no __stcs(rstd + idx, s); } - // Update absmax for both residual and normed tensors + // Update absmax for residual and normed tensors (typically it will skip residual as it is not FP8) residual128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, false); normed128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -template +template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch_, - tensorFP8e5 dout, tensorX inp, tensorX weight, tensorFP32 mean, tensorFP32 rstd, + TensorGPU dout, tensorX inp, tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C) { int BLOCK_SIZE = blockDim.x; // todo - does it make any difference if this is hardcoded here? int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block @@ -195,7 +198,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with for (int c = 0; c < iterations_C; c++) { int global_index = (warpThreadIdx * x128::size) + (c * C_per_iteration); - tensor128 dout128; + tensor128 dout128; tensor128 inp128; tensor128 weight128; tensor128 dinp128; @@ -370,22 +373,25 @@ void launch_layernorm_kernel(KernelFunc kernel, int N, int C, cudaStream_t strea cudaCheck(cudaGetLastError()); } -void layernorm_forward(tensorX out, tensorFP32 mean, tensorFP32 rstd, +template +void layernorm_forward(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, tensorX inp, const tensorX weight, const tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); + launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); } -void fused_residual_forward5(tensorX residual, tensorX normed, tensorFP32 mean, tensorFP32 rstd, - tensorX inp1, tensorX inp2, tensorX weight, tensorX bias, +template +void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, + tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); + launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } +template void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, - const tensorX dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, + const TensorGPU dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index d413b12c4..a915fb0b5 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -258,18 +258,125 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX cudaCheck(cudaGetLastError()); } -template +// Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c +// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul +template +void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU b, const tensorX bias, + int m, int n, int k, cudaStream_t stream=main_stream, + bool accumulate=false, bool backward=false) +{ + NVTX_RANGE_FN(); + if(((uintptr_t)a.data_ptr % 16) != 0 || ((uintptr_t)b.data_ptr % 16) != 0 || ((uintptr_t)d.data_ptr % 16) != 0 || ((uintptr_t)bias.data_ptr % 16) != 0) { + printf("All cuBLASLt pointers must be aligned!\n"); + exit(EXIT_FAILURE); + } + + // create the operation descriptor + cublasLtMatmulDesc_t operationDesc; + cublasCheck(cublasLtMatmulDescCreate(&operationDesc, cublas_compute, CUDA_R_32F)); + + cublasOperation_t opTranspose = CUBLAS_OP_T, opNoTranspose = CUBLAS_OP_N; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opNoTranspose, sizeof(opNoTranspose))); + + // define matrix layouts + cublasLtMatrixLayout_t ALayout, BLayout, CLayout, DLayout; + cublasDataType_t typeA = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeB = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeD = std::is_same::value ? CUBLAS_LOWP : + (std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2); + + cublasCheck(cublasLtMatrixLayoutCreate(&ALayout, typeA, k, m, k)); // always transposed for FP8 + cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, typeB, k, n, k)); // never transposed for FP8 + cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, CUBLAS_LOWP, m, n, m)); // must be BF16 for accumulation in cuBLASLt + cublasCheck(cublasLtMatrixLayoutCreate(&DLayout, typeD, m, n, m)); + + // setup epilogue and associated pointers for bias + cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; + if(bias.data_ptr != NULL) { + epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; + cublasDataType_t bias_data_type = CUBLAS_LOWP; // BF16 bias for FP8 mode + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias.data_ptr, sizeof(bias.data_ptr))); + } + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + + // FP8 scale factors and absmax pointers + float* a_descale_ptr = a.scale_descale_ptr + 1; + float* b_descale_ptr = b.scale_descale_ptr + 1; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_descale_ptr, sizeof(float*))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_descale_ptr, sizeof(float*))); + if (sizeof(Td) == 1) { + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d.scale_descale_ptr, sizeof(float*))); + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &d.absmax_ptr, sizeof(float*))); + } + + // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) + cublasDataType_t scale_type = CUDA_R_32F; + cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // create a preference handle with specified max workspace + cublasLtMatmulPreference_t preference; + cublasCheck(cublasLtMatmulPreferenceCreate(&preference)); + cublasCheck(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &cublaslt_workspace_size, sizeof(cublaslt_workspace_size))); + + // find a suitable algorithm (cached internally so shouldn't take much CPU time in practice) + int returnedResults = 0; + cublasLtMatmulHeuristicResult_t heuristic; + + cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout, + preference, 1, &heuristic, &returnedResults); + + if (returnedResults == 0) { + printf("No cuBLASLt FP8 algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, (bias.data_ptr != NULL)); + exit(EXIT_FAILURE); + } + + // set whether to accumulate (i.e. D += C) or not - note this isn't considered in algorithm selection (?!) + const float alpha = 1.0f, beta = accumulate ? 1.0f : 0.0f; + + // call the matmul + cublasCheck(cublasLtMatmul(cublaslt_handle, operationDesc, + &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, + &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); + + // cleanups + cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); + cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); + cublasCheck(cublasLtMatrixLayoutDestroy(ALayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(BLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(CLayout)); + cublasCheck(cublasLtMatrixLayoutDestroy(DLayout)); + cudaCheck(cudaGetLastError()); +} + +template // small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments) -void matmul_forward_cublaslt(tensorX out, - tensorX inp, tensorX weight, tensorX bias, +void matmul_forward_cublaslt(TensorGPU out, + TensorGPU inp, TensorGPU weight, tensorX bias, int BT, int C, int OC, - TensorGPU pre_gelu=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { + TensorGPU pre_gelu=TensorGPU(), int gelu_fusion=1, cudaStream_t stream=main_stream) { // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) - if (gelu_fusion < 1 && pre_gelu != null_tensorX) { - matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); - gelu_forward(out, pre_gelu, stream); + + if constexpr (sizeof(Tin) == 1) { + if (pre_gelu.enabled()) { + matmul_cublaslt_fp8(pre_gelu, weight, inp, bias, OC, BT, C, stream, false, false); + if constexpr (sizeof(Tout) == 1) { // todo - hack to avoid error for case we will never see + gelu_forward(out, pre_gelu, stream); + } + } else { + matmul_cublaslt_fp8(out, weight, inp, bias, OC, BT, C, stream, false, false); + } } else { - matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); + if (gelu_fusion < 1 && pre_gelu.enabled()) { + matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); + if constexpr (sizeof(Tout) == sizeof(float8e4)) { + gelu_forward(out, pre_gelu, stream); // todo - same hack + } + } else { + matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); + } } } diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index ddee7f0e9..70cb7c9a4 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -2,8 +2,8 @@ #define TENSOR_CUH // ... -#define FAKE_FP8 -#define UNIQUE_TENSOR_MEMORY false +//#define FAKE_FP8 +#define UNIQUE_TENSOR_MEMORY true #define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled // ... @@ -71,16 +71,22 @@ __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; template struct TensorGPU { - ElementType* data_ptr; - int id; - float* scale_descale_ptr; - unsigned int* absmax_ptr; - size_t num_elements; + ElementType* data_ptr = NULL; + float* scale_descale_ptr = NULL; + unsigned int* absmax_ptr = NULL; + size_t num_elements = 0; + int id = -1; + + bool is_null() const { + return (data_ptr == NULL); + } + bool enabled() const { + return (absmax_ptr != NULL); + } static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { - TensorGPU tmp = {0}; + TensorGPU tmp; tmp.data_ptr = ptr; - tmp.id = -1; return tmp; } @@ -146,8 +152,10 @@ typedef TensorGPU tensorFP8e4; typedef TensorGPU tensorFP8e5; #endif -extern TensorGPU null_tensorX; extern TensorGPU null_tensorFP32; +extern TensorGPU null_tensorX; +extern TensorGPU null_tensorFP8E4; +extern TensorGPU null_tensorFP8E5; // ---------------------------------------------------------------------------- diff --git a/train_gpt2.cu b/train_gpt2.cu index cc200f0ef..031088d14 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,4 +1,4 @@ -//#define ENABLE_FP8 +#define ENABLE_FP8 /* GPT-2 Transformer Neural Net training loop. See README.md for usage. @@ -98,8 +98,10 @@ int current_absmax_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; -TensorGPU null_tensorX = {0}; TensorGPU null_tensorFP32 = {0}; +TensorGPU null_tensorX = {0}; +TensorGPU null_tensorFP8E4 = {0}; +TensorGPU null_tensorFP8E5 = {0}; // ---------------------------------------------------------------------------- // GPT-2 model definition @@ -187,7 +189,7 @@ void gpt2_allocate(GPT2 *model) { // 1) parameters & optimizer state for (int t = PARAMETER; t <= PARAMETER_MASTER; t++) { DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; - DType dtype_lowp = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; // FP8 in the future + DType dtype_lowp = (t == PARAMETER) ? DTYPE_FP8E4 : ((t == PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32); current_tensor_type = (TT)t; ParameterTensors* spec = &model->params[t]; @@ -207,7 +209,7 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS (ln2w, L, C, LAYERNORM); TENSOR_SPECS (ln2b, L, C, LAYERNORM | BIAS); TENSOR_SPECS_LOWP(fcw, L, 4 * C * C, TENSOR_2D); - TENSOR_SPECS_LOWP(fcb, L, 4 * C, BIAS); + TENSOR_SPECS (fcb, L, 4 * C, BIAS); TENSOR_SPECS_LOWP(fcprojw, L, 4 * C * C, TENSOR_2D); TENSOR_SPECS (fcprojb, L, C, BIAS); TENSOR_SPECS (lnfw, 1, C, LAYERNORM); @@ -229,7 +231,7 @@ void gpt2_allocate(GPT2 *model) { // 3) activations ActivationTensors* spec = &model->acts; - DType dtype_lowp = DTYPE_FLOATX; // todo FP8 + DType dtype_lowp = DTYPE_FP8E4; // todo FP8 DType dtype = DTYPE_FLOATX; shards = 1; @@ -254,17 +256,17 @@ void gpt2_allocate(GPT2 *model) { // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size if (model->recompute < 1 || UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(fch_gelu, L, 4 * BTC, 0); - TENSOR_SPECS(ln1, L, BTC, LAYERNORM); - TENSOR_SPECS(ln2, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, 0); + TENSOR_SPECS_LOWP(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(ln2, L, BTC, LAYERNORM); } else if (model->recompute < 2) { spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); - TENSOR_SPECS(ln1, L, BTC, LAYERNORM); - TENSOR_SPECS(ln2, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(ln1, L, BTC, LAYERNORM); + TENSOR_SPECS_LOWP(ln2, L, BTC, LAYERNORM); } else { spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, shards, dtype_lowp, model->multiuse.output_scratch, REUSED_MEMORY); - spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); - spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); + spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); + spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); } TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); @@ -291,27 +293,27 @@ void gpt2_allocate(GPT2 *model) { // todo - is "LAYERNORM" applied logically here? do we care? reuse_every_n = 0; spec = &model->acts_grads; - dtype_lowp = DTYPE_FLOATX; // todo FP8 + dtype_lowp = DTYPE_FP8E5; shards = 1; if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS(encoded, 1, BTC, GRADIENT | EMBEDDING); - TENSOR_SPECS(output, 1, output_size, GRADIENT | EMBEDDING); - TENSOR_SPECS(lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS(ln1, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS(atty, L, BTC, GRADIENT | TENSOR_2D); - TENSOR_SPECS(residual2, L, BTC, GRADIENT | RESIDUAL); - TENSOR_SPECS(ln2, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS(fch, L, 4 * BTC, GRADIENT | TENSOR_2D); - TENSOR_SPECS(fch_gelu, L, 4 * BTC, GRADIENT | TENSOR_2D); - TENSOR_SPECS(residual3, L, BTC, GRADIENT | RESIDUAL); - TENSOR_SPECS(qkvr, L, 3 * BTC, GRADIENT); + TENSOR_SPECS (encoded, 1, BTC, GRADIENT | EMBEDDING); + TENSOR_SPECS (output, 1, output_size, GRADIENT | EMBEDDING); + TENSOR_SPECS_LOWP(lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS_LOWP(ln1, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS (atty, L, BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS (residual2, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS_LOWP(ln2, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS (residual3, L, BTC, GRADIENT | RESIDUAL); + TENSOR_SPECS (qkvr, L, 3 * BTC, GRADIENT); } else { spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT | EMBEDDING); int reused_btc = model->acts.residual3 + (L-1); // todo - check if this works with activation checkpointing - spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); - spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT | TENSOR_2D); int reused_btc2 = model->acts.lnf; @@ -320,9 +322,9 @@ void gpt2_allocate(GPT2 *model) { spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT | EMBEDDING); // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) - spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); - spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); } @@ -652,7 +654,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // start of forward pass with encoder (layer 0) int l = 0; encoder_forward(ACT(encoded), model->inputs, PARAM(wte), PARAM(wpe), B, T, C); - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); for (; l < L; l++) { NvtxRange layer_range("Layer", l); @@ -660,28 +662,28 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); - matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); #else attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); - fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); - matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); if(l+1 != L) { - fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), - PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); + fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), + PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); } else { - fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), - PARAM(lnfw), PARAM(lnfb), B*T, C); + fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), + PARAM(lnfw), PARAM(lnfb), B*T, C); } } - matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); + matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); } @@ -716,6 +718,8 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { NVTX_RANGE_FN(); + exit(0); +#if 0 // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; @@ -757,8 +761,8 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // next: backward the classifier matmul matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); // backward the final layernorm - layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), - PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); + layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), + PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers for (; l >= 0; l--) { @@ -868,6 +872,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } else { model->mean_loss = -1.f; // no loss available yet } +#endif } float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { From 2e26a9c70862a344afdf0bd8a4c7580012d57601 Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 17 Sep 2024 02:17:21 +0000 Subject: [PATCH 17/27] FP8 forward working again (!!!) + AdamW for FP8/BF16/FP32 --- llmc/adamw.cuh | 217 ++++++++++++++++++++++++-------------------- llmc/copy_and_fp8.h | 76 +++++----------- llmc/cuda_utils.cuh | 4 +- llmc/gelu.cuh | 4 +- llmc/layernorm.cuh | 18 ++-- llmc/tensor.cuh | 178 ++++++++++++++++++++---------------- train_gpt2.cu | 131 ++++++++++++++++++++------ 7 files changed, 355 insertions(+), 273 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 82288b39f..f33c773a5 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -15,23 +15,100 @@ __device__ float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } +// always sizeof(param) <= sizeof(grad) <= sizeof(opt/master) <= sizeof(float) +template +__device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, int spec_id, size_t current_start, size_t current_end, size_t stride, + TensorGPU grad_tensor, TensorGPU master_tensor, TensorGPU opt_m_tensor, TensorGPU opt_v_tensor, + unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, + float eps, float wd, float grad_scale, int t) { + constexpr size_t block_size = 64; + auto out_master128 = new_tensor128(master_tensor, true); + auto out_opt_m128 = new_tensor128(opt_m_tensor, true); + auto out_opt_v128 = new_tensor128(opt_v_tensor, true); + auto out_param128 = new_tensor128(param_tensor); + + __syncthreads(); // todo - hopefully results in better memory access patterns => TBC + while (idx < current_end) { + unsigned int random = get_random_noise(seed, idx); + + tensor128 param128; + tensor128 grad128; + tensor128 opt_m128; + tensor128 opt_v128; + tensor128 master128; + + size_t offset = idx - current_start; + int next_idx[TT::NUM_TYPES_PARAM] = {0}; + int current_idx[TT::NUM_TYPES_PARAM] = {0}; + + #pragma unroll + for (int i = 0; i < 16; i += 4, offset += 4) { + if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, offset); + if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, offset, false, true); + if (current_idx[PARAMETER_OPT_M] == 0) opt_m128 = load_tensor128(opt_m_tensor, offset, false,true); + if (current_idx[PARAMETER_OPT_V] == 0) opt_v128 = load_tensor128(opt_v_tensor, offset, false, true); + if (current_idx[PARAMETER_MASTER] == 0) master128 = load_tensor128(master_tensor, offset, false, true); + + for (int k = 0; k < 4; k++) { + float grad = grad128.get(current_idx[PARAMETER_GRAD] + k); + float m = opt_m128.get(current_idx[PARAMETER_OPT_M] + k); + float v = opt_v128.get(current_idx[PARAMETER_OPT_V] + k); + float master = master128.get(current_idx[PARAMETER_MASTER] + k); + + m = lerp(grad, m, beta1); + v = lerp(grad * grad, v, beta2); + out_opt_m128.set(current_idx[PARAMETER_OPT_M] + k, m); + out_opt_v128.set(current_idx[PARAMETER_OPT_V] + k, v); + m /= beta1_correction; + v /= beta2_correction; + + float old_param; + if (use_master_weights && !master_init_modes) { + old_param = master; + } else { + old_param = param128.get(current_idx[PARAMETER] + k); + } + float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); + out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random); + float new_param = out_param128.get(current_idx[PARAMETER] + k); + out_master128.set(current_idx[PARAMETER_MASTER] + k, param); + } + next_idx[PARAMETER] = (i + 4) % (16 / sizeof(Tparam)); + next_idx[PARAMETER_GRAD] = (i + 4) % (16 / sizeof(Tgrad)); + next_idx[PARAMETER_OPT_M] = (i + 4) % (16 / sizeof(Tm)); + next_idx[PARAMETER_OPT_V] = (i + 4) % (16 / sizeof(Tv)); + next_idx[PARAMETER_MASTER] = (i + 4) % (16 / sizeof(Tmaster)); + + if (next_idx[PARAMETER] == 0) out_param128.store(offset - current_idx[PARAMETER]); + if (next_idx[PARAMETER_OPT_M] == 0) out_opt_m128.store(offset - current_idx[PARAMETER_OPT_M]); + if (next_idx[PARAMETER_OPT_V] == 0) out_opt_v128.store(offset - current_idx[PARAMETER_OPT_V]); + if constexpr (use_master_weights) { + if (next_idx[PARAMETER_MASTER] == 0) out_master128.store(offset - current_idx[PARAMETER_MASTER]); + } + + for (int n = 0; n < TT::NUM_TYPES_PARAM; n++) { + current_idx[n] = next_idx[n]; + } + } + idx += stride; + } + out_param128.update_absmax(threadIdx.x, block_size, false); + return idx; +} + template __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, - float eps, float weight_decay, float grad_scale, int t, bool init_from_master_only=false) { + float eps, float weight_decay, float grad_scale, int t) { // ... - __shared__ int shared_spec_id; - constexpr size_t block_size = 64; // 64 ==> 4KiB chunks with iteration_size=16 for FP32 opt/master - size_t iteration_size = 16; - assert(iteration_size <= 16); - size_t idx_blk = blockIdx.x * block_size * iteration_size; - size_t idx = idx_blk + (threadIdx.x * iteration_size); + constexpr size_t iteration_size = 16; + size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); size_t stride = gridDim.x * blockDim.x * iteration_size; int spec_id = 0; - TensorSpec* grad_specs = specs + num_params_tensors; TensorSpec* opt_m_specs = specs + 2 * num_params_tensors; TensorSpec* opt_v_specs = specs + 3 * num_params_tensors; @@ -43,110 +120,54 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, while (true) { // todo - performance analysis/optimisation! (impact of using step 0?) - if (threadIdx.x == 0) { - while (idx >= current_end) { - spec_id++; - if (spec_id >= num_params_tensors) { - shared_spec_id = -1; - return; - } - opt_spec = opt_v_specs[spec_id]; - current_start = opt_spec.offset / sizeof(float); - current_end = current_start + opt_spec.num_elements; + while (idx >= current_end) { + spec_id++; + if (spec_id >= num_params_tensors) { + return; + } + opt_spec = opt_v_specs[spec_id]; + current_start = opt_spec.offset / sizeof(float); + current_end = current_start + opt_spec.num_elements; + + while (idx < current_start) { + idx += stride; } - shared_spec_id = spec_id; - } - __syncthreads(); - spec_id = shared_spec_id; - if (spec_id == -1) { - return; } opt_spec = opt_v_specs[spec_id]; current_start = opt_spec.offset / sizeof(float); current_end = current_start + opt_spec.num_elements; + float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; TensorGPU grad_tensor = grad_specs[spec_id]; TensorGPU master_tensor = master_specs[spec_id]; TensorGPU opt_m_tensor = opt_m_specs[spec_id]; TensorGPU opt_v_tensor = opt_spec; - auto out_master128 = new_tensor128(master_tensor, true); - auto out_opt_m128 = new_tensor128(opt_m_tensor, true); - auto out_opt_v128 = new_tensor128(opt_v_tensor, true); - - // todo - make it configurable whether weight decay applies to e.g. bias or not - float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - - if (specs[spec_id].data_type == DType::BF16) { - // todo - this is actually "EQUAL FLOATX" right now, doesn't work for mix and match - // !!! + if (specs[spec_id].data_type == DType::FP32) { + TensorGPU param_tensor = specs[spec_id]; + idx = adamw_update_part( + param_tensor, idx, spec_id, current_start, current_end, stride, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + seed, num_params_tensors, num_parameters, num_opt_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); + } else if (specs[spec_id].data_type == DType::BF16) { TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; - auto out_param128 = new_tensor128(param_tensor); - - __syncthreads(); // todo - hopefully results in better memory access patterns => TBC - while (idx < current_end) { - // always sizeof(param) <= sizeof(grad) <= sizeof(opt/master) - // todo - maybe not true, could have FP32 param and BF16 grad? - // todo - hack - currently assuming grad is always bfloat16 - unsigned int random = get_random_noise(seed, idx); - for (int i = 0; i < iteration_size; i += 16 / sizeof(__nv_bfloat16)) { - size_t offset = (idx - current_start) + i; - auto param128 = load_tensor128(param_tensor, offset); - auto grad128 = load_tensor128(grad_tensor, offset); - for (int j = 0; j < sizeof(float) / sizeof(__nv_bfloat16); j++) { - // todo - sparse(-ish) accesses, I don't like it. - auto opt_m128 = load_tensor128(opt_m_tensor, offset + j * f128::size, true); - auto opt_v128 = load_tensor128(opt_v_tensor, offset + j * f128::size, true); - // optimised away if we don't use it (and pointer will be equal to opt_m128) - auto master128 = load_tensor128(master_tensor, offset + j * f128::size, true); - - if (master_init_modes && init_from_master_only) { - for (int k = 0; k < f128::size; k++) { - float old_param = master128.get(k); - out_param128.set_stochastic(k + j*f128::size, old_param, random); - } - continue; - } - - for (int k = 0; k < f128::size; k++) { - float grad = grad128.get(k + j*f128::size); - float m = opt_m128.get(k); - float v = opt_v128.get(k); - m = lerp(grad, m, beta1); - v = lerp(grad * grad, v, beta2); - out_opt_m128.set(k, m); - out_opt_v128.set(k, v); - m /= beta1_correction; - v /= beta2_correction; - - float old_param; - if (use_master_weights && !master_init_modes) { - old_param = master128.get(k); - } else { - old_param = param128.get(k + j*f128::size); - } - float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); - out_param128.set_stochastic(k + j*f128::size, param, random); - out_master128.set(k, param); - } - out_opt_m128.store(offset + j * f128::size); - out_opt_v128.store(offset + j * f128::size); - if constexpr (use_master_weights) { - out_master128.store(offset + j * f128::size); - } - } - out_param128.store(offset); - } - idx_blk += stride; - idx += stride; - } - out_param128.update_absmax(threadIdx.x, block_size, false); + idx = adamw_update_part( + param_tensor, idx, spec_id, current_start, current_end, stride, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + seed, num_params_tensors, num_parameters, num_opt_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); } else if (specs[spec_id].data_type == DType::FP8E4M3) { - TensorGPU param_tensor = specs[spec_id]; - auto out_param128 = new_tensor128(param_tensor); - return; - // todo + TensorGPU<__nv_fp8_e4m3> param_tensor = specs[spec_id]; + idx = adamw_update_part( + param_tensor, idx, spec_id, current_start, current_end, stride, + grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, + seed, num_params_tensors, num_parameters, num_opt_parameters, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, + eps, wd, grad_scale, t); } else { assert(false); // TODO } diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h index 6092db07f..f1ebffe6c 100644 --- a/llmc/copy_and_fp8.h +++ b/llmc/copy_and_fp8.h @@ -11,7 +11,6 @@ See /dev/cuda/advanced_copy_transpose.cu for more information and options #include "cuda_utils.cuh" // todo - tune these for performance (but should be close to optimal already) -#define ABSMAX_ITERATIONS_PER_THREAD 4 #define TRANSPOSE_TILE_SIZE 64UL // ---------------------------------------------------------------------------- @@ -39,23 +38,23 @@ __device__ float gelu_forward_elementwise(float x) { // Same as copy_simple_kernel but with optional absmax and elementwise function options // absmax is calculated before scaling but after the elementwise function -template -__global__ void copy_advanced_kernel(TensorGPU in, TensorGPU out) { - constexpr size_t vec_size = 16 / ((sizeof(T1) < sizeof(T2)) ? sizeof(T2) : sizeof(T1)); + typename Tin=float, typename Tout=float> +__global__ void copy_advanced_kernel(TensorGPU out, TensorGPU in) { + constexpr size_t vec_size = 16 / ((sizeof(Tin) < sizeof(Tout)) ? sizeof(Tout) : sizeof(Tin)); size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; - if (idx >= in.num_elements) { return; } + if (idx >= out.num_elements) { return; } auto inp128 = load_tensor128(in, idx, true, disable_scaling); - auto out128 = new_tensor128(out); + auto out128 = new_tensor128(out, disable_scaling); for (int k = 0; k < vec_size; k++) { float out_fp32 = elementwise_func(inp128.get(k)); out128.set(k, out_fp32); } - out128.store_same_length(idx); - out128.update_absmax(threadIdx.x, block_size, true); + out128.template store_same_length(idx); + out128.update_absmax(threadIdx.x, blockDim.x, true); } // transpose + copy + format conversion (+ elementwise + absmax) kernel @@ -104,17 +103,13 @@ __global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* _ // only calculate absmax of the input tensor (non-fused) template __global__ void update_absmax_kernel(TensorGPU inp) { - size_t idx = ((blockIdx.x * blockDim.x * ABSMAX_ITERATIONS_PER_THREAD) + threadIdx.x) * inp.num_per_128(); - auto max128 = new_tensor128(inp, disable_scaling); + size_t idx = ((blockIdx.x * blockDim.x) + threadIdx.x) * inp.num_per_128(); + auto max128 = new_tensor128(inp); if (idx < inp.num_elements) { - #pragma unroll - for (int i = 0; i < ABSMAX_ITERATIONS_PER_THREAD; i++) { - auto inp128 = load_tensor128(inp, idx, disable_scaling); - for(int k = 0; k < inp.num_per_128(); ++k) { - float value = inp128.get(k); - max128.add_value_stats(value); - } - idx += blockDim.x * inp.num_per_128(); + auto inp128 = load_tensor128(inp, idx, disable_scaling); + for(int k = 0; k < inp.num_per_128(); ++k) { + float value = inp128.get(k); + max128.add_value_stats(value); } } max128.update_absmax(threadIdx.x, blockDim.x, true, true); @@ -122,34 +117,14 @@ __global__ void update_absmax_kernel(TensorGPU inp) { // ---------------------------------------------------------------------------- -template -void copy_advanced(TensorGPU *copy, TensorGPU *input, size_t N, float* descale_pointer=NULL, float* scale_pointer=NULL, void* absmax_output=NULL, /*bool memset_absmax=true,*/ cudaStream_t stream=0, const size_t block_size=512) { +template +void copy_advanced(TensorGPU out, TensorGPU in, cudaStream_t stream=0, const size_t block_size=512) { + size_t N = out.num_elements; size_t fewest_elements = min(Packed128::size, Packed128::size); - const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); assert((N % fewest_elements) == 0); - constexpr uint absmax_factor = 1; - unsigned int* absmax_uint = (unsigned int*)absmax_output; - - // todo - fix this function - assert(false); - - if (absmax_output) { - /*if (memset_absmax) { - cudaMemset(absmax_output, 0, sizeof(unsigned int)); - }*/ - if (scale_pointer || descale_pointer) { - copy_advanced_kernel<<>>(copy, input, N, descale_pointer, scale_pointer, absmax_uint); - } else { - copy_advanced_kernel<<>>(copy, input, N, NULL, NULL, absmax_uint); - } - } else { - if (scale_pointer || descale_pointer) { - copy_advanced_kernel<<>>(copy, input, N, descale_pointer, scale_pointer); - } else { - copy_advanced_kernel<<>>(copy, input, N); - } - } + const dim3 grid_size(CEIL_DIV(N, block_size * fewest_elements)); + copy_advanced_kernel<<>>(out, in); cudaCheck(cudaGetLastError()); } @@ -173,18 +148,13 @@ void transpose_simple(TensorGPU transposed, TensorGPU input, size_t w, s } template -void update_absmax(TensorGPU inp, bool memset_absmax=false, cudaStream_t stream=main_stream, size_t max_block_size=512) { +void update_absmax(TensorGPU inp, bool memset_absmax=true, cudaStream_t stream=main_stream) { size_t N = inp.num_elements; if (N == 0 || inp.absmax_ptr == NULL) { return; } + assert(N % inp.num_per_128() == 0); - // find the largest block size that divides N - size_t block_size = max_block_size; - while ((N % (block_size * Packed128::size * ABSMAX_ITERATIONS_PER_THREAD)) != 0) { - block_size /= 2; - assert(block_size >= 32); // block size of 1 would be OK, but so inefficient we'd rather fail and debug I think - } - - const dim3 grid_size(CEIL_DIV(N, block_size * ABSMAX_ITERATIONS_PER_THREAD * Packed128::size)); + size_t block_size = 512; + const dim3 grid_size(CEIL_DIV(N, block_size * Packed128::size)); if (memset_absmax) { cudaMemset(inp.absmax_ptr, 0, sizeof(unsigned int)); } diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 9102b2541..ca16d2174 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -191,7 +191,7 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa // e.g. +0.3f ==> 65% chance up, 35% chance down float threshold_percentage = ((float)random / (float)0xFFFFFFFF) - prob_offset; - Ti rounded_down, rounded_up; + Ti rounded_down = (Ti)0.0f, rounded_up = (Ti)0.0f; if constexpr (std::is_same::value) { rounded_down = __float2half_rd(in); rounded_up = __float2half_ru(in); @@ -218,6 +218,8 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa } rounded_up = (__nv_fp8_e4m3)high; rounded_down = (__nv_fp8_e4m3)low; + } else { + assert(false); } float diff = (float)rounded_up - (float)rounded_down; diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index d334c2b03..acc931969 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -30,7 +30,7 @@ __global__ void gelu_forward_kernel2(tensorFP8e4 out, tensorFP8e4 inp) { float half_xi = 0.5f * xi; out128.set(k, half_xi * tanh_in_out + half_xi); } - out128.store_same_length(idx, false); + out128.store(idx, false); out128.update_absmax(threadIdx.x, blockDim.x, true); } @@ -58,7 +58,7 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG float result = local_grad * dout128.get(k); dinp128.set(k, result); } - dinp128.store_same_length(idx, false); + dinp128.store(idx, false); dinp128.update_absmax(threadIdx.x, blockDim.x, true); } diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 17d1b27b7..37f5e00aa 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -75,9 +75,9 @@ __global__ void layernorm_forward_kernel6(TensorGPU out, tensorFP32 mean, ten out128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -template -__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, - const tensorX inp1, const TensorGPU inp2, +template +__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, + const tensorX inp1, const TensorGPU inp2, const tensorX weight, const tensorX bias, int N, int C) { // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here @@ -381,17 +381,17 @@ void layernorm_forward(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); } -template -void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, - tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, +template +void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, + tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); - launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); + launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } -template +template void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, - const TensorGPU dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, + const TensorGPU dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 70cb7c9a4..ad2605af1 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -77,6 +77,8 @@ struct TensorGPU { size_t num_elements = 0; int id = -1; + static constexpr bool no_scaling = (sizeof(ElementType) == 1); + bool is_null() const { return (data_ptr == NULL); } @@ -111,7 +113,7 @@ struct TensorGPU { return sizeof(int4) / sizeof(ElementType); } - __device__ __host__ float get_scalar(size_t index, bool disable_scaling=false) const { + __device__ __host__ float get_scalar(size_t index, bool disable_scaling=no_scaling) const { #ifdef FAKE_FP8 disable_scaling = true; #endif @@ -124,7 +126,7 @@ struct TensorGPU { return value * descale; // [1] = descale } - __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=false) { + __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=no_scaling) { #ifdef FAKE_FP8 disable_scaling = true; #endif @@ -198,6 +200,86 @@ struct TensorSpec { // ---------------------------------------------------------------------------- +// debug helper function +void print_tensor_elements(int tensor_id) { + return; + + printf("Printing tensor %d\n", tensor_id); + TensorSpec spec = tensor_specs[tensor_id]; + size_t num_elements = spec.num_elements; + const char* tensor_name = spec.name; + TT tensor_type = spec.tensor_type; + DType dtype = spec.data_type; + size_t element_size = sizeof_dtype(dtype); + + void* gpu_tensor = spec.ptr; + void* cpu_tensor = malloc(num_elements * element_size); + + printf("Printing tensor %s (tensor_type: %d, data_type: %d)\n", tensor_name, (int)tensor_type, (int)dtype); + printf("GPU memory: %p\n", gpu_tensor); + printf("CPU memory: %p\n", cpu_tensor); + printf("Num elements: %zu\n", num_elements); + printf("Element size: %zu\n", element_size); + printf("Offset: %zu\n", spec.offset); + + cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); + + printf("Did memcpy\n"); + + printf("First 4 of %s: ", tensor_name); + for (int i = 0; i < num_elements && i < 4; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E4M3) { + printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E5M2) { + printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Middle 4 of %s: ", tensor_name); + for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E4M3) { + printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E5M2) { + printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); + } + } + printf("\n"); + + printf("Last 4 of %s: ", tensor_name); + for (int i = num_elements - 4; i < num_elements; i++) { + if (dtype == DType::FP32) { + printf("%.16f ", ((float*)cpu_tensor)[i]); + } else if (dtype == DType::FP16) { + printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); + } else if (dtype == DType::BF16) { + printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E4M3) { + printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); + } else if (dtype == DType::FP8E5M2) { + printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); + } + } + printf("\n"); + printf("\n"); + + free(cpu_tensor); +} + +// ---------------------------------------------------------------------------- + TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { TensorSpec spec = tensor_specs[spec_index]; if (layer > 0 && spec.remaining_layers >= layer) { @@ -207,7 +289,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { assert(false); } assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); - //print_tensor_elements(spec_index); + print_tensor_elements(spec_index); return spec; } @@ -274,81 +356,14 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz return first_tensor_id; } -// debug helper function -void print_tensor_elements(int tensor_id) { - return; - - printf("Printing tensor %d\n", tensor_id); - TensorSpec spec = tensor_specs[tensor_id]; - size_t num_elements = spec.num_elements; - const char* tensor_name = spec.name; - TT tensor_type = spec.tensor_type; - DType dtype = spec.data_type; - size_t element_size = sizeof_dtype(dtype); - - void* gpu_memory = spec.ptr; - void* gpu_tensor = (void*)((char*)gpu_memory + tensor_specs[tensor_id].offset); - void* cpu_tensor = malloc(num_elements * element_size); - - printf("Printing tensor %s (tensor_type: %d, data_type: %d)\n", tensor_name, (int)tensor_type, (int)dtype); - printf("GPU memory: %p\n", gpu_tensor); - printf("CPU memory: %p\n", cpu_tensor); - printf("Num elements: %zu\n", num_elements); - printf("Element size: %zu\n", element_size); - printf("Offset: %zu\n", tensor_specs[tensor_id].offset); - - cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); - - printf("Did memcpy\n"); - - printf("First 4 of %s: ", tensor_name); - for (int i = 0; i < num_elements && i < 4; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Middle 4 of %s: ", tensor_name); - for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Last 4 of %s: ", tensor_name); - for (int i = num_elements - 4; i < num_elements; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } - } - printf("\n"); - printf("\n"); - - free(cpu_tensor); -} - // ---------------------------------------------------------------------------- -__global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned int* gpu_absmax_memory, int num_tensor_specs) { +__global__ void update_scale_descale_kernel(int num_tensor_specs) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid >= num_tensor_specs) return; // Get the absmax value for this tensor - unsigned int absmax_uint = gpu_absmax_memory[tid]; + unsigned int absmax_uint = gpu_absmax_memory_ptr[tid]; float absmax = __uint_as_float(absmax_uint); // Calculate scale and descale @@ -365,7 +380,7 @@ __global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned in descale *= 1.0f/32768.0f; } else { // e4 - //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER) { + //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER || absmax >= 4.0f) { scale *= 256.0f; descale *= (1.0f/256.0f); //} @@ -375,12 +390,16 @@ __global__ void update_scale_descale_kernel(float* gpu_scale_memory, unsigned in descale = 1.0f; } + if (scale != 1.0f) { + //printf("%s: absmax: %f, scale: %f, descale: %f\n", tensor_specs_ptr[tid].name, absmax, scale, descale); + } + // todo: circular buffer //gpu_absmax_memory[tid] = 0.0f; // Update gpu_scale_memory - gpu_scale_memory[tid * 2] = scale; - gpu_scale_memory[tid * 2 + 1] = descale; + gpu_scale_memory_ptr[tid * 2] = scale; + gpu_scale_memory_ptr[tid * 2 + 1] = descale; } // ---------------------------------------------------------------------------- @@ -422,12 +441,11 @@ public: scaling = false; // only do "fake" scaling #endif - if (!disable_scaling) { + scaling = scaling && !disable_scaling; + if (scaling) { const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; scale = ptr_restricted[0]; descale = ptr_restricted[1]; - } else { - scaling = false; } absmax_ptr = tensor.absmax_ptr; } @@ -528,8 +546,8 @@ public: // otherwise, the compiler does silly things related to the redux/atomicMax unsigned int lane_id ; asm volatile("mov.u32 %0, %laneid;" : "=r"(lane_id)); - unsigned int num_warps = num_threads >> 5; - unsigned int warp_id = thread_id >> 5; + unsigned int num_warps = num_threads / WARP_SIZE; + unsigned int warp_id = thread_id / WARP_SIZE; // use native integer reductions as much as possible (supported on all GPUs with FP8) // this might treat NaN/INF slightly differently but that is the least of our problems diff --git a/train_gpt2.cu b/train_gpt2.cu index 031088d14..d6a5e259d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -204,7 +204,7 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS (ln1b, L, C, LAYERNORM | BIAS); TENSOR_SPECS_LOWP(qkvw, L, 3 * C * C, TENSOR_2D); TENSOR_SPECS (qkvb, L, 3 * C, BIAS); - TENSOR_SPECS_LOWP(attprojw, L, C * C, TENSOR_2D); + TENSOR_SPECS (attprojw, L, C * C, TENSOR_2D); TENSOR_SPECS (attprojb, L, C, BIAS); TENSOR_SPECS (ln2w, L, C, LAYERNORM); TENSOR_SPECS (ln2b, L, C, LAYERNORM | BIAS); @@ -274,11 +274,11 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); if (UNIQUE_TENSOR_MEMORY) { - TENSOR_SPECS_LOWP(attproj, L, BTC, TENSOR_2D); + TENSOR_SPECS (attproj, L, BTC, TENSOR_2D); TENSOR_SPECS_LOWP(fcproj, L, BTC, TENSOR_2D); TENSOR_SPECS (output, 1, output_size, TENSOR_2D | EMBEDDING); } else { - spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); + spec->attproj = add_layer_specs(L, "attproj", BTC, shards, dtype, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); spec->fcproj = add_layer_specs(L, "fcproj", BTC, shards, dtype_lowp, model->multiuse.btc, REUSED_MEMORY | TENSOR_2D); spec->output = add_tensor_spec("output", output_size, shards, dtype, model->multiuse.output_scratch, REUSED_MEMORY | EMBEDDING | TENSOR_2D); } @@ -493,10 +493,48 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w gpt2_allocate(model); - // read in the parameters if weight_init is true if (weight_init) { - file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); + fseek(model_file, 0, SEEK_END); + size_t checkpoint_bytes = ftell(model_file) - sizeof(model_header); + fseek(model_file, sizeof(model_header), SEEK_SET); + + if (checkpoint_bytes != model->num_parameters_bytes) { + assert(checkpoint_bytes <= tensors_bytes[MULTIUSE]); // hack - won't work if params size > activations size + file_to_device(model->tensor_memory[MULTIUSE], model_file, checkpoint_bytes, IO_BUF_SIZE); + + size_t offset = 0; + int num_param_tensors = tensors_start[PARAMETER_GRAD]; + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor = tensor_specs[i]; + tensor.data_ptr = (floatX*)(model->tensor_memory[MULTIUSE] + offset); + offset += tensor.num_elements * sizeof(floatX); + update_absmax(tensor); + } + + int absmax_block_size = 256; + int num_blocks = (num_param_tensors + absmax_block_size - 1) / absmax_block_size; + update_scale_descale_kernel<<>>(num_param_tensors); + + offset = 0; + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor_in = tensor_specs[i]; + tensor_in.data_ptr = (floatX*)(model->tensor_memory[MULTIUSE] + offset); + offset += tensor_in.num_elements * sizeof(floatX); + + switch (tensor_specs[i].data_type) { + case DType::FP32: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; + case DType::BF16: copy_advanced((TensorGPU<__nv_bfloat16>)tensor_specs[i], tensor_in); break; + case DType::FP16: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; + case DType::FP8E4M3: copy_advanced((TensorGPU<__nv_fp8_e4m3>)tensor_specs[i], tensor_in); break; + case DType::FP8E5M2: copy_advanced((TensorGPU<__nv_fp8_e5m2>)tensor_specs[i], tensor_in); break; + } + } + cudaMemset(model->tensor_memory[MULTIUSE], 0, checkpoint_bytes); + } else { + file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); + } } + fcloseCheck(model_file); // only return from this function once we are certain the params are ready on the GPU @@ -669,8 +707,8 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); - fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); @@ -718,8 +756,14 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { NVTX_RANGE_FN(); - exit(0); -#if 0 + + + + cudaCheck(cudaMemsetAsync(model->tensor_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); + + + + #if 0 // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; @@ -776,10 +820,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm - layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); + layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } matmul_backward(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, B*T, C, 4 * C); - layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); #ifdef ENABLE_CUDNN @@ -792,10 +836,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #endif if(model->recompute >= 2) { - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } matmul_backward(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, B*T, C, 3 * C); - layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. if(last_step) { @@ -819,6 +863,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } // Is it time to redo the forward pass from our activation checkpoints? + /* if (LAYERS_PER_ACTIVATION_CHECKPOINT && (l % max(1, LAYERS_PER_ACTIVATION_CHECKPOINT)) == 0 && l > 0) { int old_l = l; // forward pass time! @@ -845,7 +890,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } l = old_l; } - + */ } encoder_backward(PGRAD(wte), PGRAD(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, @@ -920,7 +965,7 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { } void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, - MultiGpuConfig* multi_gpu_config, bool init_from_master_only=false) { + MultiGpuConfig* multi_gpu_config) { // update the model parameters using the AdamW optimizer // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs // so we may not be responsible for the entire parameter tensor @@ -941,6 +986,27 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); } + int absmax_block_size = 256; + int absmax_num_blocks = (num_tensor_specs + absmax_block_size - 1) / absmax_block_size; + +/* + printf("--------------\n"); + update_scale_descale_kernel<<>>(num_tensor_specs); + + // copy all absmax to CPU and printf if != 1.0f + float* absmax_cpu = (float*)malloc(num_tensor_specs * sizeof(float)); + float* scale_cpu = (float*)malloc(num_tensor_specs * 2 * sizeof(float)); + cudaCheck(cudaMemcpy(absmax_cpu, gpu_absmax_memory, num_tensor_specs * sizeof(float), cudaMemcpyDeviceToHost)); + cudaCheck(cudaMemcpy(scale_cpu, gpu_scale_memory, num_tensor_specs * 2 * sizeof(float), cudaMemcpyDeviceToHost)); + for (int i = 0; i < num_tensor_specs; i++) { + if (scale_cpu[i*2] != 1.0f || absmax_cpu[i] != 0.0f) { + printf("scale[%d/%s] ==> %.10f ==> %.10f / %.10f\n", i, tensor_specs[i].name, absmax_cpu[i], scale_cpu[i*2], scale_cpu[i*2+1]); + } + } + + printf("==============\n"); +*/ + // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; @@ -952,14 +1018,12 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo const int block_size = 64; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; if (model->use_master_weights) { - if (init_state || init_from_master_only) { + if (init_state) { // reads regular weights & writes to master+regular weights - // or init_from_master_only: reads master & write to regular weights as-is adamw_full_update<<>>( tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], model->num_parameters, model->num_parameters / num_shards, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t, - init_from_master_only); + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } else { // reads master weights & writes to master+regular weights adamw_full_update<<>>( @@ -1041,9 +1105,23 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo */ // todo - hack - update scale/descale from absmax - int absmax_block_size = 256; - int num_blocks = (num_tensor_specs + absmax_block_size - 1) / block_size; - update_scale_descale_kernel<<>>(gpu_scale_memory, gpu_absmax_memory, num_tensor_specs); + update_scale_descale_kernel<<>>(num_tensor_specs); + +/* + float* absmax_cpu2 = (float*)malloc(num_tensor_specs * sizeof(float)); + float* scale_cpu2 = (float*)malloc(num_tensor_specs * 2 * sizeof(float)); + cudaCheck(cudaMemcpy(absmax_cpu2, gpu_absmax_memory, num_tensor_specs * sizeof(float), cudaMemcpyDeviceToHost)); + cudaCheck(cudaMemcpy(scale_cpu2, gpu_scale_memory, num_tensor_specs * 2 * sizeof(float), cudaMemcpyDeviceToHost)); + for (int i = 0; i < num_tensor_specs; i++) { + if (scale_cpu[i*2] != scale_cpu2[i*2] || absmax_cpu[i] != absmax_cpu2[i]) { + printf("scale[%d/%s] ==> absmax: %f -> %f, scale: %f -> %f, descale: %f -> %f\n", i, tensor_specs[i].name, absmax_cpu[i], absmax_cpu2[i], scale_cpu[i*2], scale_cpu2[i*2], scale_cpu[i*2+1], scale_cpu2[i*2+1]); + } + } + free(scale_cpu); + free(absmax_cpu); + free(scale_cpu2); + free(absmax_cpu2); +*/ cudaCheck(cudaDeviceSynchronize()); } @@ -1203,10 +1281,6 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename if(model->use_master_weights) { assert(model->tensor_memory[PARAMETER_MASTER] != nullptr); file_to_device(model->tensor_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); - // restore weights from the master weights using the RNG state before last weight update - model->rng_state = model->rng_state_last_update; - gpt2_update(model, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0, &multi_gpu_config, /* init_from_master_only*/ true); - model->rng_state = *((unsigned long long*)&state_header[20]); // use final RNG state from checkpoint after this } // revive the DataLoader object and its state @@ -1626,9 +1700,6 @@ int main(int argc, char *argv[]) { // in any case, this must be true or we'd index beyond the model's wpe (position embedding table) assert(T <= model.config.max_seq_len); - // todo - hack - do this to update the absmax of all the weights - gpt2_update(&model, 0.0f, 0.9f, 0.95f, 1e-8f, 1.0f, 1.0f, 1, &multi_gpu_config); - // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -1781,7 +1852,7 @@ int main(int argc, char *argv[]) { // todo - ideally should rerun this step so we don't "waste" the data without training on it if (step == 0) { step_learning_rate = 0.0f; - weight_decay = 1.0f; + weight_decay = 0.0f; } gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); From 41fd098b02d5206109e59412c962f3a957c3d619 Mon Sep 17 00:00:00 2001 From: ademeure Date: Tue, 17 Sep 2024 19:56:25 +0000 Subject: [PATCH 18/27] FP8 forward+backward+update (again)! FINALLY! --- llmc/attention.cuh | 5 +- llmc/copy_and_fp8.h | 158 ++++---------------------------------------- llmc/gelu.cuh | 14 ++-- llmc/matmul.cuh | 139 +++++++++++++++++++++++--------------- llmc/tensor.cuh | 146 ++++++++++++++++++---------------------- profile_gpt2.cu | 1 - train_gpt2.cu | 117 ++++++++++++-------------------- 7 files changed, 215 insertions(+), 365 deletions(-) diff --git a/llmc/attention.cuh b/llmc/attention.cuh index f91380ad3..36dcb58ae 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -270,13 +270,16 @@ void attention_forward(tensorX out, floatX* qkvr, floatX* att, // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) -void attention_backward(tensorX dinp, floatX* dqkvr, floatX* datt, floatX* scratch, +void attention_backward(tensorX dinp, floatX* dqkvr, floatX* datt, tensorX dout, tensorX qkvr, floatX* att, int B, int T, int C, int NH, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int HS = C / NH; // head size + // now reusing dinp as scratch buffer (free before the final output and it's the right size) + floatX* scratch = dinp.data_ptr; + // unpack convenience pointers into q, k, v floatX *q, *k, *v; q = qkvr + 0 * B * T * C; diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h index f1ebffe6c..94c8091b7 100644 --- a/llmc/copy_and_fp8.h +++ b/llmc/copy_and_fp8.h @@ -42,7 +42,7 @@ template __global__ void copy_advanced_kernel(TensorGPU out, TensorGPU in) { - constexpr size_t vec_size = 16 / ((sizeof(Tin) < sizeof(Tout)) ? sizeof(Tout) : sizeof(Tin)); + constexpr size_t vec_size = 16 / ((sizeof(Tin) >= sizeof(Tout)) ? sizeof(Tin) : sizeof(Tout)); size_t adjusted_blockidx = reversed_order ? (gridDim.x - blockIdx.x - 1) : blockIdx.x; size_t idx = (adjusted_blockidx * blockDim.x + threadIdx.x) * vec_size; if (idx >= out.num_elements) { return; } @@ -57,16 +57,15 @@ __global__ void copy_advanced_kernel(TensorGPU out, TensorGPU in) { out128.update_absmax(threadIdx.x, blockDim.x, true); } -// transpose + copy + format conversion (+ elementwise + absmax) kernel template -__global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* __restrict__ input, int height) +__global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* __restrict__ input) { + constexpr size_t elements = 16 / sizeof(T1); __shared__ T1 tile[TILE_DIM][TILE_DIM]; int width = gridDim.x * TILE_DIM; - height = gridDim.y * TILE_DIM; + int height = gridDim.y * TILE_DIM; - constexpr size_t elements = 16 / sizeof(T1); - int x = blockIdx.x * TILE_DIM + (threadIdx.x * elements); + int x = blockIdx.x * TILE_DIM + threadIdx.x * elements; int y = blockIdx.y * TILE_DIM + threadIdx.y; #pragma unroll @@ -77,15 +76,9 @@ __global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* _ } __syncthreads(); - constexpr size_t block_size_x = (TILE_DIM * sizeof(T1)) / 16; - constexpr size_t block_size_y = BLOCK_ROWS; - - int adjusted_tid_x = threadIdx.x % block_size_x; - int adjusted_tid_y = (threadIdx.y) + (threadIdx.x / block_size_y); - // x/y for final write to global memory - x = blockIdx.y * TILE_DIM + adjusted_tid_x * elements; - y = blockIdx.x * TILE_DIM + adjusted_tid_y; + x = blockIdx.y * TILE_DIM + threadIdx.x * elements; + y = blockIdx.x * TILE_DIM + threadIdx.y; #pragma unroll for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) { @@ -94,9 +87,9 @@ __global__ void transpose_simple_kernel(T1* __restrict__ transposed, const T1* _ for (int k = 0; k < elements; k++) { // these are tiny 8-bit loads with loads of bank conflicts for FP8 // extremely hard to avoid and not a bottleneck when everything else is well optimised - out128[k] = tile[k + (adjusted_tid_x) * out128.size][adjusted_tid_y + j]; + out128[k] = tile[k + threadIdx.x * elements][threadIdx.y + j]; } - store128(transposed + x + out128.size + (y+j)*height, out128); + store128(transposed + x + (y+j)*height, out128); } } @@ -139,9 +132,9 @@ void transpose_simple(TensorGPU transposed, TensorGPU input, size_t w, s dim3 block_size_dim(block_size_x, block_size_y, 1); switch (block_size_y) { - case 64: transpose_simple_kernel<64, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; - case 32: transpose_simple_kernel<32, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; - case 16: transpose_simple_kernel<16, TRANSPOSE_TILE_SIZE><<>>(transposed, input, h); break; + case 64: transpose_simple_kernel<64, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; + case 32: transpose_simple_kernel<32, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; + case 16: transpose_simple_kernel<16, TRANSPOSE_TILE_SIZE><<>>((T1*)transposed, (T1*)input); break; default: printf("Invalid block size (might be easy to add): %lu\n", block_size_y); exit(1); } cudaCheck(cudaGetLastError()); @@ -162,131 +155,4 @@ void update_absmax(TensorGPU inp, bool memset_absmax=true, cudaStream_t strea cudaCheck(cudaGetLastError()); } -// ---------------------------------------------------------------------------- -// Scratch allocation for FP8 conversions etc. -// todo - consider alternatives (or at least move it somewhere else) - -#include -#include -#include - -class CudaScratchAllocator { -private: - struct Allocation { - void* ptr; - size_t size; - bool in_use; - - Allocation(void* p, size_t s) : ptr(p), size(s), in_use(false) {} - }; - - static std::vector allocations; - static size_t total_allocated; - -public: - template - static T* getMemory(size_t count, bool exact=false) { - size_t size = count * sizeof(T); - - // Find the smallest free allocation that fits the requested size - auto it = std::min_element(allocations.begin(), allocations.end(), - [size](const Allocation& a, const Allocation& b) { - return !a.in_use && a.size >= size && (b.in_use || b.size < size || a.size < b.size); - }); - - if (it != allocations.end() && !it->in_use && it->size >= size && (!exact || it->size == size)) { - it->in_use = true; - return reinterpret_cast(it->ptr); - } - - // If no suitable allocation found, create a new one - void* new_ptr; - cudaMalloc(&new_ptr, size); - allocations.emplace_back(new_ptr, size); - allocations.back().in_use = true; - total_allocated += size; - printf("Allocated CUDA scratch memory: %lu bytes (%p) ==> total allocated: %.1fGiB\n", size, new_ptr, total_allocated / (1024.0 * 1024.0 * 1024.0)); - return reinterpret_cast(new_ptr); - } - - template - static void releaseMemory(T* ptr) { - if (ptr == nullptr) { return; } - auto it = std::find_if(allocations.begin(), allocations.end(), - [ptr](const Allocation& a) { return a.ptr == (void*)ptr; }); - - if (it != allocations.end()) { - it->in_use = false; - } - } - - static void cleanup() { - for (const auto& alloc : allocations) { - cudaFree(alloc.ptr); - } - allocations.clear(); - } -}; -std::vector CudaScratchAllocator::allocations; -size_t CudaScratchAllocator::total_allocated = 0; - -// ---------------------------------------------------------------------------- -// Transposed Cache (for FP8 weights) - -#include - -// Custom hash function for std::pair -// todo - why did we need this? complained about default constructor issue? -struct PairHash { - std::size_t operator()(const std::pair& p) const { - return std::hash{}(p.first) ^ (std::hash{}(p.second) << 1); - } -}; - -class TransposedCache { -private: - struct CacheEntry { - void* ptr; - size_t size; - }; - - std::unordered_map, CacheEntry, PairHash> cache; - -public: - TransposedCache() = default; - - template - Tout* getTransposed(const T* original, const void* associatedTensor, size_t m, size_t k, bool compute=true, bool find_only=false, cudaStream_t stream=0) { - uint64_t key1 = reinterpret_cast(original); - uint64_t key2 = reinterpret_cast(associatedTensor); - auto key = std::make_pair(key1, key2); - size_t size = m * k * sizeof(T); - - auto it = cache.find(key); - if (it != cache.end() && it->second.size == size) { - return reinterpret_cast(it->second.ptr); - } - if (find_only) { - return nullptr; - } - - Tout* transposed = CudaScratchAllocator::getMemory(m * k, true); - if (compute) { - // todo - //copy_or_transpose(true, transposed, original, m, k, nullptr, nullptr, nullptr, stream); - } - - cache[key] = {transposed, size}; - return transposed; - } - - void clearCache() { - for (const auto& entry : cache) { - CudaScratchAllocator::releaseMemory(entry.second.ptr); - } - cache.clear(); - } -}; -TransposedCache g_transposed_cache; - #endif \ No newline at end of file diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index acc931969..bc56e67d8 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -10,7 +10,8 @@ // CUDA kernels #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) -__global__ void gelu_forward_kernel2(tensorFP8e4 out, tensorFP8e4 inp) { +template +__global__ void gelu_forward_kernel2(TensorGPU out, TensorGPU inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); auto out128 = new_tensor128(out); @@ -35,8 +36,8 @@ __global__ void gelu_forward_kernel2(tensorFP8e4 out, tensorFP8e4 inp) { } //template -template -__global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorGPU inp) { +template +__global__ void gelu_backward_kernel(TensorGPU dinp, TensorGPU dout, TensorGPU inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); auto dinp128 = new_tensor128(dinp); @@ -64,8 +65,8 @@ __global__ void gelu_backward_kernel(tensorFP8e5 dinp, tensorFP8e5 dout, TensorG // ---------------------------------------------------------------------------- // kernel launchers - -void gelu_forward(tensorFP8e4 out, tensorFP8e4 inp, cudaStream_t stream=main_stream) { +template +void gelu_forward(TensorGPU out, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; assert(inp.num_elements % (block_size * inp.num_per_128()) == 0); @@ -75,7 +76,8 @@ void gelu_forward(tensorFP8e4 out, tensorFP8e4 inp, cudaStream_t stream=main_str cudaCheck(cudaGetLastError()); } -void gelu_backward(tensorFP8e5 dinp, tensorFP8e5 dout, tensorFP8e4 inp, cudaStream_t stream=main_stream) { +template +void gelu_backward(TensorGPU dinp, TensorGPU dout, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index a915fb0b5..db376ebb8 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -16,9 +16,11 @@ Matrix Multiplication, with help from cuBLASLt // ---------------------------------------------------------------------------- // CUDA kernels -template -__global__ void matmul_backward_bias_kernel9(TensorGPU dbias, tensorX dout, int BT, int OC, +template +__global__ void matmul_backward_bias_kernel9(TensorGPU dbias, TensorGPU dout, int BT, int OC, std::bool_constant) { + // todo - this kernel is way more complicated than it needs to be + // (should look at my old PR to simplify it again after this) constexpr const int bdx = 4; constexpr const int bdy = WARP_SIZE / bdx; assert(blockDim.x == bdx); @@ -28,33 +30,33 @@ __global__ void matmul_backward_bias_kernel9(TensorGPU dbias, tensorX int warp_c = (int)threadIdx.y; int block_d = (int)threadIdx.z; - const int OC_per_warp = bdy * x128::size; // 64 at BF16 + const int OC_per_warp = bdy * Packed128::size; // 64 at BF16 - int local_oc = warp_c * x128::size; + int local_oc = warp_c * Packed128::size; int global_oc = blockIdx.x * OC_per_warp + local_oc; int local_bt = warp_d + bdx * block_d; int bt_per_block = bdx * blockDim.z; - float accumulators[x128::size]; - for (int k = 0; k < x128::size; k++) { + float accumulators[Packed128::size]; + for (int k = 0; k < Packed128::size; k++) { accumulators[k] = 0.0f; } if(global_oc < OC) { // sum up over all bt within registers for (int idx = blockIdx.y * bt_per_block + local_bt; idx < BT; idx += gridDim.y * bt_per_block) { - x128 packed_dout = load128(dout + global_oc + idx*OC); - for (int k = 0; k < x128::size; k++) { - accumulators[k] += (float)packed_dout[k]; + auto dout128 = load_tensor128(dout, global_oc + idx*OC); + for (int k = 0; k < Packed128::size; k++) { + accumulators[k] += dout128.get(k); } } } - __shared__ float sub_results[x128::size][WARP_SIZE][bdy]; + __shared__ float sub_results[Packed128::size][WARP_SIZE][bdy]; // reduce within-warp results - for (int k = 0; k < x128::size; k++) { + for (int k = 0; k < Packed128::size; k++) { float v = accumulators[k]; v += __shfl_down_sync(0xffffffff, v, 1, 4); v += __shfl_down_sync(0xffffffff, v, 2, 4); @@ -65,7 +67,7 @@ __global__ void matmul_backward_bias_kernel9(TensorGPU dbias, tensorX __syncthreads(); // block-wide reductions - for (int k = block_d; k < x128::size; k += blockDim.z) { + for (int k = block_d; k < Packed128::size; k += blockDim.z) { float a = 0.f; for (int r = warp_d; r < blockDim.z; r += bdx) { float v = sub_results[k][r][warp_c]; @@ -152,8 +154,7 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX } else { cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, CUBLAS_LOWP, k, n, k)); } - // cuBLASLt requires C in FP8 mode to be BF16 or FP32... (sigh) - cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP, m, n, m)); + cublasCheck(cublasLtMatrixLayoutCreate(&CLayout, CUBLAS_LOWP, m, n, m)); cublasCheck(cublasLtMatrixLayoutCreate(&DLayout, CUBLAS_LOWP, m, n, m)); // Strided Batched GEMM (used for non-flash attention, equivalent to cublasGemmStridedBatchedEx) @@ -183,18 +184,8 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX if (backward) { assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias epilogue = CUBLASLT_EPILOGUE_DGELU; - if (pre_gelu.scale_descale_ptr) { // descale input - //float* gelu_descale_ptr = pre_gelu.scale_descale_ptr + 1; - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER, &gelu_descale_ptr, sizeof(float*))); - } } else { epilogue = has_bias ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; - if (pre_gelu.absmax_ptr) { - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER, &pre_gelu.absmax_ptr, sizeof(pre_gelu.absmax_ptr))); - } - if (pre_gelu.scale_descale_ptr) { // scale output - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER, &pre_gelu.scale_descale_ptr, sizeof(float*))); - } } } else if(has_bias){ epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; @@ -210,21 +201,6 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias.data_ptr, sizeof(bias.data_ptr))); } - // scale factors - if (a.scale_descale_ptr) { - //float* a_descale_ptr = a.scale_descale_ptr + 1; - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_descale_ptr, sizeof(float*))); - } - if (b.scale_descale_ptr) { - //float* b_descale_ptr = b.scale_descale_ptr + 1; - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_descale_ptr, sizeof(float*))); - } - if (d.scale_descale_ptr) { - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d.scale_descale_ptr, sizeof(float*))); - } - if (d.absmax_ptr) { - //cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &d.absmax_ptr, sizeof(float*))); - } // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) cublasDataType_t scale_type = CUDA_R_32F; @@ -246,8 +222,6 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); - update_absmax(d, false, stream); - // cleanups cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); @@ -311,7 +285,6 @@ void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &d.absmax_ptr, sizeof(float*))); } - // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) cublasDataType_t scale_type = CUDA_R_32F; cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); @@ -324,7 +297,6 @@ void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU // find a suitable algorithm (cached internally so shouldn't take much CPU time in practice) int returnedResults = 0; cublasLtMatmulHeuristicResult_t heuristic; - cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout, preference, 1, &heuristic, &returnedResults); @@ -380,12 +352,8 @@ void matmul_forward_cublaslt(TensorGPU out, } } -template -void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, - tensorX dout, tensorX inp, tensorX weight, - tensorFP32 dbias_buffer, - int BT, int C, int OC, - TensorGPU pre_gelu=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { +template +void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensorFP32 scratch, int BT, int OC, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // backward to bias, if given, does a += @@ -407,20 +375,83 @@ void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset - matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, BT, OC, True); + matmul_backward_bias_kernel9<<>>(scratch, dout, BT, OC, True); cudaCheck(cudaGetLastError()); - reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); + reduce_add_sum_kernel<<>>(dbias, scratch, OC, grid_size_y); cudaCheck(cudaGetLastError()); } } +} + +template +void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, + TensorGPU dout, tensorFP8e4 inp, tensorFP8e4 weight, + tensorFP32 scratch1_big, tensorFP32 scratch2_huge, + int BT, int C, int OC, + tensorFP8e4 pre_gelu_activation=null_tensorFP8E4, cudaStream_t stream=main_stream) { +#ifndef ENABLE_FP8 + // FP8 is not enabled so we use the regular floatX matmul path + matmul_backward(dinp, dweight, dbias, dout, inp, weight, scratch1_big, BT, C, OC, pre_gelu_activation, 1, stream); +#else + NVTX_RANGE_FN(); + matmul_backward_bias(dbias, dout, scratch1_big, BT, OC, stream); + + // N.B.: Both scratch1 and scratch2 are guaranteed to be big enough for 4BTC and 4CC in FP8 + // IMPORTANT: inp is allowed to be the same buffer as scratch2_huge (e.g. for fch_gelu) + // ==> this MUST be done first and write to scratch1_big! + // transpose input + TensorGPU inp_fp8_transposed = inp; + inp_fp8_transposed.data_ptr = (float8e4*)scratch1_big.data_ptr; + transpose_simple(inp_fp8_transposed, inp, C, BT, stream); + + // convert dout to FP8e5 if it is not already, and transpose it + // the buffer is guaranteed to be at least twice as big as 4BTC, so we can split it in 2 + // todo - merge conversion and tranposition like we did before? + TensorGPU dout_fp8; + if constexpr (std::is_same::value) { + dout_fp8 = dout; + } else { + dout_fp8 = *(TensorGPU*)&dout; + dout_fp8.data_ptr = (float8e5*)(scratch2_huge.data_ptr); + copy_advanced(dout_fp8, dout, stream); + } + TensorGPU dout_fp8_transposed = dout_fp8; + dout_fp8_transposed.data_ptr = (float8e5*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); + transpose_simple(dout_fp8_transposed, dout_fp8, OC, BT, stream); + + // GEMM 1: dweight, inp_fp8_transposed, dout_fp8_transposed + matmul_cublaslt_fp8(dweight, inp_fp8_transposed, dout_fp8_transposed, null_tensorX, C, OC, BT, stream, false, true); + + // transpose weight (todo: option to cache this / do it at optimizer time) + TensorGPU weight_fp8_transposed = weight; + weight_fp8_transposed.data_ptr = (float8e4*)scratch1_big.data_ptr; + transpose_simple(weight_fp8_transposed, weight, C, OC, stream); + + matmul_cublaslt_fp8(dinp, weight_fp8_transposed, dout_fp8, null_tensorX, C, BT, OC, stream, false, true); + + if (pre_gelu_activation.data_ptr) { + gelu_backward(dinp, dinp, pre_gelu_activation, stream); + } +#endif +} + + +template +void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, + tensorX dout, tensorX inp, tensorX weight, + tensorFP32 dbias_scratch, + int BT, int C, int OC, + TensorGPU pre_gelu_activation=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { + NVTX_RANGE_FN(); + matmul_backward_bias(dbias, dout, dbias_scratch, BT, OC, stream); // backward to input, uses = in the backward pass (set the gradient) matmul_cublaslt(dinp, weight, dout, null_tensorX, C, BT, OC, stream, false, false, 0, 0, 0, 0, false, - gelu_fusion >= 2 ? pre_gelu : null_tensorX, true); + gelu_fusion >= 2 ? pre_gelu_activation : null_tensorX, true); // backward GELU (if it wasn't fused into the matmul above) - if (gelu_fusion < 2 && pre_gelu != null_tensorX) { - gelu_backward(dinp, dinp, pre_gelu, stream); + if (gelu_fusion < 2 && pre_gelu_activation.enabled()) { + gelu_backward(dinp, dinp, pre_gelu_activation, stream); } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index ad2605af1..395099884 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -3,7 +3,7 @@ // ... //#define FAKE_FP8 -#define UNIQUE_TENSOR_MEMORY true +#define UNIQUE_TENSOR_MEMORY false #define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled // ... @@ -21,8 +21,8 @@ enum TT : uint8_t { enum TFlags : uint8_t { NONE=0, - REUSED_MEMORY=1, - GRADIENT=2, + GRADIENT=1, + REUSED_MEMORY=2, TENSOR_2D=4, // used for matmul *outputs* only, not inputs (+weights) BIAS=8, LAYERNORM=16, @@ -77,13 +77,12 @@ struct TensorGPU { size_t num_elements = 0; int id = -1; - static constexpr bool no_scaling = (sizeof(ElementType) == 1); - + static constexpr bool no_scaling = (sizeof(ElementType) != 1); bool is_null() const { return (data_ptr == NULL); } bool enabled() const { - return (absmax_ptr != NULL); + return (data_ptr != NULL); } static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { @@ -202,8 +201,6 @@ struct TensorSpec { // debug helper function void print_tensor_elements(int tensor_id) { - return; - printf("Printing tensor %d\n", tensor_id); TensorSpec spec = tensor_specs[tensor_id]; size_t num_elements = spec.num_elements; @@ -215,65 +212,35 @@ void print_tensor_elements(int tensor_id) { void* gpu_tensor = spec.ptr; void* cpu_tensor = malloc(num_elements * element_size); - printf("Printing tensor %s (tensor_type: %d, data_type: %d)\n", tensor_name, (int)tensor_type, (int)dtype); + // Get scale from GPU + float scale, descale, absmax; + cudaMemcpy(&scale, &gpu_scale_memory[spec.id * 2], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&descale, &gpu_scale_memory[spec.id * 2 + 1], sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(&absmax, &gpu_absmax_memory[spec.id], sizeof(float), cudaMemcpyDeviceToHost); + + printf("Printing tensor %s (tensor_type: %d, data_type: %d, flags: %d)\n", tensor_name, (int)tensor_type, (int)dtype, spec.flags); printf("GPU memory: %p\n", gpu_tensor); printf("CPU memory: %p\n", cpu_tensor); printf("Num elements: %zu\n", num_elements); printf("Element size: %zu\n", element_size); printf("Offset: %zu\n", spec.offset); + printf("Scale: %f, Descale: %f, Absmax: %f\n", scale, descale, absmax); cudaCheck(cudaMemcpy(cpu_tensor, gpu_tensor, num_elements * element_size, cudaMemcpyDeviceToHost)); - printf("Did memcpy\n"); - - printf("First 4 of %s: ", tensor_name); - for (int i = 0; i < num_elements && i < 4; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E4M3) { - printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E5M2) { - printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); + printf("First 4 & Last 4 of %s:\n", tensor_name); + for (int i = 0; i < 8; i++) { + int idx = (i < 4) ? i : num_elements - 8 + i; + switch (dtype) { + case DType::FP32: printf("%.16f ", ((float*)cpu_tensor)[idx]); break; + case DType::FP16: printf("%.16f ", (float)((__nv_half*)cpu_tensor)[idx]); break; + case DType::BF16: printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[idx]); break; + case DType::FP8E4M3: printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[idx]); break; + case DType::FP8E5M2: printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[idx]); break; } + if (i == 3) printf("\n"); } - printf("\n"); - - printf("Middle 4 of %s: ", tensor_name); - for (int i = (num_elements/2) + 4; i < num_elements && i < (num_elements/2 + 8); i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E4M3) { - printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E5M2) { - printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); - } - } - printf("\n"); - - printf("Last 4 of %s: ", tensor_name); - for (int i = num_elements - 4; i < num_elements; i++) { - if (dtype == DType::FP32) { - printf("%.16f ", ((float*)cpu_tensor)[i]); - } else if (dtype == DType::FP16) { - printf("%.16f ", (float)((__nv_half*)cpu_tensor)[i]); - } else if (dtype == DType::BF16) { - printf("%.16f ", (float)((__nv_bfloat16*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E4M3) { - printf("%.16f ", (float)((__nv_fp8_e4m3*)cpu_tensor)[i]); - } else if (dtype == DType::FP8E5M2) { - printf("%.16f ", (float)((__nv_fp8_e5m2*)cpu_tensor)[i]); - } - } - printf("\n"); - printf("\n"); + printf("\n\n"); free(cpu_tensor); } @@ -289,7 +256,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { assert(false); } assert(spec.tensor_type == tensor_type || tensor_type == DEFAULT); - print_tensor_elements(spec_index); + //print_tensor_elements(spec.id); // enable for extreme debugging return spec; } @@ -315,13 +282,10 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); if (base_spec.tensor_type != spec->tensor_type) { - printf("ERROR: tensor_type mismatch for %s: %d vs %d\n", - spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); + printf("ERROR: tensor_type for %s: %d vs %d\n", spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); assert(false); } - if (flags & REUSED_MEMORY) { - base_spec.flags |= REUSED_MEMORY; - } + base_spec.flags |= (flags & REUSED_MEMORY); assert(base_spec.tensor_type == spec->tensor_type); assert(new_tensor_bytes <= original_tensor_bytes); } else { @@ -367,39 +331,57 @@ __global__ void update_scale_descale_kernel(int num_tensor_specs) { float absmax = __uint_as_float(absmax_uint); // Calculate scale and descale - if (absmax == 0.0f) { - absmax = 1.0f; + float scale = 1.0f; + float descale = 1.0f; + if (absmax != 0.0f) { + scale = 1.0f / absmax; + descale = absmax; } - float scale = 1.0f / absmax; - float descale = absmax; - if (!(tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) && !(tensor_specs_ptr[tid].flags & TFlags::EMBEDDING) && absmax != 1.0f) { - if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE)) { - // e5 + if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE)) { + // e5 + if (absmax != 0.0f) { scale *= 32768.0f; descale *= 1.0f/32768.0f; } else { - // e4 - //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER || absmax >= 4.0f) { - scale *= 256.0f; - descale *= (1.0f/256.0f); - //} + // default so that things are not as bad for gradients on the first step + scale = 4096.0f; + descale = 1.0f/4096.0f; } } else { - scale = 1.0f; - descale = 1.0f; + // e4 + // todo - power benefit of making sure top bit of exponent is (nearly always) zero? + // this can be done simply by *not* multiplying here, so that the "maximum" is 1.0f + //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER || absmax >= 4.0f) { + if (absmax != 0.0f) { + scale *= 256.0f; + descale *= (1.0f/256.0f); + } } - if (scale != 1.0f) { - //printf("%s: absmax: %f, scale: %f, descale: %f\n", tensor_specs_ptr[tid].name, absmax, scale, descale); + #ifdef FAKE_FP8 + // with real FP8, we rely on tensor128 not scaling when sizeof(T)>1, but that doesn't work with fake FP8 + // so we prevent scaling for the things we know we don't want to scale + // this might not match what we have in the real FP8 implementation, but allows for quick experimentation + if ((tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) || (tensor_specs_ptr[tid].flags & TFlags::EMBEDDING)) { + scale = 1.0f; + descale = 1.0f; } - - // todo: circular buffer - //gpu_absmax_memory[tid] = 0.0f; + #endif // Update gpu_scale_memory + // todo: descale should be delayed by one step for parameters (see comment in gpt2_update). gpu_scale_memory_ptr[tid * 2] = scale; gpu_scale_memory_ptr[tid * 2 + 1] = descale; + + // todo: circular buffer !!! + //gpu_absmax_memory[tid] = 0.0f; +} + +void update_scales_from_absmax() { + int block_size = 256; + int num_blocks = CEIL_DIV(num_tensor_specs, block_size); + update_scale_descale_kernel<<>>(num_tensor_specs); } // ---------------------------------------------------------------------------- diff --git a/profile_gpt2.cu b/profile_gpt2.cu index fa5e528d7..940629af6 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -58,7 +58,6 @@ int main(int argc, char *argv[]) { model.config.num_layers = 1; set_zero_configs(&multi_gpu_config, 0, model.num_parameters); - gpt2_allocate_state(&model, B, T); // do a training step gpt2_forward(&model, x, B, T); gpt2_backward_and_reduce(&model, x, y, 1, 0); diff --git a/train_gpt2.cu b/train_gpt2.cu index d6a5e259d..f7920c0f9 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -119,9 +119,10 @@ typedef struct { typedef struct { int bt4c; // (B, T, 4*C) int btc; // (B, T, C) - int local_scratch; // (B, T, C) - int output_scratch; // huge - int output_scratch_fp32; // same memory as FP32 + int local_scratch; // big, see local_scratch_size below + int local_scratch_fp32; // same memory as above + int output_scratch; // huge, see output_size below + int output_scratch_fp32; // same memory as above } MultiuseTensors; typedef struct { @@ -176,10 +177,23 @@ void gpt2_allocate(GPT2 *model) { size_t B = model->batch_size; size_t T = model->seq_len; size_t NH = model->config.num_heads; - size_t output_size = B*T * max(4*C, max(NH*T, Vp)); size_t BTC = B*T*C; size_t BT = B*T; + // output is also used as a scratch buffer (floatX), needs to be big enough for: + // 1) Output: B*T*Vp (padded vocabulary size) + // 2) 4BTC (largest activation/grad tensor) + // 3) 4CC FP8 (largest parameter tensor, 2*C*C if floatX=BF16) + // 4) B*T*T*NH (non-cuDNN attention tensor) + size_t output_size = max(B*T * max(Vp, 4*C), 4*C*C/sizeof(floatX)); + output_size = CUDNN_ENABLED ? output_size : max(output_size, B*T*T*NH); + // local scratch (floatX), must be big enough for: + // 1) BTC (in floatX) + // 2) 4BTC FP8 (transpose cache) + // 2) 4CC FP8 (largest parameter tensor in FP8) + // 3) 4BTC BF16 (non-cuDNN backwards scratch in floatX) + size_t local_scratch_size = max(CUDNN_ENABLED ? 4*BTC/sizeof(floatX) : 4*BTC, 4*C*C/sizeof(floatX)); + int reuse_every_n = 0; int shards = 1; int num_gpu = multi_gpu_config.num_processes; @@ -225,9 +239,10 @@ void gpt2_allocate(GPT2 *model) { model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); } - model->multiuse.local_scratch = add_tensor_spec("local_scratch", BTC, 1, DType::FP32, -1, REUSED_MEMORY); // todo - is this avoidable (or oversized)? - model->multiuse.output_scratch = add_tensor_spec("output_fpx", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); - model->multiuse.output_scratch_fp32 = add_tensor_spec("output_fp32", output_size / 2, 1, DType::FP32, model->multiuse.output_scratch, REUSED_MEMORY); + model->multiuse.local_scratch = add_tensor_spec("scratch_x", local_scratch_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.local_scratch_fp32 = add_tensor_spec("scratch_32", local_scratch_size / sizeof(floatX), 1, DType::FP32, model->multiuse.local_scratch, REUSED_MEMORY); + model->multiuse.output_scratch = add_tensor_spec("out_scratch_x", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.output_scratch_fp32 = add_tensor_spec("out_scratch_32", output_size / sizeof(floatX), 1, DType::FP32, model->multiuse.output_scratch, REUSED_MEMORY); // 3) activations ActivationTensors* spec = &model->acts; @@ -299,7 +314,7 @@ void gpt2_allocate(GPT2 *model) { if (UNIQUE_TENSOR_MEMORY) { TENSOR_SPECS (encoded, 1, BTC, GRADIENT | EMBEDDING); TENSOR_SPECS (output, 1, output_size, GRADIENT | EMBEDDING); - TENSOR_SPECS_LOWP(lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); + TENSOR_SPECS (lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); TENSOR_SPECS_LOWP(ln1, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); TENSOR_SPECS (atty, L, BTC, GRADIENT | TENSOR_2D); TENSOR_SPECS (residual2, L, BTC, GRADIENT | RESIDUAL); @@ -322,7 +337,7 @@ void gpt2_allocate(GPT2 *model) { spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT | EMBEDDING); // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) - spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); @@ -511,9 +526,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w update_absmax(tensor); } - int absmax_block_size = 256; - int num_blocks = (num_param_tensors + absmax_block_size - 1) / absmax_block_size; - update_scale_descale_kernel<<>>(num_param_tensors); + update_scales_from_absmax(); offset = 0; for (int i = 0; i < num_param_tensors; i++) { @@ -702,7 +715,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN - attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C, main_stream); #else attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif @@ -756,15 +769,6 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) { NVTX_RANGE_FN(); - - - - cudaCheck(cudaMemsetAsync(model->tensor_memory[PARAMETER_GRAD], 0, tensors_bytes[PARAMETER_GRAD], main_stream)); - - - - #if 0 - // convenience shortcuts (size_t instead of int so that pointer arithmetics don't overflow) const size_t B = model->batch_size; const size_t T = model->seq_len; @@ -793,11 +797,12 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - tensorFP32 scratchF = MULTI_L(local_scratch, 0); + tensorFP32 scratchF_HUGE = MULTI_L(output_scratch_fp32, 0); // Largest buffer imaginable (max of output & everything else) tensorX scratchX_HUGE = MULTI_L(output_scratch, 0); + tensorFP32 scratchF = MULTI_L(local_scratch_fp32, 0); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) + tensorX scratchX = MULTI_L(local_scratch, 0); // backward pass: go in the reverse order of the forward pass, and call backward() functions - // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) // this was done in the fused classifier kernel as last step of forward pass // technically that is a small, inline backward() pass of calculating @@ -815,30 +820,29 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT(fch_gelu), ACT(fch)); + gelu_forward(ACT(fch_gelu), ACT(fch)); } - matmul_backward(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, B*T, 4*C, C, ACT(fch), model->gelu_fusion); + matmul_backward_fp8(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, scratchF_HUGE, B*T, 4*C, C, ACT(fch)); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } - matmul_backward(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, B*T, C, 4 * C); + matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); - matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); + + // AGRAD(atty) is BF16, AGRAD(residual2) is BF16, ACT(atty) is BF16, PARAM(attprojw) is BF16... ==> 100% BF16 ==> keep BF16 for now! + matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); #ifdef ENABLE_CUDNN - attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C); + attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C, main_stream); #else - // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = ACT(atty); - floatX* buffer_b = ACT(fch); - attention_backward(AGRAD(qkvr), buffer_b, scratchX_HUGE, buffer_a, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); + attention_backward(AGRAD(qkvr), scratchX, scratchX_HUGE, AGRAD(atty), ACT(qkvr), ACT(att), B, T, C, NH); #endif if(model->recompute >= 2) { layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } - matmul_backward(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, B*T, C, 3 * C); + matmul_backward_fp8(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, scratchF_HUGE, B*T, C, 3 * C); layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. @@ -917,7 +921,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int } else { model->mean_loss = -1.f; // no loss available yet } -#endif } float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { @@ -981,32 +984,10 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo bool init_state = model->init_state; if(init_state) { model->init_state = false; - NvtxRange rng("InitOpt"); cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_M], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); } - int absmax_block_size = 256; - int absmax_num_blocks = (num_tensor_specs + absmax_block_size - 1) / absmax_block_size; - -/* - printf("--------------\n"); - update_scale_descale_kernel<<>>(num_tensor_specs); - - // copy all absmax to CPU and printf if != 1.0f - float* absmax_cpu = (float*)malloc(num_tensor_specs * sizeof(float)); - float* scale_cpu = (float*)malloc(num_tensor_specs * 2 * sizeof(float)); - cudaCheck(cudaMemcpy(absmax_cpu, gpu_absmax_memory, num_tensor_specs * sizeof(float), cudaMemcpyDeviceToHost)); - cudaCheck(cudaMemcpy(scale_cpu, gpu_scale_memory, num_tensor_specs * 2 * sizeof(float), cudaMemcpyDeviceToHost)); - for (int i = 0; i < num_tensor_specs; i++) { - if (scale_cpu[i*2] != 1.0f || absmax_cpu[i] != 0.0f) { - printf("scale[%d/%s] ==> %.10f ==> %.10f / %.10f\n", i, tensor_specs[i].name, absmax_cpu[i], scale_cpu[i*2], scale_cpu[i*2+1]); - } - } - - printf("==============\n"); -*/ - // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; @@ -1104,24 +1085,10 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo } */ - // todo - hack - update scale/descale from absmax - update_scale_descale_kernel<<>>(num_tensor_specs); - -/* - float* absmax_cpu2 = (float*)malloc(num_tensor_specs * sizeof(float)); - float* scale_cpu2 = (float*)malloc(num_tensor_specs * 2 * sizeof(float)); - cudaCheck(cudaMemcpy(absmax_cpu2, gpu_absmax_memory, num_tensor_specs * sizeof(float), cudaMemcpyDeviceToHost)); - cudaCheck(cudaMemcpy(scale_cpu2, gpu_scale_memory, num_tensor_specs * 2 * sizeof(float), cudaMemcpyDeviceToHost)); - for (int i = 0; i < num_tensor_specs; i++) { - if (scale_cpu[i*2] != scale_cpu2[i*2] || absmax_cpu[i] != absmax_cpu2[i]) { - printf("scale[%d/%s] ==> absmax: %f -> %f, scale: %f -> %f, descale: %f -> %f\n", i, tensor_specs[i].name, absmax_cpu[i], absmax_cpu2[i], scale_cpu[i*2], scale_cpu2[i*2], scale_cpu[i*2+1], scale_cpu2[i*2+1]); - } - } - free(scale_cpu); - free(absmax_cpu); - free(scale_cpu2); - free(absmax_cpu2); -*/ + // update FP8 scale & descale multipliers based on the absmax + // since we just updated the parameters with the old scale, + // the descale of parameters is "delayed" by one step. + update_scales_from_absmax(); cudaCheck(cudaDeviceSynchronize()); } From 0740c2fe3ffa335c1641fc017e54288d229d2c20 Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 18 Sep 2024 02:07:02 +0000 Subject: [PATCH 19/27] 1st phase of cleanup --- llmc/adamw.cuh | 26 ++- llmc/cuda_common.h | 8 +- llmc/cuda_utils.cuh | 45 +--- llmc/layernorm.cuh | 10 +- llmc/matmul.cuh | 73 +++--- llmc/tensor.cuh | 134 +++++------ train_gpt2.cu | 544 ++++++++++++++++++++++---------------------- 7 files changed, 391 insertions(+), 449 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index f33c773a5..763e0a814 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -16,7 +16,7 @@ __device__ float lerp(float start, float end, float weight) { } // always sizeof(param) <= sizeof(grad) <= sizeof(opt/master) <= sizeof(float) -template +template __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, int spec_id, size_t current_start, size_t current_end, size_t stride, TensorGPU grad_tensor, TensorGPU master_tensor, TensorGPU opt_m_tensor, TensorGPU opt_v_tensor, unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, @@ -42,19 +42,20 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, int next_idx[TT::NUM_TYPES_PARAM] = {0}; int current_idx[TT::NUM_TYPES_PARAM] = {0}; + // this implementation has a stride causing sparse reads/writes and bank conflicts for non-FP8 + // todo - compare performance with a version that uses 128-bit for FP32, 64-bit for BF16, 32-bit for FP8 #pragma unroll for (int i = 0; i < 16; i += 4, offset += 4) { if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, offset); if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, offset, false, true); if (current_idx[PARAMETER_OPT_M] == 0) opt_m128 = load_tensor128(opt_m_tensor, offset, false,true); if (current_idx[PARAMETER_OPT_V] == 0) opt_v128 = load_tensor128(opt_v_tensor, offset, false, true); - if (current_idx[PARAMETER_MASTER] == 0) master128 = load_tensor128(master_tensor, offset, false, true); + if (current_idx[PARAMETER_MASTER] == 0 && use_master_weights) master128 = load_tensor128(master_tensor, offset, false, true); for (int k = 0; k < 4; k++) { float grad = grad128.get(current_idx[PARAMETER_GRAD] + k); float m = opt_m128.get(current_idx[PARAMETER_OPT_M] + k); float v = opt_v128.get(current_idx[PARAMETER_OPT_V] + k); - float master = master128.get(current_idx[PARAMETER_MASTER] + k); m = lerp(grad, m, beta1); v = lerp(grad * grad, v, beta2); @@ -64,15 +65,18 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, v /= beta2_correction; float old_param; - if (use_master_weights && !master_init_modes) { - old_param = master; + if constexpr (use_master_weights) { + old_param = master128.get(current_idx[PARAMETER_MASTER] + k); } else { old_param = param128.get(current_idx[PARAMETER] + k); } + float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random); float new_param = out_param128.get(current_idx[PARAMETER] + k); - out_master128.set(current_idx[PARAMETER_MASTER] + k, param); + if constexpr (use_master_weights) { + out_master128.set(current_idx[PARAMETER_MASTER] + k, param); + } } next_idx[PARAMETER] = (i + 4) % (16 / sizeof(Tparam)); next_idx[PARAMETER_GRAD] = (i + 4) % (16 / sizeof(Tgrad)); @@ -97,7 +101,7 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, return idx; } -template +template __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, @@ -146,7 +150,7 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, if (specs[spec_id].data_type == DType::FP32) { TensorGPU param_tensor = specs[spec_id]; - idx = adamw_update_part( + idx = adamw_update_part( param_tensor, idx, spec_id, current_start, current_end, stride, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, seed, num_params_tensors, num_parameters, num_opt_parameters, @@ -154,7 +158,7 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, eps, wd, grad_scale, t); } else if (specs[spec_id].data_type == DType::BF16) { TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; - idx = adamw_update_part( + idx = adamw_update_part( param_tensor, idx, spec_id, current_start, current_end, stride, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, seed, num_params_tensors, num_parameters, num_opt_parameters, @@ -162,14 +166,14 @@ __global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, eps, wd, grad_scale, t); } else if (specs[spec_id].data_type == DType::FP8E4M3) { TensorGPU<__nv_fp8_e4m3> param_tensor = specs[spec_id]; - idx = adamw_update_part( + idx = adamw_update_part( param_tensor, idx, spec_id, current_start, current_end, stride, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, seed, num_params_tensors, num_parameters, num_opt_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, grad_scale, t); } else { - assert(false); // TODO + assert(false); // TODO (no FP16 to avoid compile time increase but it'd be trivial to add) } } } diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 7e5d265e3..eedd923f0 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -99,13 +99,13 @@ typedef __nv_bfloat16 floatX; #endif #if defined(ENABLE_FP8) -typedef __nv_fp8_e4m3 float8e4; -typedef __nv_fp8_e5m2 float8e5; +typedef __nv_fp8_e4m3 float8; +typedef __nv_fp8_e5m2 grads8; #define DTYPE_FP8E4 DType::FP8E4M3 #define DTYPE_FP8E5 DType::FP8E5M2 #else -typedef floatX float8e4; -typedef floatX float8e5; +typedef floatX float8; +typedef floatX grads8; #define DTYPE_FP8E4 DTYPE_FLOATX #define DTYPE_FP8E5 DTYPE_FLOATX #endif diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index ca16d2174..8301143c2 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -201,6 +201,8 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa } else if constexpr (std::is_same::value) { // CUDA doesn't have round down/up instructions for FP8 (in SW or HW) so we do it ourselves // ARM-Intel-NVIDIA style FP8 E4M3 (different for AMD-Graphcore-Qualcomm format!) + // tried this approach to avoid fake_fp8 bug (didn't help), keeping it for now... + // todo: compare perf & accuracy to bit shifting method (do exhaustive testing) float low = in; float high = in; @@ -228,12 +230,11 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa } // ---------------------------------------------------------------------------- -// todo - stochastic is bugged, spent hours debugging, no idea why backwards is so broken with it __device__ float fake_fp8(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { +#ifdef FAKE_FP8 unsigned int random_number; if (faking && scale != 1.0f) { - assert(scale == 1.0f/descale || scale == 1.0f); - + assert(scale == 1.0f/descale || descale == 1.0f/scale || scale == 1.0f); if (stochastic) { unsigned int clock, laneid; asm volatile("mov.u32 %0, %%clock;" : "=r"(clock)); @@ -243,54 +244,20 @@ __device__ float fake_fp8(bool faking, float input, float scale, float descale, if (mode_e5) { __nv_fp8_e5m2 value_fp8 = __nv_fp8_e5m2(input * scale); - if (false) { - //stochastic_rounding(input * scale, value_fp8, random_number); - } return ((float)value_fp8) * descale; - } else { __nv_fp8_e4m3 value_fp8 = __nv_fp8_e4m3(input * scale); if (stochastic) { - // BUGGED - spent 6+ hours debuggin this, and at this point, I genuinely suspect a compiler bug *sigh* + // BUGGED(?) - spent 6+ hours debugging and I genuinely suspect a compiler bug *sigh* stochastic_rounding(input * scale, value_fp8, random_number); } return ((float)value_fp8) * descale; } } +#endif return input; } -// ---------------------------------------------------------------------------- -// Copy, cast functions - -// device functions and the kernel to cast data between types -template -__device__ Td cast_value(Ts val); - -template<> -__device__ float cast_value(float val) { - return val; -} - -template<> -__device__ float cast_value(half val) { - return __half2float(val); -} - -template<> -__device__ float cast_value(__nv_bfloat16 val) { - return __bfloat162float(val); -} - -template -__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - // need to try grid stride looping for more perf later - if (idx < n) { - dst[idx + stride_dst * blockIdx.y] = cast_value(src[idx + stride_src * blockIdx.y]); - } -} - // ---------------------------------------------------------------------------- // Warp/Block communication primitives diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 37f5e00aa..730c281fe 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -75,7 +75,7 @@ __global__ void layernorm_forward_kernel6(TensorGPU out, tensorFP32 mean, ten out128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -template +template __global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, const tensorX inp1, const TensorGPU inp2, const tensorX weight, const tensorX bias, @@ -141,7 +141,7 @@ __global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); } -template +template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch_, TensorGPU dout, tensorX inp, tensorX weight, tensorFP32 mean, tensorFP32 rstd, @@ -381,7 +381,7 @@ void layernorm_forward(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, launch_layernorm_kernel(layernorm_forward_kernel6, N, C, stream, out, mean, rstd, inp, weight, bias); } -template +template void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { @@ -389,7 +389,7 @@ void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP3 launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } -template +template void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, const TensorGPU dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, int BT, int C, cudaStream_t stream=main_stream) { @@ -401,7 +401,7 @@ void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, ten size_t shared_mem_size = (2 * rounded_C + 2 * (block_size - 32) * f128::size) * sizeof(float); cudaCheck(cudaMemsetAsync(scratch, 0, 1 * sizeof(float), stream)); // only need to reset the flag to 0 - if (dinp_old == null_tensorX) { + if (dinp_old.is_null()) { layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); } else { layernorm_backward_kernel10<<>>(dinp_new, dinp_old, dweight, dbias, scratch, dout, inp, weight, mean, rstd, BT, C); diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index db376ebb8..2c3c2de55 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -222,6 +222,10 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); + #ifdef FAKE_FP8 + update_absmax(d, false); // fake FP8 requires the absmax to work + #endif + // cleanups cublasCheck(cublasLtMatmulPreferenceDestroy(preference)); cublasCheck(cublasLtMatmulDescDestroy(operationDesc)); @@ -232,9 +236,8 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX cudaCheck(cudaGetLastError()); } -// Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c -// https://docs.nvidia.com/cuda/cublas/#cublasltmatmul -template +#ifdef ENABLE_FP8 +template void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU b, const tensorX bias, int m, int n, int k, cudaStream_t stream=main_stream, bool accumulate=false, bool backward=false) @@ -255,10 +258,10 @@ void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU // define matrix layouts cublasLtMatrixLayout_t ALayout, BLayout, CLayout, DLayout; - cublasDataType_t typeA = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; - cublasDataType_t typeB = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeA = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; + cublasDataType_t typeB = std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2; cublasDataType_t typeD = std::is_same::value ? CUBLAS_LOWP : - (std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2); + (std::is_same::value ? CUDA_R_8F_E4M3 : CUDA_R_8F_E5M2); cublasCheck(cublasLtMatrixLayoutCreate(&ALayout, typeA, k, m, k)); // always transposed for FP8 cublasCheck(cublasLtMatrixLayoutCreate(&BLayout, typeB, k, n, k)); // never transposed for FP8 @@ -322,30 +325,22 @@ void matmul_cublaslt_fp8(TensorGPU d, const TensorGPU a, const TensorGPU cublasCheck(cublasLtMatrixLayoutDestroy(DLayout)); cudaCheck(cudaGetLastError()); } +#endif -template +template // small wrapper around matmul_cublaslt for the forward pass (keeping historical order of arguments) -void matmul_forward_cublaslt(TensorGPU out, - TensorGPU inp, TensorGPU weight, tensorX bias, - int BT, int C, int OC, +void matmul_forward(TensorGPU out, + TensorGPU inp, TensorGPU weight, tensorX bias, int BT, int C, int OC, TensorGPU pre_gelu=TensorGPU(), int gelu_fusion=1, cudaStream_t stream=main_stream) { - // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) - if constexpr (sizeof(Tin) == 1) { + matmul_cublaslt_fp8(pre_gelu.enabled() ? pre_gelu : out, weight, inp, bias, OC, BT, C, stream, false, false); if (pre_gelu.enabled()) { - matmul_cublaslt_fp8(pre_gelu, weight, inp, bias, OC, BT, C, stream, false, false); - if constexpr (sizeof(Tout) == 1) { // todo - hack to avoid error for case we will never see - gelu_forward(out, pre_gelu, stream); - } - } else { - matmul_cublaslt_fp8(out, weight, inp, bias, OC, BT, C, stream, false, false); + gelu_forward(out, pre_gelu, stream); } } else { - if (gelu_fusion < 1 && pre_gelu.enabled()) { + if (pre_gelu.enabled() && gelu_fusion < 1) { matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); - if constexpr (sizeof(Tout) == sizeof(float8e4)) { - gelu_forward(out, pre_gelu, stream); // todo - same hack - } + gelu_forward(out, pre_gelu, stream); } else { matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); } @@ -383,12 +378,12 @@ void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensorFP32 scrat } } -template +template void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, TensorGPU dout, tensorFP8e4 inp, tensorFP8e4 weight, tensorFP32 scratch1_big, tensorFP32 scratch2_huge, int BT, int C, int OC, - tensorFP8e4 pre_gelu_activation=null_tensorFP8E4, cudaStream_t stream=main_stream) { + tensorFP8e4 pre_gelu_activation=tensorFP8e4(), cudaStream_t stream=main_stream) { #ifndef ENABLE_FP8 // FP8 is not enabled so we use the regular floatX matmul path matmul_backward(dinp, dweight, dbias, dout, inp, weight, scratch1_big, BT, C, OC, pre_gelu_activation, 1, stream); @@ -400,48 +395,44 @@ void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, // IMPORTANT: inp is allowed to be the same buffer as scratch2_huge (e.g. for fch_gelu) // ==> this MUST be done first and write to scratch1_big! // transpose input - TensorGPU inp_fp8_transposed = inp; - inp_fp8_transposed.data_ptr = (float8e4*)scratch1_big.data_ptr; - transpose_simple(inp_fp8_transposed, inp, C, BT, stream); + TensorGPU inp_fp8_transposed = inp; + inp_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; + transpose_simple(inp_fp8_transposed, inp, C, BT, stream); // convert dout to FP8e5 if it is not already, and transpose it // the buffer is guaranteed to be at least twice as big as 4BTC, so we can split it in 2 // todo - merge conversion and tranposition like we did before? - TensorGPU dout_fp8; - if constexpr (std::is_same::value) { - dout_fp8 = dout; - } else { - dout_fp8 = *(TensorGPU*)&dout; - dout_fp8.data_ptr = (float8e5*)(scratch2_huge.data_ptr); + TensorGPU dout_fp8 = *(TensorGPU*)&dout; + if constexpr (std::is_same::value == false) { + dout_fp8.data_ptr = (grads8*)(scratch2_huge.data_ptr); copy_advanced(dout_fp8, dout, stream); } - TensorGPU dout_fp8_transposed = dout_fp8; - dout_fp8_transposed.data_ptr = (float8e5*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); + TensorGPU dout_fp8_transposed = dout_fp8; + dout_fp8_transposed.data_ptr = (grads8*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); transpose_simple(dout_fp8_transposed, dout_fp8, OC, BT, stream); // GEMM 1: dweight, inp_fp8_transposed, dout_fp8_transposed matmul_cublaslt_fp8(dweight, inp_fp8_transposed, dout_fp8_transposed, null_tensorX, C, OC, BT, stream, false, true); // transpose weight (todo: option to cache this / do it at optimizer time) - TensorGPU weight_fp8_transposed = weight; - weight_fp8_transposed.data_ptr = (float8e4*)scratch1_big.data_ptr; + TensorGPU weight_fp8_transposed = weight; + weight_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; transpose_simple(weight_fp8_transposed, weight, C, OC, stream); matmul_cublaslt_fp8(dinp, weight_fp8_transposed, dout_fp8, null_tensorX, C, BT, OC, stream, false, true); - if (pre_gelu_activation.data_ptr) { + if (pre_gelu_activation.enabled()) { gelu_backward(dinp, dinp, pre_gelu_activation, stream); } #endif } -template void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, tensorX dout, tensorX inp, tensorX weight, tensorFP32 dbias_scratch, int BT, int C, int OC, - TensorGPU pre_gelu_activation=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { + tensorX pre_gelu_activation=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); matmul_backward_bias(dbias, dout, dbias_scratch, BT, OC, stream); @@ -450,7 +441,7 @@ void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, gelu_fusion >= 2 ? pre_gelu_activation : null_tensorX, true); // backward GELU (if it wasn't fused into the matmul above) - if (gelu_fusion < 2 && pre_gelu_activation.enabled()) { + if ( pre_gelu_activation.enabled() && gelu_fusion < 2) { gelu_backward(dinp, dinp, pre_gelu_activation, stream); } diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 395099884..8db76035e 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -23,7 +23,7 @@ enum TFlags : uint8_t { NONE=0, GRADIENT=1, REUSED_MEMORY=2, - TENSOR_2D=4, // used for matmul *outputs* only, not inputs (+weights) + TENSOR_2D=4, // used for matmul weights and activation outputs only (not inputs or gradients) BIAS=8, LAYERNORM=16, RESIDUAL=32, @@ -34,7 +34,7 @@ enum TFlags : uint8_t { // ---------------------------------------------------------------------------- // forward declarations & extern variables defined in the training file struct TensorSpec; -constexpr size_t MAX_TENSORS = 16*1024; +constexpr size_t MAX_TENSORS = 32768; // only increases CPU memory usage if unused constexpr size_t MAX_ABSMAX_HISTORY = 32; // todo - command line option extern TensorSpec tensor_specs[MAX_TENSORS]; @@ -44,8 +44,8 @@ extern size_t tensors_bytes[TT::COUNT]; extern size_t tensors_elements[TT::COUNT]; extern int num_tensor_specs; -extern TT current_tensor_type; -extern int current_absmax_index; +extern TT current_tensor_type; // todo - avoid having this somehow? +extern int current_absmax_index; // todo - move into model struct? extern float* gpu_scale_memory; extern unsigned int* gpu_absmax_memory; @@ -61,11 +61,13 @@ __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; #define AGRAD_L(x,layer) get_tensor(model->acts_grads.x, MULTIUSE, layer) #define PARAM_L(x,layer) get_tensor(model->params[PARAMETER].x, PARAMETER, layer) #define PGRAD_L(x,layer) get_tensor(model->params[PARAMETER_GRAD].x, PARAMETER_GRAD, layer) -#define ACT(x) ACT_L(x,l) -#define MULTI(x) MULTI_L(x,l) -#define AGRAD(x) AGRAD_L(x,l) -#define PARAM(x) PARAM_L(x,l) -#define PGRAD(x) PGRAD_L(x,l) +#define ACT(x) ACT_L(x,l) +#define MULTI(x) MULTI_L(x,l) +#define AGRAD(x) AGRAD_L(x,l) +#define PARAM(x) PARAM_L(x,l) +#define PGRAD(x) PGRAD_L(x,l) +#define ACT_0(x) ACT_L(x,0) +#define MULTI_0(x) MULTI_L(x,0) // ---------------------------------------------------------------------------- @@ -78,12 +80,8 @@ struct TensorGPU { int id = -1; static constexpr bool no_scaling = (sizeof(ElementType) != 1); - bool is_null() const { - return (data_ptr == NULL); - } - bool enabled() const { - return (data_ptr != NULL); - } + bool is_null() const { return (data_ptr == NULL); } + bool enabled() const { return (data_ptr != NULL); } static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { TensorGPU tmp; @@ -116,7 +114,6 @@ struct TensorGPU { #ifdef FAKE_FP8 disable_scaling = true; #endif - ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -129,7 +126,6 @@ struct TensorGPU { #ifdef FAKE_FP8 disable_scaling = true; #endif - ElementType* __restrict__ data_ptr_restricted = data_ptr; float* __restrict__ scale_ptr_restricted = scale_descale_ptr; @@ -144,7 +140,6 @@ typedef TensorGPU tensorX; typedef TensorGPU tensorFP32; typedef TensorGPU tensorFP16; typedef TensorGPU tensorBF16; - #ifdef ENABLE_FP8 typedef TensorGPU<__nv_fp8_e4m3> tensorFP8e4; typedef TensorGPU<__nv_fp8_e5m2> tensorFP8e5; @@ -152,25 +147,23 @@ typedef TensorGPU<__nv_fp8_e5m2> tensorFP8e5; typedef TensorGPU tensorFP8e4; typedef TensorGPU tensorFP8e5; #endif - -extern TensorGPU null_tensorFP32; extern TensorGPU null_tensorX; -extern TensorGPU null_tensorFP8E4; -extern TensorGPU null_tensorFP8E5; // ---------------------------------------------------------------------------- struct TensorSpec { + int id; char* ptr; + + char name[16]; + TT tensor_type; + DType data_type; + int flags; + size_t offset; // into base pointer size_t num_elements; // per shard - int id; short num_shards; short remaining_layers; - DType data_type; - TT tensor_type; - int flags; - char name[16]; template __host__ __device__ operator T*() const { @@ -199,9 +192,8 @@ struct TensorSpec { // ---------------------------------------------------------------------------- -// debug helper function +// debug helper function (enable in get_tensor() for extreme logging) void print_tensor_elements(int tensor_id) { - printf("Printing tensor %d\n", tensor_id); TensorSpec spec = tensor_specs[tensor_id]; size_t num_elements = spec.num_elements; const char* tensor_name = spec.name; @@ -264,30 +256,27 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, assert(num_tensor_specs < 16*1024); assert((total_elements % num_shards) == 0); TensorSpec* spec = &tensor_specs[num_tensor_specs]; + + spec->id = num_tensor_specs; strncpy(spec->name, name, 15); spec->name[15] = 0; + spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; + spec->data_type = data_type; + spec->flags = flags; - spec->id = num_tensor_specs; spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; spec->remaining_layers = 0; - spec->data_type = data_type; - spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; - spec->flags = flags; - tensors_elements[spec->tensor_type] += spec->num_elements; if (copy_offset_from >= 0) { TensorSpec base_spec = tensor_specs[copy_offset_from]; + base_spec.flags |= (flags & REUSED_MEMORY); spec->offset = base_spec.offset; + size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); size_t new_tensor_bytes = spec->num_elements * sizeof_dtype(data_type); - if (base_spec.tensor_type != spec->tensor_type) { - printf("ERROR: tensor_type for %s: %d vs %d\n", spec->name, (int)base_spec.tensor_type, (int)spec->tensor_type); - assert(false); - } - base_spec.flags |= (flags & REUSED_MEMORY); - assert(base_spec.tensor_type == spec->tensor_type); assert(new_tensor_bytes <= original_tensor_bytes); + assert(spec->tensor_type == base_spec.tensor_type); } else { spec->offset = tensors_bytes[spec->tensor_type]; tensors_bytes[spec->tensor_type] += spec->num_elements * sizeof_dtype(data_type); @@ -295,26 +284,26 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, tensors_start[spec->tensor_type] = num_tensor_specs; } } + + tensors_elements[spec->tensor_type] += spec->num_elements; return num_tensor_specs++; } int add_layer_specs(int num_layers, const char* name, size_t total_elements, size_t num_shards, DType data_type, - int copy_offset_from=-1, int flags=TFlags::NONE, bool copy_per_layer=false, - int reuse_every_n_layers=0, TT tensor_type=TT::DEFAULT) { + int copy_offset_from=-1, int flags=TFlags::NONE, int reuse_every_n_layers=0, + TT tensor_type=TT::DEFAULT) { int first_tensor_id = num_tensor_specs; if (reuse_every_n_layers > 0 && num_layers > 1) { flags |= REUSED_MEMORY; } for (int l = 0; l < num_layers; l++) { char layer_name[16]; - assert(snprintf(layer_name, 16, "%s_%d", name, l) >= 0); + assert(snprintf(layer_name, 15, "%s_%d", name, l) >= 0); if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); } + int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); - if (copy_per_layer) { - copy_offset_from++; - } tensor_specs[spec].remaining_layers = num_layers - (l + 1); } return first_tensor_id; @@ -322,6 +311,7 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz // ---------------------------------------------------------------------------- +// todo - should this be moved elsewhere? maybe to copy_and_fp8.h? __global__ void update_scale_descale_kernel(int num_tensor_specs) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid >= num_tensor_specs) return; @@ -359,16 +349,6 @@ __global__ void update_scale_descale_kernel(int num_tensor_specs) { } } - #ifdef FAKE_FP8 - // with real FP8, we rely on tensor128 not scaling when sizeof(T)>1, but that doesn't work with fake FP8 - // so we prevent scaling for the things we know we don't want to scale - // this might not match what we have in the real FP8 implementation, but allows for quick experimentation - if ((tensor_specs_ptr[tid].flags & TFlags::RESIDUAL) || (tensor_specs_ptr[tid].flags & TFlags::EMBEDDING)) { - scale = 1.0f; - descale = 1.0f; - } - #endif - // Update gpu_scale_memory // todo: descale should be delayed by one step for parameters (see comment in gpt2_update). gpu_scale_memory_ptr[tid * 2] = scale; @@ -398,7 +378,8 @@ private: bool wrote_data = false; bool wrote_absmax = false; int id = -1; - // fake fp8 mode + + // fake fp8 mode (ignored without FAKE_FP8 define) bool faking_fp8 = false; bool mode_e5 = false; @@ -438,31 +419,20 @@ public: } __device__ void store(size_t offset, bool cache_streaming=false) { - if (cache_streaming) { - store128cs(data_ptr + offset, data128); - } else { - store128(data_ptr + offset, data128); - } + if (cache_streaming) store128cs(data_ptr + offset, data128); + else store128(data_ptr + offset, data128); wrote_data = true; } template __device__ void store_same_length(size_t offset, bool cache_streaming=false) { - if (cache_streaming) { - store128_same_length_cs(data_ptr + offset, data128); - } else { - store128_same_length(data_ptr + offset, data128); - } + if (cache_streaming) store128_same_length_cs(data_ptr + offset, data128); + else store128_same_length(data_ptr + offset, data128); wrote_data = true; } - __device__ const Packed128& get128() const { - return data128; - } - - __device__ Packed128& get128() { - return data128; - } + __device__ const Packed128& get128() const { return data128; } + __device__ Packed128& get128() { return data128; } // call this manually if e.g. you use set_scalar() to update the tensor // todo - in the future, this could support more than just absmax @@ -484,7 +454,7 @@ public: } __device__ void set_stochastic(int index, float value, unsigned int random_number, - bool rotate_by_index=true, bool non_deterministic_rng=false) { + int rotate_by_index=10, bool non_deterministic_rng=false) { float scaled_value = value * (scaling ? scale : 1.0f); // rotate the random number by the index so we can cheaply reuse the same RNG @@ -493,9 +463,11 @@ public: // x10 is used so that it never repeats for indices [0;15] with a minimum difference of 2 etc. if (rotate_by_index) { assert(index < 16); // >=16 would repeat and be extremely bad RNG - random_number = __funnelshift_l(random_number, random_number, index * 10); + random_number = __funnelshift_l(random_number, random_number, index * rotate_by_index); } - // RNG without a seed from the host for quick testing, but obviously not deterministic! + + // RNG without a seed from the host for quick testing, but obviously not deterministic + // can be forced to get slightly different runs from which you can calculate an average #ifdef FORCE_NON_DETERMINISM non_deterministic_rng = true; #endif @@ -510,10 +482,10 @@ public: add_value_stats(value, data128[index]); } - // if update_absmax returns true, we can skip __syncthreads() in some kernels + // return value: if true, we can skip __syncthreads() in the calling function as we have just done one __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { #ifdef FAKE_FP8 - if (id < 0 || absmax_ptr == NULL || !faking_fp8) { + if (absmax_ptr == NULL || !faking_fp8) { return false; } forced = true; @@ -564,7 +536,8 @@ public: } return true; } - __device__ void update_absmax_auto(int dimensions=1, bool exit=false) { + + __device__ void update_absmax(int dimensions=1, bool exit=false) { if (dimensions == 1) { update_absmax(threadIdx.x, blockDim.x, exit); } else if (dimensions == 2) { @@ -574,6 +547,7 @@ public: blockDim.x * blockDim.y * blockDim.z, exit); } } + __device__ void skip_absmax() { wrote_absmax = true; } diff --git a/train_gpt2.cu b/train_gpt2.cu index f7920c0f9..f5565c09d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -39,7 +39,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/cuda_common.h" // defines: // Packed128, f128, x128 -// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel +// warpReduceSum, warpReduceMax, blockReduce #include "llmc/cuda_utils.cuh" // ... todo ... #include "llmc/tensor.cuh" @@ -97,11 +97,7 @@ TT current_tensor_type = TT::PARAMETER; int current_absmax_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; - -TensorGPU null_tensorFP32 = {0}; TensorGPU null_tensorX = {0}; -TensorGPU null_tensorFP8E4 = {0}; -TensorGPU null_tensorFP8E5 = {0}; // ---------------------------------------------------------------------------- // GPT-2 model definition @@ -119,10 +115,8 @@ typedef struct { typedef struct { int bt4c; // (B, T, 4*C) int btc; // (B, T, C) - int local_scratch; // big, see local_scratch_size below - int local_scratch_fp32; // same memory as above - int output_scratch; // huge, see output_size below - int output_scratch_fp32; // same memory as above + int local_scratch, local_scratch_fp32; // big, see local_scratch_size below + int output_scratch, output_scratch_fp32; // huge, see output_size below } MultiuseTensors; typedef struct { @@ -153,7 +147,6 @@ typedef struct { float mean_loss = -1.0f; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps float* accumulated_mean_loss = NULL; // GPU buffer used to accumulate loss across micro-steps float* cpu_losses = NULL; // CPU buffer to copy the losses to, allocated with cudaMallocHost - bool init_state = true; // set to true if master weights need to be initialized int use_master_weights = 1; // keep master weights copy in float for optim update? 0|1 int gelu_fusion = 0; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward) int recompute = 0; // recompute gelu | layernorm forward during model backward? 0|1|2 @@ -165,9 +158,9 @@ typedef struct { unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights } GPT2; -#define TENSOR_SPECS(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype, -1, flags, false, reuse_every_n) -#define TENSOR_SPECS_LOWP(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype_lowp, -1, flags, false, reuse_every_n) -#define TENSOR_SPECS_FP32(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, DType::FP32, -1, flags, false, reuse_every_n) +#define TENSOR_SPECS(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype, -1, flags, reuse_every_n) +#define TENSOR_SPECS_LOWP(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, dtype_lowp, -1, flags, reuse_every_n) +#define TENSOR_SPECS_FP32(name, layers, dim, flags) spec->name = add_layer_specs(layers, #name, dim, shards, DType::FP32, -1, flags, reuse_every_n) void gpt2_allocate(GPT2 *model) { size_t Vp = model->config.padded_vocab_size; @@ -196,21 +189,19 @@ void gpt2_allocate(GPT2 *model) { int reuse_every_n = 0; int shards = 1; - int num_gpu = multi_gpu_config.num_processes; - int shards_opt = (multi_gpu_config.zero_stage >= 1) ? num_gpu : 1; - int shards_grad = (multi_gpu_config.zero_stage >= 2) ? num_gpu : 1; + int shards_opt = (multi_gpu_config.zero_stage >= 1) ? multi_gpu_config.num_processes : 1; + int shards_grad = (multi_gpu_config.zero_stage >= 2) ? multi_gpu_config.num_processes : 1; // 1) parameters & optimizer state for (int t = PARAMETER; t <= PARAMETER_MASTER; t++) { - DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; - DType dtype_lowp = (t == PARAMETER) ? DTYPE_FP8E4 : ((t == PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32); + if (t == PARAMETER_MASTER && !model->use_master_weights) continue; current_tensor_type = (TT)t; ParameterTensors* spec = &model->params[t]; shards = (t == PARAMETER) ? 1 : (t == PARAMETER_GRAD) ? shards_grad : shards_opt; - if (t == PARAMETER_MASTER && !model->use_master_weights) { - continue; - } + + DType dtype = (t <= PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32; + DType dtype_lowp = (t == PARAMETER) ? DTYPE_FP8E4 : ((t == PARAMETER_GRAD) ? DTYPE_FLOATX : DType::FP32); TENSOR_SPECS (wte, 1, Vp * C, TENSOR_2D | EMBEDDING); TENSOR_SPECS (wpe, 1, maxT * C, TENSOR_2D | EMBEDDING); @@ -230,43 +221,46 @@ void gpt2_allocate(GPT2 *model) { TENSOR_SPECS (lnfb, 1, C, LAYERNORM | BIAS); } + model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; + model->num_parameters = tensors_elements[TT::PARAMETER]; + // 2) multiuse & scratch tensors current_tensor_type = MULTIUSE; - if (UNIQUE_TENSOR_MEMORY) { - model->multiuse.bt4c = -1; - model->multiuse.btc = -1; - } else { + model->multiuse.bt4c = model->multiuse.btc = -1; + if (UNIQUE_TENSOR_MEMORY == false) { model->multiuse.bt4c = add_tensor_spec("multiuse_bt4c", 4 * BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); model->multiuse.btc = add_tensor_spec("multiuse_btc", BTC, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); } - model->multiuse.local_scratch = add_tensor_spec("scratch_x", local_scratch_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.local_scratch = add_tensor_spec("scratch_X", local_scratch_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.output_scratch = add_tensor_spec("out_scratch_X", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); + model->multiuse.local_scratch_fp32 = add_tensor_spec("scratch_32", local_scratch_size / sizeof(floatX), 1, DType::FP32, model->multiuse.local_scratch, REUSED_MEMORY); - model->multiuse.output_scratch = add_tensor_spec("out_scratch_x", output_size, 1, DTYPE_FLOATX, -1, REUSED_MEMORY); model->multiuse.output_scratch_fp32 = add_tensor_spec("out_scratch_32", output_size / sizeof(floatX), 1, DType::FP32, model->multiuse.output_scratch, REUSED_MEMORY); // 3) activations ActivationTensors* spec = &model->acts; - DType dtype_lowp = DTYPE_FP8E4; // todo FP8 + DType dtype_lowp = DTYPE_FP8E4; DType dtype = DTYPE_FLOATX; shards = 1; // with activation checkpointing, we keep every layer's residual3 for simplicity - // in theory, if we have e.g. 4 layers per checkpoint, we could have 1/4 as many residual3 + // in theory, with e.g. 4 layers per checkpoint, we'd have 1/4 as many residual3 // but that would complicate everything a lot for relatively little benefit... TENSOR_SPECS (residual3, L, BTC, RESIDUAL); reuse_every_n = LAYERS_PER_ACTIVATION_CHECKPOINT; - assert(!reuse_every_n || (L % reuse_every_n) == 0); + assert(!reuse_every_n || !(L % reuse_every_n)); TENSOR_SPECS (encoded, 1, BTC, EMBEDDING); TENSOR_SPECS (qkvr, L, 3 * BTC, TENSOR_2D); + TENSOR_SPECS (atty, L, BTC, 0); + TENSOR_SPECS (residual2, L, BTC, RESIDUAL); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, TENSOR_2D); + #ifdef ENABLE_CUDNN TENSOR_SPECS_FP32(att, L, NH * B * T, 0); #else TENSOR_SPECS (att, L, NH * B * T * T, 0); #endif - TENSOR_SPECS (atty, L, BTC, 0); - TENSOR_SPECS (residual2, L, BTC, RESIDUAL); - TENSOR_SPECS_LOWP(fch, L, 4 * BTC, TENSOR_2D); // optionally reuse the same activation buffer at each layer and re-compute the gelu during backward // very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size @@ -283,10 +277,6 @@ void gpt2_allocate(GPT2 *model) { spec->ln1 = add_layer_specs(L, "ln1", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); spec->ln2 = add_layer_specs(L, "ln2", BTC, shards, dtype_lowp, model->multiuse.btc, LAYERNORM | REUSED_MEMORY); } - TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); if (UNIQUE_TENSOR_MEMORY) { TENSOR_SPECS (attproj, L, BTC, TENSOR_2D); @@ -299,9 +289,13 @@ void gpt2_allocate(GPT2 *model) { } TENSOR_SPECS (lnf, 1, BTC, LAYERNORM); + TENSOR_SPECS_FP32(losses, 1, BT, 0); + TENSOR_SPECS_FP32(ln1_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln1_rstd, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_mean, L, BT, LAYERNORM | STATS); + TENSOR_SPECS_FP32(ln2_rstd, L, BT, LAYERNORM | STATS); TENSOR_SPECS_FP32(lnf_mean, 1, BT, LAYERNORM | STATS); TENSOR_SPECS_FP32(lnf_rstd, 1, BT, LAYERNORM | STATS); - TENSOR_SPECS_FP32(losses, 1, BT, 0); // 4) activation gradients // note: TENSOR_2D are for the tensors written to by a matmul which are different here @@ -314,22 +308,22 @@ void gpt2_allocate(GPT2 *model) { if (UNIQUE_TENSOR_MEMORY) { TENSOR_SPECS (encoded, 1, BTC, GRADIENT | EMBEDDING); TENSOR_SPECS (output, 1, output_size, GRADIENT | EMBEDDING); - TENSOR_SPECS (lnf, 1, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS_LOWP(ln1, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS (atty, L, BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS (lnf, 1, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS_LOWP(ln1, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS (atty, L, BTC, GRADIENT); TENSOR_SPECS (residual2, L, BTC, GRADIENT | RESIDUAL); - TENSOR_SPECS_LOWP(ln2, L, BTC, GRADIENT | LAYERNORM | TENSOR_2D); - TENSOR_SPECS_LOWP(fch, L, 4 * BTC, GRADIENT | TENSOR_2D); - TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, GRADIENT | TENSOR_2D); + TENSOR_SPECS_LOWP(ln2, L, BTC, GRADIENT | LAYERNORM); + TENSOR_SPECS_LOWP(fch, L, 4 * BTC, GRADIENT); + TENSOR_SPECS_LOWP(fch_gelu, L, 4 * BTC, GRADIENT); TENSOR_SPECS (residual3, L, BTC, GRADIENT | RESIDUAL); TENSOR_SPECS (qkvr, L, 3 * BTC, GRADIENT); } else { spec->output = add_layer_specs(1, "output", output_size, 1, dtype, model->multiuse.output_scratch, GRADIENT | EMBEDDING); int reused_btc = model->acts.residual3 + (L-1); // todo - check if this works with activation checkpointing - spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); - spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM | TENSOR_2D); - spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT | TENSOR_2D); + spec->ln1 = add_layer_specs(L, "ln1", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM); + spec->ln2 = add_layer_specs(L, "ln2", BTC, 1, dtype_lowp, reused_btc, GRADIENT | LAYERNORM); + spec->atty = add_layer_specs(L, "atty", BTC, 1, dtype, reused_btc, GRADIENT); int reused_btc2 = model->acts.lnf; spec->residual2 = add_layer_specs(L, "residual2", BTC, 1, dtype, reused_btc2, GRADIENT | RESIDUAL); @@ -337,74 +331,48 @@ void gpt2_allocate(GPT2 *model) { spec->encoded = add_layer_specs(1, "encoded", BTC, 1, dtype, reused_btc2, GRADIENT | EMBEDDING); // (lnf doesn't need bt4c but it's free at this point unlike the other buffers) - spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM | TENSOR_2D); - spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); - spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT | TENSOR_2D); - spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | TENSOR_2D); + spec->lnf = add_layer_specs(1, "lnf", BTC, 1, dtype, model->multiuse.bt4c, GRADIENT | LAYERNORM); + spec->fch_gelu = add_layer_specs(L, "fch_gelu", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT); + spec->fch = add_layer_specs(L, "fch", 4 * BTC, 1, dtype_lowp, model->multiuse.bt4c, GRADIENT); + spec->qkvr = add_layer_specs(L, "qkvr", 3 * BTC, 1, dtype, model->multiuse.bt4c, GRADIENT); } // allocate a single huge GPU buffer for all the tensors of a given type - cudaCheck(cudaMalloc(&model->tensor_memory[MULTIUSE], tensors_bytes[MULTIUSE])); - cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER], tensors_bytes[PARAMETER])); - cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_GRAD], tensors_bytes[PARAMETER_GRAD])); - cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_OPT_M], tensors_bytes[PARAMETER_OPT_M])); - cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_OPT_V], tensors_bytes[PARAMETER_OPT_V])); - if (model->use_master_weights) { - cudaCheck(cudaMalloc(&model->tensor_memory[PARAMETER_MASTER], tensors_bytes[PARAMETER_MASTER])); + for (int i = 0; i < TT::COUNT; i++) { + if (i == PARAMETER_MASTER && !model->use_master_weights) continue; + cudaCheck(cudaMalloc(&model->tensor_memory[i], tensors_bytes[i])); } - // clear multiuse memory (better safe than sorry) - cudaCheck(cudaMemset(model->tensor_memory[MULTIUSE], 0, tensors_bytes[MULTIUSE])); - // Set the ptr for each tensor spec based on type and offset + // Set the GPU pointer for each tensor spec (so we don't need to know the base and the offset) for (size_t i = 0; i < num_tensor_specs; i++) { TensorSpec* spec = &tensor_specs[i]; - switch (spec->tensor_type) { - case MULTIUSE: - spec->ptr = model->tensor_memory[MULTIUSE] + spec->offset; - break; - default: - assert(spec->tensor_type <= PARAMETER_MASTER); - spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; - } + spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; } - // we are finished creating the tensors specs and can copy them to the GPU (effectively read-only) + // we are finished creating the tensors specs and copy them to the GPU (they are effectively read-only) cudaMalloc((void**)&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); cudaMemcpy(tensor_specs_gpu, tensor_specs, sizeof(TensorSpec) * num_tensor_specs, cudaMemcpyHostToDevice); - //initialise helper variables - model->num_parameters = tensors_elements[TT::PARAMETER]; - model->num_parameters_bytes = tensors_bytes[TT::PARAMETER]; - printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRAD] / (1024*1024)); - printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); printf("number of v bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_V] / (1024*1024)); + printf("number of master weight bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_MASTER] / (1024*1024)); printf("number of act+actgrad+multiuse bytes: %zu MiB\n", tensors_bytes[TT::MULTIUSE] / (1024*1024)); - // ======================= - // allocate_state stuff - // ======================= - // absmax/scaling/descaling buffers for FP8 & Friends + // absmax/scale/descale buffers for FP8 & Friends (scale is initialised via update_scales_from_absmax) + cudaMalloc(&gpu_scale_memory, 2 * num_tensor_specs * sizeof(float)); cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); - // Initialize gpu_scale_memory with 1.0f for all elements (todo - could use cuMemsetD8 but runtime vs driver...) - size_t scale_memory_elements = 2 * num_tensor_specs; - cudaMalloc(&gpu_scale_memory, scale_memory_elements * sizeof(float)); - float* h_scale_memory = (float*)malloc(scale_memory_elements * sizeof(float)); - for (size_t i = 0; i < scale_memory_elements; ++i) { - h_scale_memory[i] = 1.0f; - } - cudaMemcpy(gpu_scale_memory, h_scale_memory, scale_memory_elements * sizeof(float), cudaMemcpyHostToDevice); - free(h_scale_memory); - - // copy to constant buffers + // copy pointers to constant buffers for easy access on the GPU cudaMemcpyToSymbol(tensor_specs_ptr, &tensor_specs_gpu, sizeof(TensorSpec*)); cudaMemcpyToSymbol(gpu_scale_memory_ptr, &gpu_scale_memory, sizeof(float*)); cudaMemcpyToSymbol(gpu_absmax_memory_ptr, &gpu_absmax_memory, sizeof(unsigned int*)); + // ======================= + // allocate_state stuff + // ======================= cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); @@ -416,11 +384,11 @@ void gpt2_allocate(GPT2 *model) { model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups); model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups); + // check available memory and give an estimate of the maximum batch size size_t free, total; cudaCheck(cudaMemGetInfo(&free, &total)); printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024); - // give an estimate of the maximum batch size size_t bytes_per_sequence = tensors_bytes[TT::MULTIUSE] / B; // pessimistic (output buffer etc.) printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024); printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence); @@ -432,6 +400,89 @@ void gpt2_init_common(GPT2 *model) { model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding } +// take a GPU buffer with "num_parameters * sizeof(floatX)" in the required order +// and convert each individual tensor to its desired data type +template +void convert_fixed_parameters(GPT2* model, char* gpu_buffer, size_t fixed_size_bytes) { + size_t offset = 0; + int num_param_tensors = tensors_start[PARAMETER+1]; + + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor = tensor_specs[i]; + tensor.data_ptr = (Tin*)(gpu_buffer + offset); + offset += tensor.num_elements * sizeof(Tin); + update_absmax(tensor); + } + update_scales_from_absmax(); + + offset = 0; + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU tensor_in = tensor_specs[i]; + tensor_in.data_ptr = (Tin*)(gpu_buffer + offset); + offset += tensor_in.num_elements * sizeof(Tin); + + switch (tensor_specs[i].data_type) { + case DType::FP32: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; + case DType::FP16: copy_advanced((TensorGPU<__nv_half>)tensor_specs[i], tensor_in); break; + case DType::BF16: copy_advanced((TensorGPU<__nv_bfloat16>)tensor_specs[i], tensor_in); break; + case DType::FP8E4M3: copy_advanced((TensorGPU<__nv_fp8_e4m3>)tensor_specs[i], tensor_in); break; + } + if (model->use_master_weights) { + size_t master_start = tensors_start[PARAMETER_MASTER]; + TensorGPU master = tensor_specs[i+master_start]; + size_t shard_offset = master.num_elements * (multi_gpu_config.process_rank % tensor_specs[i+master_start].num_shards); + + tensor_in.data_ptr += shard_offset; + copy_advanced(master, tensor_in); + } + } + cudaMemset(gpu_buffer, 0, fixed_size_bytes); +} + +// to convert from variable precision parameters to a single precision (e.g. before checkpointing) +template +void convert_from_fixed_parameters(GPT2* model, char* gpu_buffer) { + size_t offset = 0; + for (int i = 0; i < tensors_start[PARAMETER+1]; i++) { + TensorGPU tensor_out = tensor_specs[i]; + tensor_out.data_ptr = (Tout*)(gpu_buffer + offset); + offset += tensor_out.num_elements * sizeof(Tout); + + switch (tensor_specs[i].data_type) { + case DType::FP32: copy_advanced(tensor_out, (TensorGPU)tensor_specs[i]); break; + case DType::FP16: copy_advanced(tensor_out, (TensorGPU<__nv_half>)tensor_specs[i]); break; + case DType::BF16: copy_advanced(tensor_out, (TensorGPU<__nv_bfloat16>)tensor_specs[i]); break; + case DType::FP8E4M3: copy_advanced(tensor_out, (TensorGPU<__nv_fp8_e4m3>)tensor_specs[i]); break; + } + } +} + + +// helper function to initialise sharded master weights from unsharded weights +template +void init_tensor_shard(TensorGPU out, int i) { + size_t shard_offset = out.num_elements * (multi_gpu_config.process_rank % tensor_specs[out.id].num_shards); + TensorGPU t = tensor_specs[i]; + t.num_elements = out.num_elements; + t.data_ptr += shard_offset; + copy_advanced(out, t); +} + +// initialise master weights based on the regular weights, taking into account sharding +void init_master_weights(GPT2 *model) { + int num_param_tensors = tensors_start[PARAMETER+1]; + int master_start = tensors_start[PARAMETER_MASTER]; // relies on there being the same number of parameter and master parameter tensors + for (int i = 0; i < num_param_tensors; i++) { + TensorGPU master = tensor_specs[i+master_start]; + switch (tensor_specs[i].data_type) { + case DType::FP32: init_tensor_shard(master, i); break; + case DType::FP16: init_tensor_shard<__nv_half>(master, i); break; + case DType::BF16: init_tensor_shard<__nv_bfloat16>(master, i); break; + case DType::FP8E4M3: init_tensor_shard<__nv_fp8_e4m3>(master, i); break; + } + } +} + void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { // write the model to a checkpoint file printf0("Writing model to %s\n", checkpoint_path); @@ -450,12 +501,21 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - device_to_file(model_file, model->tensor_memory, model->num_parameters_bytes, IO_BUF_SIZE); + bool write_as_floatX = true; + if (write_as_floatX && model->num_parameters_bytes != model->num_parameters * sizeof(floatX)) { + // convert the parameters to floatX before writing them + assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work + convert_from_fixed_parameters(model, model->tensor_memory[MULTIUSE]); + device_to_file(model_file, model->tensor_memory[MULTIUSE], model->num_parameters * sizeof(floatX), IO_BUF_SIZE); + } else { + // just write the parameters as they are + device_to_file(model_file, model->tensor_memory[PARAMETER], model->num_parameters_bytes, IO_BUF_SIZE); + } // close file, we're done fcloseCheck(model_file); } -void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) { +void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { // If weight_init is true, we will load the weights from this checkpoint .bin file // We sometimes want this to be false, if we are going to initialize these weights from // the master weights that are instead stored in the state .bin file. @@ -483,19 +543,18 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w exit(EXIT_FAILURE); } - // check if the precision mode of the checkpoing matches the model precision - if (weight_init) { - if (PRECISION_MODE == PRECISION_BF16 && version != 5) { - fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); - fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); - exit(EXIT_FAILURE); - } - if (PRECISION_MODE == PRECISION_FP32 && version != 3) { - fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); - fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); - fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); - exit(EXIT_FAILURE); - } + // check if the precision mode of the checkpoint matches the model precision + // todo - we could support this (and FP16) fairly easily by modifying convert_fixed_parameters() a bit... + if (PRECISION_MODE == PRECISION_BF16 && version != 5) { + fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n"); + exit(EXIT_FAILURE); + } + if (PRECISION_MODE == PRECISION_FP32 && version != 3) { + fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path); + fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n"); + fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n"); + exit(EXIT_FAILURE); } // read in hyperparameters @@ -506,51 +565,26 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w model->config.channels = model_header[6]; model->config.padded_vocab_size = model_header[7]; + // key line to allocate all of the GPU buffers for all of the tensos gpt2_allocate(model); - if (weight_init) { - fseek(model_file, 0, SEEK_END); - size_t checkpoint_bytes = ftell(model_file) - sizeof(model_header); - fseek(model_file, sizeof(model_header), SEEK_SET); - - if (checkpoint_bytes != model->num_parameters_bytes) { - assert(checkpoint_bytes <= tensors_bytes[MULTIUSE]); // hack - won't work if params size > activations size - file_to_device(model->tensor_memory[MULTIUSE], model_file, checkpoint_bytes, IO_BUF_SIZE); - - size_t offset = 0; - int num_param_tensors = tensors_start[PARAMETER_GRAD]; - for (int i = 0; i < num_param_tensors; i++) { - TensorGPU tensor = tensor_specs[i]; - tensor.data_ptr = (floatX*)(model->tensor_memory[MULTIUSE] + offset); - offset += tensor.num_elements * sizeof(floatX); - update_absmax(tensor); - } - - update_scales_from_absmax(); - - offset = 0; - for (int i = 0; i < num_param_tensors; i++) { - TensorGPU tensor_in = tensor_specs[i]; - tensor_in.data_ptr = (floatX*)(model->tensor_memory[MULTIUSE] + offset); - offset += tensor_in.num_elements * sizeof(floatX); - - switch (tensor_specs[i].data_type) { - case DType::FP32: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; - case DType::BF16: copy_advanced((TensorGPU<__nv_bfloat16>)tensor_specs[i], tensor_in); break; - case DType::FP16: copy_advanced((TensorGPU)tensor_specs[i], tensor_in); break; - case DType::FP8E4M3: copy_advanced((TensorGPU<__nv_fp8_e4m3>)tensor_specs[i], tensor_in); break; - case DType::FP8E5M2: copy_advanced((TensorGPU<__nv_fp8_e5m2>)tensor_specs[i], tensor_in); break; - } - } - cudaMemset(model->tensor_memory[MULTIUSE], 0, checkpoint_bytes); - } else { - file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); - } + // if the number of bytes in the checkpoint doesn't match the number of bytes allocated, + // we assume the checkpoint is all floatX but our model has different sizes for different tensors (e.g. FP8) + fseek(model_file, 0, SEEK_END); + size_t checkpoint_bytes = ftell(model_file) - sizeof(model_header); + fseek(model_file, sizeof(model_header), SEEK_SET); + + if (checkpoint_bytes != model->num_parameters_bytes) { + assert(checkpoint_bytes == model->num_parameters * sizeof(floatX)); + assert(checkpoint_bytes <= tensors_bytes[MULTIUSE]); // todo - won't work if params size > activations size + file_to_device(model->tensor_memory[MULTIUSE], model_file, checkpoint_bytes, IO_BUF_SIZE); + convert_fixed_parameters(model, model->tensor_memory[MULTIUSE], checkpoint_bytes); + } else { + file_to_device(model->tensor_memory[PARAMETER], model_file, model->num_parameters_bytes, IO_BUF_SIZE); } - fcloseCheck(model_file); - // only return from this function once we are certain the params are ready on the GPU + fcloseCheck(model_file); cudaCheck(cudaDeviceSynchronize()); } @@ -596,6 +630,7 @@ void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) { config->num_heads = channels / head_size; config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2 } + void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // The model descriptor can be: // - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers. @@ -626,56 +661,56 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // NOTE: assuming all parameters are of the type floatX, could be relaxed later mt19937_state init_rng; manual_seed(&init_rng, 42); - floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); - memset(params_memory_cpu, 0, model->num_parameters_bytes); + size_t fixed_size_bytes = model->num_parameters * sizeof(floatX); + floatX* params_memory_cpu = (floatX*)mallocCheck(fixed_size_bytes); + memset(params_memory_cpu, 0, fixed_size_bytes); + // fill in all the weights with random values - float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); // we have to init all these tensors exactly in the order that PyTorch initializes them // so that we can match them up and get correctness and exactly the same initial conditions - /* - size_t L = model->config.num_layers; + float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); size_t offset = 0; - for (int l = 0; l < L; l++) { - offset = 0; - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - // the layernorm parameters are all initialized to 1 - if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once - for (size_t j = 0; j < model->param_elements[i]; j++) { - params_memory_cpu[offset + j] = 1.0f; - } + + int num_param_tensors = tensors_start[PARAMETER+1]; + for (int i = 0; i < num_param_tensors; i++) { + TensorSpec tensor = tensor_specs[i]; + if ((tensor.flags & TFlags::LAYERNORM) && !(tensor.flags & BIAS)) { + for (size_t j = 0; j < tensor.num_elements; j++) { + params_memory_cpu[offset + j] = (floatX)1.0f; } - // weights tensors are handled here - if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors - || i == 4 || i == 6 || i == 10 || i == 12) { - size_t n = model->param_elements[i]; - size_t layer_offset = 0; - if (i == 0) {rer - // for wte tensor (padded vocab) override to init V instead of Vp rows - n = model->config.vocab_size * model->config.channels; - } - if (i == 4 || i == 6 || i == 10 || i == 12) { - // weight tensors, we are only initializing layer l - assert(n % L == 0); - n = n / L; - layer_offset = l * n; - } - // in GPT-2, the projections back into the residual stream are additionally - // scaled by 1/sqrt(2*L) for training stability - float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; - // okay let's draw the random numbers and write them - float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); - normal_(fp32_buffer, n, 0.0f, scale, &init_rng); - for (size_t j = 0; j < n; j++) { - params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j]; - } - free(fp32_buffer); + } + if (tensor.flags & TENSOR_2D) { + size_t n = tensor.num_elements; + if (n == model->config.padded_vocab_size * model->config.channels) { + n = model->config.vocab_size * model->config.channels; } - offset += model->param_elements[i]; + + // in GPT-2, the projections back into the residual stream are additionally + // scaled by 1/sqrt(2*L) for training stability + float scale = 0.02f; + if (strstr(tensor.name, "proj") != NULL) { // always love a good strstr()... /s + scale *= residual_scale; + } + + float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); + normal_(fp32_buffer, n, 0.0f, scale, &init_rng); + for (size_t j = 0; j < n; j++) { + params_memory_cpu[offset + j] = (floatX)fp32_buffer[j]; + } + free(fp32_buffer); } + offset += tensor.num_elements; } - */ - // copy them to GPU - cudaCheck(cudaMemcpy(model->tensor_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); + + // if the actual allocation doesn't match "params * sizeof(floatX)" we need to convert everything, otherwise just copy. + if (fixed_size_bytes != model->num_parameters_bytes) { + assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work + cudaMemcpy(model->tensor_memory[MULTIUSE], params_memory_cpu, fixed_size_bytes, cudaMemcpyHostToDevice); + convert_fixed_parameters(model, model->tensor_memory[MULTIUSE], fixed_size_bytes); + } else { + cudaCheck(cudaMemcpy(model->tensor_memory[PARAMETER], params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); + } + free(params_memory_cpu); } @@ -695,7 +730,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { } // unused parts of attention buffer must be zeroed for non-cuDNN path if (!CUDNN_ENABLED && T != model->seq_len) { - cudaCheck(cudaMemset(ACT_L(att, 0), 0, L * B * NH * T * T * sizeof(floatX))); + cudaCheck(cudaMemset(ACT_0(att), 0, L * B * NH * T * T * sizeof(floatX))); } // validate inputs, all indices mucst be in the range [0, V) tokenCheck(inputs, B*T, V); @@ -705,7 +740,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { // start of forward pass with encoder (layer 0) int l = 0; encoder_forward(ACT(encoded), model->inputs, PARAM(wte), PARAM(wpe), B, T, C); - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), ACT(encoded), PARAM(ln1w), PARAM(ln1b), B*T, C); for (; l < L; l++) { NvtxRange layer_range("Layer", l); @@ -713,28 +748,29 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); - matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + #ifdef ENABLE_CUDNN attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C, main_stream); #else attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); - fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); - matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); if(l+1 != L) { - fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), - PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); + fused_residual_forward5(ACT(residual3), ACT_L(ln1, l+1), ACT_L(ln1_mean, l+1), ACT_L(ln1_rstd, l+1), ACT(residual2), ACT(fcproj), + PARAM_L(ln1w, l+1), PARAM_L(ln1b, l+1), B*T, C); } else { - fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), - PARAM(lnfw), PARAM(lnfb), B*T, C); + fused_residual_forward5(ACT(residual3), ACT(lnf), ACT(lnf_mean), ACT(lnf_rstd), ACT(residual2), ACT(fcproj), + PARAM(lnfw), PARAM(lnfb), B*T, C); } } - matmul_forward_cublaslt(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); + matmul_forward(ACT(output), ACT(lnf), PARAM(wte), null_tensorX, B*T, C, Vp); } @@ -754,11 +790,11 @@ float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B // fused classifier: does the forward pass and first part of the backward pass const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements // note: we don't need to generate dlogits here - cudaCheck(cudaMemset(ACT_L(losses, 0), 0, B*T*sizeof(float))); + cudaCheck(cudaMemset(ACT_0(losses), 0, B*T*sizeof(float))); cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets - fused_classifier(ACT_L(output, 0), ACT_L(output, 0), ACT_L(losses, 0), dloss, model->targets, B*T, V, Vp, False); - cudaCheck(cudaMemcpy(model->cpu_losses, ACT_L(losses, 0), B * T * sizeof(float), cudaMemcpyDeviceToHost)); + fused_classifier(ACT_0(output), ACT_0(output), ACT_0(losses), dloss, model->targets, B*T, V, Vp, False); + cudaCheck(cudaMemcpy(model->cpu_losses, ACT_0(losses), B * T * sizeof(float), cudaMemcpyDeviceToHost)); for (int i = 0; i < B*T; i++) { mean_loss += model->cpu_losses[i]; } @@ -797,10 +833,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - tensorFP32 scratchF_HUGE = MULTI_L(output_scratch_fp32, 0); // Largest buffer imaginable (max of output & everything else) - tensorX scratchX_HUGE = MULTI_L(output_scratch, 0); - tensorFP32 scratchF = MULTI_L(local_scratch_fp32, 0); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) - tensorX scratchX = MULTI_L(local_scratch, 0); + tensorFP32 scratchF_HUGE = MULTI_0(output_scratch_fp32); // Largest buffer imaginable (max of output & everything else) + tensorX scratchX_HUGE = MULTI_0(output_scratch); + tensorFP32 scratchF = MULTI_0(local_scratch_fp32); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) + tensorX scratchX = MULTI_0(local_scratch); // backward pass: go in the reverse order of the forward pass, and call backward() functions // we kick off the chain rule by filling in dlosses with 1.0f/(B*T) @@ -808,7 +844,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); + matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); // backward the final layernorm layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); @@ -820,19 +856,18 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT(fch_gelu), ACT(fch)); + gelu_forward(ACT(fch_gelu), ACT(fch)); } matmul_backward_fp8(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, scratchF_HUGE, B*T, 4*C, C, ACT(fch)); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm - layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); + layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } - matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); - layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); // AGRAD(atty) is BF16, AGRAD(residual2) is BF16, ACT(atty) is BF16, PARAM(attprojw) is BF16... ==> 100% BF16 ==> keep BF16 for now! - matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); - + matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); #ifdef ENABLE_CUDNN attention_backward_cudnn(AGRAD(qkvr), AGRAD(atty), ACT(qkvr), ACT(atty), ACT(att), B, T, NH, C, main_stream); #else @@ -840,10 +875,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #endif if(model->recompute >= 2) { - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } matmul_backward_fp8(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, scratchF_HUGE, B*T, C, 3 * C); - layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. if(last_step) { @@ -880,17 +915,17 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX qkvr = ACT(qkvr); qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); - matmul_forward_cublaslt(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); #ifdef ENABLE_CUDNN attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); #else attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward_cublaslt(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward_cublaslt(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); - matmul_forward_cublaslt(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); } l = old_l; } @@ -928,7 +963,7 @@ float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { floatX* grads_memory = (floatX*)model->tensor_memory[PARAMETER_GRAD]; // repurposing this buffer (which isn't needed now) to write grad norm into it - float* grad_norm_squared = MULTI_L(output_scratch_fp32, 0); + float* grad_norm_squared = MULTI_0(output_scratch_fp32); float grad_norm_squared_cpu = 0.0f; int num_slices[2] = {1, model->config.num_layers}; @@ -981,13 +1016,6 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo exit(EXIT_FAILURE); } - bool init_state = model->init_state; - if(init_state) { - model->init_state = false; - cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_M], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config->shard_num_parameters * sizeof(float))); - } - // save RNG state at this point so we can round from master weights identically when restoring from a checkpoint model->rng_state_last_update = model->rng_state; @@ -999,23 +1027,13 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo const int block_size = 64; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; if (model->use_master_weights) { - if (init_state) { - // reads regular weights & writes to master+regular weights - adamw_full_update<<>>( - tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], - model->num_parameters, model->num_parameters / num_shards, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); - } else { - // reads master weights & writes to master+regular weights - adamw_full_update<<>>( - tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], - model->num_parameters, model->num_parameters / num_shards, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); - } + adamw_full_update<<>>( + tensor_specs_gpu, seed, tensors_start[PARAMETER+1], + model->num_parameters, model->num_parameters / num_shards, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } else { - // reads & writes regular weights only adamw_full_update<<>>( - tensor_specs_gpu, seed, tensors_start[PARAMETER_GRAD], + tensor_specs_gpu, seed, tensors_start[PARAMETER+1], model->num_parameters, model->num_parameters / num_shards, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } @@ -1050,24 +1068,12 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo float* v_ptr = model->v_memory + opt_state_offset; float* master_ptr = nullptr; if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } - if(init_state && model->master_weights != nullptr ) { - size_t grid_size = CEIL_DIV(shard.size, 512); - copy_and_cast_kernel<<>>(master_ptr, param_ptr, shard.size, - shard.size, tensor.size); - cudaCheck(cudaGetLastError()); - } - - if (init_from_master_only) { - // when resuming training from a checkpoint with master weights (allows changing precision) - init_from_master(param_ptr, master_ptr, shard.size, tensor.size, shard.size, num_layers, seed, main_stream); - } else { - // ok finally call the kernel to update the weights with AdamW - adamw_update(param_ptr, master_ptr, grad_ptr, - m_ptr, v_ptr, - shard.size, tensor.size, tensor.size, shard.size, num_layers, - learning_rate, - beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); - } + // ok finally call the kernel to update the weights with AdamW + adamw_update(param_ptr, master_ptr, grad_ptr, + m_ptr, v_ptr, + shard.size, tensor.size, tensor.size, shard.size, num_layers, + learning_rate, + beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); if (multi_gpu_config->zero_stage == 1) { #if MULTI_GPU @@ -1240,13 +1246,9 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename exit(EXIT_FAILURE); } - model->init_state = false; // we just got the state from file, no need to do first-touch init - assert(model->tensor_memory[PARAMETER_OPT_M] != nullptr); - assert(model->tensor_memory[PARAMETER_OPT_V] != nullptr); file_to_device(model->tensor_memory[PARAMETER_OPT_M], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); file_to_device(model->tensor_memory[PARAMETER_OPT_V], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - assert(model->tensor_memory[PARAMETER_MASTER] != nullptr); file_to_device(model->tensor_memory[PARAMETER_MASTER], state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } @@ -1542,9 +1544,7 @@ int main(int argc, char *argv[]) { if (resuming == 1) { // if `-y 1` was set, then we are resuming from the latest checkpoint - // if we are using master weights, we'll init them later inside load_state() - bool weight_init = !use_master_weights; - gpt2_build_from_checkpoint(&model, filename_buffer, weight_init); + gpt2_build_from_checkpoint(&model, filename_buffer); } else if (ends_with_bin(load_filename)) { // otherwise, if this is a .bin file, we assume it's a model, let's init from it gpt2_build_from_checkpoint(&model, load_filename); @@ -1638,11 +1638,17 @@ int main(int argc, char *argv[]) { floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); - // if we found a checkpoint to resume from, load the optimization state + // if we found a checkpoint to resume from, load the optimization state (and initialize it otherwise) int step = 0; if (resuming == 1) { snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); load_state(&step, &model, &train_loader, filename_buffer); + } else { + cudaCheck(cudaMemset(model.tensor_memory[PARAMETER_OPT_M], 0, multi_gpu_config.shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model.tensor_memory[PARAMETER_OPT_V], 0, multi_gpu_config.shard_num_parameters * sizeof(float))); + if (model.use_master_weights) { + init_master_weights(&model); + } } // init an OutlierDetector the training loss From acd1058d871543f8dcba7fd6bd5b479df3ccc4ea Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 18 Sep 2024 03:43:35 +0000 Subject: [PATCH 20/27] FP8 cleanup part 2 --- llmc/adamw.cuh | 3 +-- llmc/attention.cuh | 15 +++++------ llmc/copy_and_fp8.h | 2 +- llmc/cuda_common.h | 4 +-- llmc/encoder.cuh | 6 ++--- llmc/fused_classifier.cuh | 4 +-- llmc/gelu.cuh | 24 ++++++++++++------ llmc/layernorm.cuh | 32 +++++++++++------------ llmc/matmul.cuh | 39 +++++++++++++++------------- llmc/tensor.cuh | 45 ++++++++++++++++++++------------- train_gpt2.cu | 53 +++++++++++++++++++-------------------- 11 files changed, 121 insertions(+), 106 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 763e0a814..f4df77342 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -22,7 +22,6 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float wd, float grad_scale, int t) { - constexpr size_t block_size = 64; auto out_master128 = new_tensor128(master_tensor, true); auto out_opt_m128 = new_tensor128(opt_m_tensor, true); auto out_opt_v128 = new_tensor128(opt_v_tensor, true); @@ -97,7 +96,7 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, } idx += stride; } - out_param128.update_absmax(threadIdx.x, block_size, false); + out_param128.update_absmax(1); return idx; } diff --git a/llmc/attention.cuh b/llmc/attention.cuh index 36dcb58ae..72cd9c545 100644 --- a/llmc/attention.cuh +++ b/llmc/attention.cuh @@ -63,18 +63,15 @@ __global__ void permute_kernel_backward(tensorX dinp, dinp128_q.set(i, dq[idx+i]); dinp128_k.set(i, dk[idx+i]); dinp128_v.set(i, dv[idx+i]); - // to allow us to update the absmax only once - dinp128_k.add_value_stats(dk[idx+i], dinp128_k.get128()[i]); - dinp128_v.add_value_stats(dv[idx+i], dinp128_v.get128()[i]); + + // to allow us to update the absmax only once for the q vector + dinp128_q.add_value_stats(dk[idx+i], dinp128_k.get128()[i]); + dinp128_q.add_value_stats(dv[idx+i], dinp128_v.get128()[i]); } dinp128_q.store(inp_idx); dinp128_k.store(inp_idx + NH * d); dinp128_v.store(inp_idx + 2 * (NH * d)); - - // todo - merge this into 1 update - dinp128_q.update_absmax(threadIdx.x, blockDim.x, false); - dinp128_k.update_absmax(threadIdx.x, blockDim.x, false); - dinp128_v.update_absmax(threadIdx.x, blockDim.x, true); + dinp128_q.update_absmax(1); } __global__ void unpermute_kernel(tensorX out, floatX* inp, int B, int N, int NH, int d) { @@ -96,7 +93,7 @@ __global__ void unpermute_kernel(tensorX out, floatX* inp, int B, int N, int NH, out128.set(i, __ldcs(&inp[idx + i])); } out128.store(other_idx); - out128.update_absmax(threadIdx.x, blockDim.x, true); + out128.update_absmax(1); } __global__ void unpermute_kernel_backward(floatX* dout_permuted, tensorX dout, int B, int N, int NH, int d) { diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h index 94c8091b7..b263d7ef1 100644 --- a/llmc/copy_and_fp8.h +++ b/llmc/copy_and_fp8.h @@ -54,7 +54,7 @@ __global__ void copy_advanced_kernel(TensorGPU out, TensorGPU in) { out128.set(k, out_fp32); } out128.template store_same_length(idx); - out128.update_absmax(threadIdx.x, blockDim.x, true); + out128.update_absmax(1); } template diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index eedd923f0..44f19c608 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -100,12 +100,12 @@ typedef __nv_bfloat16 floatX; #if defined(ENABLE_FP8) typedef __nv_fp8_e4m3 float8; -typedef __nv_fp8_e5m2 grads8; +typedef __nv_fp8_e5m2 float8e5; #define DTYPE_FP8E4 DType::FP8E4M3 #define DTYPE_FP8E5 DType::FP8E5M2 #else typedef floatX float8; -typedef floatX grads8; +typedef floatX float8e5; #define DTYPE_FP8E4 DTYPE_FLOATX #define DTYPE_FP8E5 DTYPE_FLOATX #endif diff --git a/llmc/encoder.cuh b/llmc/encoder.cuh index 2bcd6017e..980165ca4 100644 --- a/llmc/encoder.cuh +++ b/llmc/encoder.cuh @@ -37,7 +37,7 @@ __global__ void encoder_forward_kernel3(tensorX out, out128.set(k, wte128.get(k) + wpe128.get(k)); } out128.store(b * T * C + t * C + c); - out128.update_absmax(threadIdx.x, blockDim.x, true); + out128.update_absmax(1); } template @@ -107,7 +107,7 @@ __global__ void wte_backward_kernel(tensorX dwte, dwte128.set_stochastic(k, accum[k] + dwte128.get(k), random); } dwte128.store(bucket_ix * C + c); - dwte128.update_absmax(threadIdx.x, blockDim.x, true); + dwte128.update_absmax(1); } __global__ void wpe_backward_kernel(tensorX dwpe, @@ -143,7 +143,7 @@ __global__ void wpe_backward_kernel(tensorX dwpe, dwpe128.set_stochastic(k, accum[k] + dwpe128.get(k), random); } dwpe128.store(t * C + c); - dwpe128.update_absmax(threadIdx.x, blockDim.x, true); + dwpe128.update_absmax(1); } // ---------------------------------------------------------------------------- diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index f56d0c023..8a0ac8962 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -132,7 +132,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) } } if constexpr (WriteDLogits) { - dlogits128.update_absmax(threadIdx.x, blockDim.x, true); + dlogits128.update_absmax(1); } } @@ -141,7 +141,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // replaces logits with logit gradients template -void fused_classifier(tensorX dlogits, tensorX logits, tensorFP32 losses, +void fused_classifier(tensorX dlogits, tensorX logits, tensor32 losses, const float dloss, const int* targets, int BT, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); diff --git a/llmc/gelu.cuh b/llmc/gelu.cuh index bc56e67d8..1c0af0f11 100644 --- a/llmc/gelu.cuh +++ b/llmc/gelu.cuh @@ -10,7 +10,7 @@ // CUDA kernels #define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI) -template +template __global__ void gelu_forward_kernel2(TensorGPU out, TensorGPU inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * inp.num_per_128(); @@ -32,12 +32,12 @@ __global__ void gelu_forward_kernel2(TensorGPU out, TensorGPU inp) { out128.set(k, half_xi * tanh_in_out + half_xi); } out128.store(idx, false); - out128.update_absmax(threadIdx.x, blockDim.x, true); + out128.update_absmax(1); } //template -template -__global__ void gelu_backward_kernel(TensorGPU dinp, TensorGPU dout, TensorGPU inp) { +template +__global__ void gelu_backward_kernel(TensorGPU dinp, TensorGPU dout, TensorGPU inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * dout.num_per_128(); auto dinp128 = new_tensor128(dinp); @@ -60,12 +60,12 @@ __global__ void gelu_backward_kernel(TensorGPU dinp, TensorGPU dout, Ten dinp128.set(k, result); } dinp128.store(idx, false); - dinp128.update_absmax(threadIdx.x, blockDim.x, true); + dinp128.update_absmax(1); } // ---------------------------------------------------------------------------- // kernel launchers -template +template void gelu_forward(TensorGPU out, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; @@ -76,11 +76,19 @@ void gelu_forward(TensorGPU out, TensorGPU inp, cudaStream_t stream= cudaCheck(cudaGetLastError()); } -template -void gelu_backward(TensorGPU dinp, TensorGPU dout, TensorGPU inp, cudaStream_t stream=main_stream) { +template +void gelu_backward(TensorGPU dinp, TensorGPU dout, TensorGPU inp, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 256; const int grid_size = CEIL_DIV(inp.num_elements, block_size * inp.num_per_128()); gelu_backward_kernel<<>>(dinp, dout, inp); cudaCheck(cudaGetLastError()); } + +void gelu_forward_fp8(tensor8 out, tensor8 inp, cudaStream_t stream=main_stream) { + gelu_forward(out, inp, stream); +} + +void gelu_backward_fp8(tensor8e5 dinp, tensor8e5 dout, tensor8 inp, cudaStream_t stream=main_stream) { + gelu_backward(dinp, dout, inp, stream); +} \ No newline at end of file diff --git a/llmc/layernorm.cuh b/llmc/layernorm.cuh index 730c281fe..db6892074 100644 --- a/llmc/layernorm.cuh +++ b/llmc/layernorm.cuh @@ -19,7 +19,7 @@ E.g., the layernorms are connected to the residuals so we += in layernorm backwa // CUDA kernels template -__global__ void layernorm_forward_kernel6(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, +__global__ void layernorm_forward_kernel6(TensorGPU out, tensor32 mean, tensor32 rstd, tensorX inp, tensorX weight, tensorX bias, int N, int C) { // Note that blockDim.x must be WARP_SIZE=32 but we don't want to pay the cost of assert() here @@ -72,11 +72,11 @@ __global__ void layernorm_forward_kernel6(TensorGPU out, tensorFP32 mean, ten __stcs(rstd + idx, s); } // update absmax - out128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); + out128.update_absmax(2); } template -__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, +__global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU normed, tensor32 mean, tensor32 rstd, const tensorX inp1, const TensorGPU inp2, const tensorX weight, const tensorX bias, int N, int C) { @@ -137,14 +137,14 @@ __global__ void fused_residual_forward_kernel5(tensorX residual, TensorGPU } // Update absmax for residual and normed tensors (typically it will skip residual as it is not FP8) - residual128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, false); - normed128.update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, true); + residual128.update_absmax(2); + normed128.update_absmax(2); } -template +template __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with only 1024 threads? - layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch_, - TensorGPU dout, tensorX inp, tensorX weight, tensorFP32 mean, tensorFP32 rstd, + layernorm_backward_kernel10(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensor32 scratch_, + TensorGPU dout, tensorX inp, tensorX weight, tensor32 mean, tensor32 rstd, int BT, int C) { int BLOCK_SIZE = blockDim.x; // todo - does it make any difference if this is hardcoded here? int warpsInBlock = BLOCK_SIZE / WARP_SIZE; //number of warps in block @@ -268,7 +268,7 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with } // if we did actually update the absmax (returns true), we already did __syncthreads() here - if (!dinp_new128.update_absmax(threadIdx.x, BLOCK_SIZE, false)) { + if (!dinp_new128.update_absmax(1)) { __syncthreads(); } @@ -340,8 +340,8 @@ __global__ void __launch_bounds__(512, 2) // todo - any warnings on Turing with dbias128_out.store_same_length(global_index); dweight128_out.store_same_length(global_index); } - dbias128_out.update_absmax(threadIdx.x, BLOCK_SIZE, false); - dweight128_out.update_absmax(threadIdx.x, BLOCK_SIZE, false); + dbias128_out.update_absmax(1); + dweight128_out.update_absmax(1); } } @@ -374,7 +374,7 @@ void launch_layernorm_kernel(KernelFunc kernel, int N, int C, cudaStream_t strea } template -void layernorm_forward(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, +void layernorm_forward(TensorGPU out, tensor32 mean, tensor32 rstd, tensorX inp, const tensorX weight, const tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); @@ -382,16 +382,16 @@ void layernorm_forward(TensorGPU out, tensorFP32 mean, tensorFP32 rstd, } template -void fused_residual_forward5(tensorX residual, TensorGPU normed, tensorFP32 mean, tensorFP32 rstd, +void fused_residual_forward5(tensorX residual, TensorGPU normed, tensor32 mean, tensor32 rstd, tensorX inp1, TensorGPU inp2, tensorX weight, tensorX bias, int N, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); launch_layernorm_kernel(fused_residual_forward_kernel5, N, C, stream, residual, normed, mean, rstd, inp1, inp2, weight, bias); } -template -void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensorFP32 scratch, - const TensorGPU dout, const tensorX inp, const tensorX weight, tensorFP32 mean, tensorFP32 rstd, +template +void layernorm_backward(tensorX dinp_new, tensorX dinp_old, tensorX dweight, tensorX dbias, tensor32 scratch, + const TensorGPU dout, const tensorX inp, const tensorX weight, tensor32 mean, tensor32 rstd, int BT, int C, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); const int block_size = 512; diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index 2c3c2de55..8d3d2d8da 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -335,12 +335,12 @@ void matmul_forward(TensorGPU out, if constexpr (sizeof(Tin) == 1) { matmul_cublaslt_fp8(pre_gelu.enabled() ? pre_gelu : out, weight, inp, bias, OC, BT, C, stream, false, false); if (pre_gelu.enabled()) { - gelu_forward(out, pre_gelu, stream); + gelu_forward(out, pre_gelu, stream); } } else { if (pre_gelu.enabled() && gelu_fusion < 1) { matmul_cublaslt(pre_gelu, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, null_tensorX, false); - gelu_forward(out, pre_gelu, stream); + gelu_forward(out, pre_gelu, stream); } else { matmul_cublaslt(out, weight, inp, bias, OC, BT, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); } @@ -348,7 +348,7 @@ void matmul_forward(TensorGPU out, } template -void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensorFP32 scratch, int BT, int OC, cudaStream_t stream=main_stream) { +void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensor32 scratch, int BT, int OC, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); // backward to bias, if given, does a += @@ -378,12 +378,12 @@ void matmul_backward_bias(tensorX dbias, TensorGPU dout, tensorFP32 scrat } } -template -void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, - TensorGPU dout, tensorFP8e4 inp, tensorFP8e4 weight, - tensorFP32 scratch1_big, tensorFP32 scratch2_huge, +template +void matmul_backward_fp8(tensor8e5 dinp, tensorX dweight, tensorX dbias, + TensorGPU dout, tensor8 inp, tensor8 weight, + tensor32 scratch1_big, tensor32 scratch2_huge, int BT, int C, int OC, - tensorFP8e4 pre_gelu_activation=tensorFP8e4(), cudaStream_t stream=main_stream) { + tensor8 pre_gelu_activation=tensor8(), cudaStream_t stream=main_stream) { #ifndef ENABLE_FP8 // FP8 is not enabled so we use the regular floatX matmul path matmul_backward(dinp, dweight, dbias, dout, inp, weight, scratch1_big, BT, C, OC, pre_gelu_activation, 1, stream); @@ -395,34 +395,37 @@ void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, // IMPORTANT: inp is allowed to be the same buffer as scratch2_huge (e.g. for fch_gelu) // ==> this MUST be done first and write to scratch1_big! // transpose input - TensorGPU inp_fp8_transposed = inp; + tensor8 inp_fp8_transposed = inp; inp_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; transpose_simple(inp_fp8_transposed, inp, C, BT, stream); // convert dout to FP8e5 if it is not already, and transpose it // the buffer is guaranteed to be at least twice as big as 4BTC, so we can split it in 2 // todo - merge conversion and tranposition like we did before? - TensorGPU dout_fp8 = *(TensorGPU*)&dout; - if constexpr (std::is_same::value == false) { - dout_fp8.data_ptr = (grads8*)(scratch2_huge.data_ptr); + tensor8e5 dout_fp8 = *(tensor8e5*)&dout; + if constexpr (std::is_same::value == false) { + dout_fp8.data_ptr = (float8e5*)(scratch2_huge.data_ptr); copy_advanced(dout_fp8, dout, stream); } - TensorGPU dout_fp8_transposed = dout_fp8; - dout_fp8_transposed.data_ptr = (grads8*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); + tensor8e5 dout_fp8_transposed = dout_fp8; + dout_fp8_transposed.data_ptr = (float8e5*)(scratch2_huge.data_ptr + (scratch2_huge.num_elements / 2)); transpose_simple(dout_fp8_transposed, dout_fp8, OC, BT, stream); // GEMM 1: dweight, inp_fp8_transposed, dout_fp8_transposed matmul_cublaslt_fp8(dweight, inp_fp8_transposed, dout_fp8_transposed, null_tensorX, C, OC, BT, stream, false, true); // transpose weight (todo: option to cache this / do it at optimizer time) - TensorGPU weight_fp8_transposed = weight; + tensor8 weight_fp8_transposed = weight; weight_fp8_transposed.data_ptr = (float8*)scratch1_big.data_ptr; transpose_simple(weight_fp8_transposed, weight, C, OC, stream); + // GEMM 2: dinp, weight_fp8_transposed, dout_fp8 matmul_cublaslt_fp8(dinp, weight_fp8_transposed, dout_fp8, null_tensorX, C, BT, OC, stream, false, true); + // todo - need dinp and dinp_pre_gelu passed separately here, important for UNIQUE_TENSOR_MEMORY! + // todo - need to support BF16 for dinp into gelu_backwasrd() with FP8 out of gelu_backward()! if (pre_gelu_activation.enabled()) { - gelu_backward(dinp, dinp, pre_gelu_activation, stream); + gelu_backward_fp8(dinp, dinp, pre_gelu_activation, stream); } #endif } @@ -430,7 +433,7 @@ void matmul_backward_fp8(tensorFP8e5 dinp, tensorX dweight, tensorX dbias, void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, tensorX dout, tensorX inp, tensorX weight, - tensorFP32 dbias_scratch, + tensor32 dbias_scratch, int BT, int C, int OC, tensorX pre_gelu_activation=null_tensorX, int gelu_fusion=1, cudaStream_t stream=main_stream) { NVTX_RANGE_FN(); @@ -442,7 +445,7 @@ void matmul_backward(tensorX dinp, tensorX dweight, tensorX dbias, // backward GELU (if it wasn't fused into the matmul above) if ( pre_gelu_activation.enabled() && gelu_fusion < 2) { - gelu_backward(dinp, dinp, pre_gelu_activation, stream); + gelu_backward(dinp, dinp, pre_gelu_activation, stream); } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 8db76035e..71f929399 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -54,7 +54,7 @@ __device__ __constant__ float* gpu_scale_memory_ptr; __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; // ---------------------------------------------------------------------------- -// Helper macros for accessing tensors in the training file +// Helper macros for accessing tensors in the training loop #define TENSOR(x,layer) get_tensor(x, DEFAULT, layer) #define ACT_L(x,layer) get_tensor(model->acts.x, MULTIUSE, layer) #define MULTI_L(x,layer) get_tensor(model->multiuse.x, MULTIUSE, layer) @@ -73,13 +73,13 @@ __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; template struct TensorGPU { + int id = -1; // TensorSpec index in tensor_specs[] array ElementType* data_ptr = NULL; float* scale_descale_ptr = NULL; unsigned int* absmax_ptr = NULL; size_t num_elements = 0; - int id = -1; - static constexpr bool no_scaling = (sizeof(ElementType) != 1); + static constexpr bool no_scaling = (sizeof(ElementType) != 1); // todo - this prevents scaling FP16 bool is_null() const { return (data_ptr == NULL); } bool enabled() const { return (data_ptr != NULL); } @@ -137,20 +137,23 @@ struct TensorGPU { }; typedef TensorGPU tensorX; -typedef TensorGPU tensorFP32; +typedef TensorGPU tensor32; typedef TensorGPU tensorFP16; typedef TensorGPU tensorBF16; #ifdef ENABLE_FP8 -typedef TensorGPU<__nv_fp8_e4m3> tensorFP8e4; -typedef TensorGPU<__nv_fp8_e5m2> tensorFP8e5; +typedef TensorGPU<__nv_fp8_e4m3> tensor8; +typedef TensorGPU<__nv_fp8_e5m2> tensor8e5; #else -typedef TensorGPU tensorFP8e4; -typedef TensorGPU tensorFP8e5; +typedef TensorGPU tensor8; +typedef TensorGPU tensor8e5; #endif extern TensorGPU null_tensorX; // ---------------------------------------------------------------------------- +// this is the "foundation" of the other tensor classes (TensorGPU and tensor128) +// they all implicitly refer to this (in tensor_specs[] and tensor_specs_gpu[] for now) with the id +// and these other classes are created by converting from this one (sometimes implicitly) struct TensorSpec { int id; char* ptr; @@ -252,6 +255,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { return spec; } +// this can only be called at initialisation time, once tensor_specs has been uploaded to the GPU, it is fixed in stone int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { assert(num_tensor_specs < 16*1024); assert((total_elements % num_shards) == 0); @@ -302,7 +306,6 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz if (reuse_every_n_layers > 0 && l >= reuse_every_n_layers) { copy_offset_from = first_tensor_id + (l % reuse_every_n_layers); } - int spec = add_tensor_spec(num_layers > 1 ? layer_name : name, total_elements, num_shards, data_type, copy_offset_from, flags, tensor_type); tensor_specs[spec].remaining_layers = num_layers - (l + 1); } @@ -393,6 +396,8 @@ public: id = tensor.id; #ifdef FAKE_FP8 + // fake FP8 only applies to specific tensors to test expected training performance + // todo - expand this to support more unusual formats and test things like blockwise scaling(?) if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { if ((tensor_specs_ptr[id].flags & (TFlags::RESIDUAL | TFlags::EMBEDDING | TFlags::BIAS)) == 0) { faking_fp8 = true; @@ -440,9 +445,10 @@ public: new_absmax = max(new_absmax, fabsf(value)); } + // get and set automatically apply scaling/descaling for FP8 values __device__ float get(int index) { float value = (float)data128[index] * (scaling ? descale : 1.0f); - value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); + value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); // ignored without FAKE_FP8 return value; } @@ -537,15 +543,18 @@ public: return true; } - __device__ void update_absmax(int dimensions=1, bool exit=false) { - if (dimensions == 1) { - update_absmax(threadIdx.x, blockDim.x, exit); - } else if (dimensions == 2) { - update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); - } else if (dimensions == 3) { - update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, - blockDim.x * blockDim.y * blockDim.z, exit); + // helper function to avoid having to specify threadIdx/blockDim manually + __device__ bool update_absmax(int block_dimensions, bool exit=false) { + if (block_dimensions == 1) { + return update_absmax(threadIdx.x, blockDim.x, exit); + } else if (block_dimensions == 2) { + return update_absmax(threadIdx.x + threadIdx.y * blockDim.x, blockDim.x * blockDim.y, exit); + } else if (block_dimensions == 3) { + return update_absmax(threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y, + blockDim.x * blockDim.y * blockDim.z, exit); } + assert(false); + return false; } __device__ void skip_absmax() { diff --git a/train_gpt2.cu b/train_gpt2.cu index f5565c09d..3269b6f99 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -97,7 +97,7 @@ TT current_tensor_type = TT::PARAMETER; int current_absmax_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; -TensorGPU null_tensorX = {0}; +tensorX null_tensorX = {0}; // ---------------------------------------------------------------------------- // GPT-2 model definition @@ -441,7 +441,7 @@ void convert_fixed_parameters(GPT2* model, char* gpu_buffer, size_t fixed_size_b // to convert from variable precision parameters to a single precision (e.g. before checkpointing) template -void convert_from_fixed_parameters(GPT2* model, char* gpu_buffer) { +void convert_to_fixed_parameters(GPT2* model, char* gpu_buffer) { size_t offset = 0; for (int i = 0; i < tensors_start[PARAMETER+1]; i++) { TensorGPU tensor_out = tensor_specs[i]; @@ -505,7 +505,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { if (write_as_floatX && model->num_parameters_bytes != model->num_parameters * sizeof(floatX)) { // convert the parameters to floatX before writing them assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work - convert_from_fixed_parameters(model, model->tensor_memory[MULTIUSE]); + convert_to_fixed_parameters(model, model->tensor_memory[MULTIUSE]); device_to_file(model_file, model->tensor_memory[MULTIUSE], model->num_parameters * sizeof(floatX), IO_BUF_SIZE); } else { // just write the parameters as they are @@ -833,9 +833,9 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int fused_classifier(AGRAD(output), ACT(output), ACT(losses), dloss, model->targets, B*T, V, Vp, True); // todo - split output & doutput // re-use the output buffer of the forward pass as a scratchpad during backward pass + dedicated buffer - tensorFP32 scratchF_HUGE = MULTI_0(output_scratch_fp32); // Largest buffer imaginable (max of output & everything else) + tensor32 scratchF_HUGE = MULTI_0(output_scratch_fp32); // Largest buffer imaginable (max of output & everything else) tensorX scratchX_HUGE = MULTI_0(output_scratch); - tensorFP32 scratchF = MULTI_0(local_scratch_fp32); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) + tensor32 scratchF = MULTI_0(local_scratch_fp32); // FP32 BTC with cuDNN, FP32 2*BTC without cuDNN (i.e. 4xBTC BF16) tensorX scratchX = MULTI_0(local_scratch); // backward pass: go in the reverse order of the forward pass, and call backward() functions @@ -846,7 +846,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // next: backward the classifier matmul matmul_backward(AGRAD(lnf), PGRAD(wte), null_tensorX, AGRAD(output), ACT(lnf), PARAM(wte), scratchF, B*T, C, Vp); // backward the final layernorm - layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, AGRAD(lnf), ACT_L(residual3, L-1), + layernorm_backward(AGRAD_L(residual3, L-1), null_tensorX, PGRAD(lnfw), PGRAD(lnfb), scratchF, (tensorX)AGRAD(lnf), ACT_L(residual3, L-1), PARAM(lnfw), ACT(lnf_mean), ACT(lnf_rstd), B*T, C); // now backward all the layers @@ -856,15 +856,15 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int tensorX dresidual = (l == 0) ? AGRAD(encoded) : AGRAD_L(residual3, l-1); if(model->recompute >= 1) { // recompute >= 1 means we recompute gelu - gelu_forward(ACT(fch_gelu), ACT(fch)); + gelu_forward_fp8(ACT(fch_gelu), ACT(fch)); } - matmul_backward_fp8(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, scratchF_HUGE, B*T, 4*C, C, ACT(fch)); + matmul_backward_fp8(AGRAD(fch), PGRAD(fcprojw), PGRAD(fcprojb), (tensorX)AGRAD(residual3), ACT(fch_gelu), PARAM(fcprojw), scratchF, scratchF_HUGE, B*T, 4*C, C, ACT(fch)); if(model->recompute >= 2) { // recompute >= 2 means we recompute layernorm - layernorm_forward(ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); + layernorm_forward((tensor8)ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), ACT(residual2), PARAM(ln2w), PARAM(ln2b), B*T, C); } - matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); - layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); + matmul_backward_fp8(AGRAD(ln2), PGRAD(fcw), PGRAD(fcb), (tensor8e5)AGRAD(fch), ACT(ln2), PARAM(fcw), scratchF, scratchF_HUGE, B*T, C, 4 * C); + layernorm_backward(AGRAD(residual2), AGRAD(residual3), PGRAD(ln2w), PGRAD(ln2b), scratchF, (tensor8e5)AGRAD(ln2), ACT(residual2), PARAM(ln2w), ACT(ln2_mean), ACT(ln2_rstd), B*T, C); // AGRAD(atty) is BF16, AGRAD(residual2) is BF16, ACT(atty) is BF16, PARAM(attprojw) is BF16... ==> 100% BF16 ==> keep BF16 for now! matmul_backward(AGRAD(atty), PGRAD(attprojw), PGRAD(attprojb), AGRAD(residual2), ACT(atty), PARAM(attprojw), scratchF, B*T, C, C); @@ -875,10 +875,10 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int #endif if(model->recompute >= 2) { - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward((tensor8)ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); } - matmul_backward_fp8(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, scratchF_HUGE, B*T, C, 3 * C); - layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); + matmul_backward_fp8(AGRAD(ln1), PGRAD(qkvw), PGRAD(qkvb), (tensorX)AGRAD(qkvr), ACT(ln1), PARAM(qkvw), scratchF, scratchF_HUGE, B*T, C, 3 * C); + layernorm_backward(dresidual, AGRAD(residual2), PGRAD(ln1w), PGRAD(ln1b), scratchF, (tensor8e5)AGRAD(ln1), residual, PARAM(ln1w), ACT(ln1_mean), ACT(ln1_rstd), B*T, C); // Accumulate gradients from this layer in a background stream. if(last_step) { @@ -901,35 +901,34 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } + // todo - this used to be bit-for-bit identical to not recomputing forward, why is it now different?! // Is it time to redo the forward pass from our activation checkpoints? - /* if (LAYERS_PER_ACTIVATION_CHECKPOINT && (l % max(1, LAYERS_PER_ACTIVATION_CHECKPOINT)) == 0 && l > 0) { - int old_l = l; - // forward pass time! + int backward_l = l; l -= LAYERS_PER_ACTIVATION_CHECKPOINT; for (int i = 0; i < LAYERS_PER_ACTIVATION_CHECKPOINT; i++, l++) { // non-fused layernorm as we already (only!) have the residual // (for the original forward pass, residual of l-1 is fused with layernorm of l) tensorX residual = (l == 0) ? ACT(encoded) : ACT_L(residual3, l-1); - layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); + layernorm_forward(ACT(ln1), ACT(ln1_mean), ACT(ln1_rstd), residual, PARAM(ln1w), PARAM(ln1b), B*T, C); - tensorX qkvr = ACT(qkvr); + tensorX qkvr = ACT(qkvr); // non-cudnn reuses tensor with different memory pre/post-permute qkvr.data_ptr = CUDNN_ENABLED ? ACT(qkvr) : MULTI(output_scratch); - matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + matmul_forward(qkvr, ACT(ln1), PARAM(qkvw), PARAM(qkvb), B*T, C, 3*C); + #ifdef ENABLE_CUDNN - attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C); + attention_forward_cudnn(ACT(atty), ACT(att), ACT(qkvr), B, T, NH, C, main_stream); #else attention_forward(ACT(atty), ACT(qkvr), ACT(att), qkvr, B, T, C, NH); #endif - matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); - fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); - matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); - matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); + matmul_forward(ACT(attproj), ACT(atty), PARAM(attprojw), PARAM(attprojb), B*T, C, C); + fused_residual_forward5(ACT(residual2), ACT(ln2), ACT(ln2_mean), ACT(ln2_rstd), residual, ACT(attproj), PARAM(ln2w), PARAM(ln2b), B*T, C); + matmul_forward(ACT(fch_gelu), ACT(ln2), PARAM(fcw), PARAM(fcb), B*T, C, 4*C, ACT(fch), model->gelu_fusion); + matmul_forward(ACT(fcproj), ACT(fch_gelu), PARAM(fcprojw), PARAM(fcprojb), B*T, 4*C, C); } - l = old_l; + l = backward_l; } - */ } encoder_backward(PGRAD(wte), PGRAD(wpe), scratchX_HUGE, model->workload_indices, model->bucket_info, From c68bb9f9a7270b044ed14f0e8d574e0ea70d928f Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 18 Sep 2024 17:15:44 +0000 Subject: [PATCH 21/27] WIP multi-gpu and new global norm (doesn't work yet) --- llmc/adamw.cuh | 101 +++++++++++++++------------------- llmc/cuda_common.h | 2 + llmc/global_norm.cuh | 126 +++++++++++++++++++++---------------------- llmc/tensor.cuh | 31 ++++++----- train_gpt2.cu | 51 +++++------------- 5 files changed, 137 insertions(+), 174 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index f4df77342..9a15e8377 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -15,19 +15,18 @@ __device__ float lerp(float start, float end, float weight) { return fma(weight, end, fma(-weight, start, start)); } -// always sizeof(param) <= sizeof(grad) <= sizeof(opt/master) <= sizeof(float) template -__device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, int spec_id, size_t current_start, size_t current_end, size_t stride, +__device__ size_t adamw_update_part(TensorGPU param_tensor, + size_t idx, size_t current_start, size_t current_end, size_t stride, unsigned int seed, unsigned int shard_idx, TensorGPU grad_tensor, TensorGPU master_tensor, TensorGPU opt_m_tensor, TensorGPU opt_v_tensor, - unsigned int seed, int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float wd, float grad_scale, int t) { - auto out_master128 = new_tensor128(master_tensor, true); + auto out_master128 = new_tensor128(master_tensor, true); auto out_opt_m128 = new_tensor128(opt_m_tensor, true); auto out_opt_v128 = new_tensor128(opt_v_tensor, true); auto out_param128 = new_tensor128(param_tensor); - __syncthreads(); // todo - hopefully results in better memory access patterns => TBC + __syncthreads(); // todo - hopefully improves memory locality while (idx < current_end) { unsigned int random = get_random_noise(seed, idx); @@ -36,17 +35,20 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, tensor128 opt_m128; tensor128 opt_v128; tensor128 master128; - - size_t offset = idx - current_start; int next_idx[TT::NUM_TYPES_PARAM] = {0}; int current_idx[TT::NUM_TYPES_PARAM] = {0}; + // todo - assuming either DPP or ZeRO 1 now (sharded optimizer/master, unsharded gradients/parameters) + // offset is 32-bit (checked <= elements in add_tensor_spec) + unsigned int offset = idx - current_start; + unsigned int unsharded_offset = offset + shard_idx * opt_v_tensor.num_elements; + // this implementation has a stride causing sparse reads/writes and bank conflicts for non-FP8 - // todo - compare performance with a version that uses 128-bit for FP32, 64-bit for BF16, 32-bit for FP8 + // todo - compare performance with a version that uses 128-bit for FP32, 64-bit for BF16, 32-bit for FP8 (probably much faster) #pragma unroll - for (int i = 0; i < 16; i += 4, offset += 4) { - if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, offset); - if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, offset, false, true); + for (int i = 0; i < 16; i += 4, offset += 4, unsharded_offset += 4) { + if (current_idx[PARAMETER] == 0) param128 = load_tensor128(param_tensor, unsharded_offset); + if (current_idx[PARAMETER_GRAD] == 0) grad128 = load_tensor128(grad_tensor, unsharded_offset, false, true); if (current_idx[PARAMETER_OPT_M] == 0) opt_m128 = load_tensor128(opt_m_tensor, offset, false,true); if (current_idx[PARAMETER_OPT_V] == 0) opt_v128 = load_tensor128(opt_v_tensor, offset, false, true); if (current_idx[PARAMETER_MASTER] == 0 && use_master_weights) master128 = load_tensor128(master_tensor, offset, false, true); @@ -71,11 +73,10 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, } float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + wd * old_param)); - out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random); - float new_param = out_param128.get(current_idx[PARAMETER] + k); if constexpr (use_master_weights) { out_master128.set(current_idx[PARAMETER_MASTER] + k, param); } + out_param128.set_stochastic(current_idx[PARAMETER] + k, param, random); } next_idx[PARAMETER] = (i + 4) % (16 / sizeof(Tparam)); next_idx[PARAMETER_GRAD] = (i + 4) % (16 / sizeof(Tgrad)); @@ -83,7 +84,7 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, next_idx[PARAMETER_OPT_V] = (i + 4) % (16 / sizeof(Tv)); next_idx[PARAMETER_MASTER] = (i + 4) % (16 / sizeof(Tmaster)); - if (next_idx[PARAMETER] == 0) out_param128.store(offset - current_idx[PARAMETER]); + if (next_idx[PARAMETER] == 0) out_param128.store(unsharded_offset - current_idx[PARAMETER]); if (next_idx[PARAMETER_OPT_M] == 0) out_opt_m128.store(offset - current_idx[PARAMETER_OPT_M]); if (next_idx[PARAMETER_OPT_V] == 0) out_opt_v128.store(offset - current_idx[PARAMETER_OPT_V]); if constexpr (use_master_weights) { @@ -101,78 +102,60 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, size_t idx, } template -__global__ void adamw_full_update(TensorSpec* specs, unsigned int seed, - int num_params_tensors, size_t num_parameters, size_t num_opt_parameters, - float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, - float eps, float weight_decay, float grad_scale, int t) { +__global__ void adamw_update_everything(int num_params_tensors, unsigned int seed , unsigned int shard_idx, + float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, + float eps, float weight_decay, float grad_scale, int t) { // ... constexpr size_t block_size = 64; // 64 ==> 4KiB chunks with iteration_size=16 for FP32 opt/master constexpr size_t iteration_size = 16; size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); - size_t stride = gridDim.x * blockDim.x * iteration_size; + unsigned int stride = gridDim.x * blockDim.x * iteration_size; int spec_id = 0; - TensorSpec* grad_specs = specs + num_params_tensors; - TensorSpec* opt_m_specs = specs + 2 * num_params_tensors; - TensorSpec* opt_v_specs = specs + 3 * num_params_tensors; - TensorSpec* master_specs = use_master_weights ? specs + 4 * num_params_tensors : opt_m_specs; - - TensorSpec opt_spec = opt_v_specs[spec_id]; - size_t current_start = opt_spec.offset / sizeof(float); - size_t current_end = current_start + opt_spec.num_elements; + TensorSpec* opt_v_specs = tensor_specs_ptr + 3 * num_params_tensors; + TensorSpec opt_v_spec = opt_v_specs[spec_id]; + size_t current_start = opt_v_spec.element_start_end.x; + size_t current_end = opt_v_spec.element_start_end.y; while (true) { - // todo - performance analysis/optimisation! (impact of using step 0?) while (idx >= current_end) { spec_id++; if (spec_id >= num_params_tensors) { return; } - opt_spec = opt_v_specs[spec_id]; - current_start = opt_spec.offset / sizeof(float); - current_end = current_start + opt_spec.num_elements; - - while (idx < current_start) { - idx += stride; - } + opt_v_spec = opt_v_specs[spec_id]; + current_start = opt_v_spec.element_start_end.x; + current_end = opt_v_spec.element_start_end.y; } - opt_spec = opt_v_specs[spec_id]; - current_start = opt_spec.offset / sizeof(float); - current_end = current_start + opt_spec.num_elements; - float wd = (opt_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; + TensorSpec param_spec = tensor_specs_ptr[spec_id]; + float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - TensorGPU grad_tensor = grad_specs[spec_id]; - TensorGPU master_tensor = master_specs[spec_id]; - TensorGPU opt_m_tensor = opt_m_specs[spec_id]; - TensorGPU opt_v_tensor = opt_spec; + TensorGPU master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_v_spec; + TensorGPU grad_tensor = tensor_specs_ptr[spec_id + 1*num_params_tensors]; + TensorGPU opt_m_tensor = tensor_specs_ptr[spec_id + 2*num_params_tensors];; + TensorGPU opt_v_tensor = opt_v_spec; - if (specs[spec_id].data_type == DType::FP32) { - TensorGPU param_tensor = specs[spec_id]; - idx = adamw_update_part( - param_tensor, idx, spec_id, current_start, current_end, stride, + if (param_spec.data_type == DType::FP32) { + idx = adamw_update_part((TensorGPU)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, - seed, num_params_tensors, num_parameters, num_opt_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, grad_scale, t); - } else if (specs[spec_id].data_type == DType::BF16) { - TensorGPU<__nv_bfloat16> param_tensor = specs[spec_id]; - idx = adamw_update_part( - param_tensor, idx, spec_id, current_start, current_end, stride, + } else if (param_spec.data_type == DType::BF16) { + idx = adamw_update_part((TensorGPU<__nv_bfloat16>)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, - seed, num_params_tensors, num_parameters, num_opt_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, grad_scale, t); - } else if (specs[spec_id].data_type == DType::FP8E4M3) { - TensorGPU<__nv_fp8_e4m3> param_tensor = specs[spec_id]; - idx = adamw_update_part( - param_tensor, idx, spec_id, current_start, current_end, stride, + } else if (param_spec.data_type == DType::FP8E4M3) { + idx = adamw_update_part((TensorGPU<__nv_fp8_e4m3>)param_spec, + idx, current_start, current_end, stride, seed, shard_idx, grad_tensor, master_tensor, opt_m_tensor, opt_v_tensor, - seed, num_params_tensors, num_parameters, num_opt_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, wd, grad_scale, t); } else { - assert(false); // TODO (no FP16 to avoid compile time increase but it'd be trivial to add) + assert(false); // TODO (no FP16 to avoid compile time increase but trivial to add here) } } } diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 44f19c608..545357ec8 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -38,8 +38,10 @@ extern cudaStream_t main_stream; // this needs to be defines rather than queried to be used for __launch_bounds__ #if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 #define MAX_1024_THREADS_BLOCKS 2 +#define MAX_WARPS 64 #else #define MAX_1024_THREADS_BLOCKS 1 +#define MAX_WARPS 48 #endif // convenience macro for calculating grid/block dimensions for kernels diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index 968171a81..a47006506 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -1,5 +1,7 @@ +// TODO - BUGGED - just committing my WIP, not sure why grad norm is zero, probably something silly! + /* -Global norm, used in gradient clipping +Global norm, used in gralldient clipping */ #include #include @@ -11,79 +13,71 @@ Global norm, used in gradient clipping // ---------------------------------------------------------------------------- // CUDA kernels -template -__device__ float global_norm_squared_for_range(const T* data, size_t count) { - size_t index = blockIdx.x * blockDim.x + threadIdx.x; - size_t grid_width = blockDim.x * gridDim.x; - float accumulator = 0.f; - for(size_t i = index; i < count; i += grid_width) { - accumulator += (float)data[i] * (float)data[i]; - } - // block-level reduce - return blockReduce(accumulator); -} +// currently assumes all gradients are the same type (simplified adamw_update_everything) +// ZeRO 1 should use shard_idx, while DPP and ZeRO 2/3 should simply set it to 0 +template +__global__ void __launch_bounds__(256, MAX_WARPS/8) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { + float grad_norm_accumulator = 0.f; -template -__global__ void global_norm_squared_kernel(float* out, const T* data, size_t count, ptrdiff_t stride) { - float block_sum = global_norm_squared_for_range(data + blockIdx.y * stride, count); - // each block accumulates its partial sum to out[out_index] - // we want to avoid using atomic add here so we combine this kernel with another kernel call - // that sums up the partial block sums - if(threadIdx.x == 0) { - size_t out_index = blockIdx.y * gridDim.x + blockIdx.x; - out[out_index] = out[out_index] + block_sum; - } -} + constexpr size_t block_size = 256; + constexpr size_t iteration_size = Packed128::size; + size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); + unsigned int stride = gridDim.x * blockDim.x * iteration_size; + + int spec_id = 0; + TensorSpec* grad_specs = tensor_specs_ptr + num_params_tensors; + TensorSpec* opt_v_specs = tensor_specs_ptr + 3 * num_params_tensors; + + TensorSpec opt_v_spec = opt_v_specs[spec_id]; + size_t current_start = opt_v_spec.element_start_end.x; + size_t current_end = opt_v_spec.element_start_end.y; + + while (true) { + while (idx >= current_end) { + // todo - check performance, misses probably okay if they reduce the tail effect + // (fastest block/SM "prefetches" for the slower ones) + // but tiny tensors back-to-back might be inefficient + spec_id++; + if (spec_id >= num_params_tensors) { + return; + } + opt_v_spec = opt_v_specs[spec_id]; + current_start = opt_v_spec.element_start_end.x; + current_end = opt_v_spec.element_start_end.y; + } + + // offset is 32-bit (checked <=4B elements in add_tensor_spec) + unsigned int offset = (idx - current_start) + (shard_idx * opt_v_spec.num_elements); + TensorGPU grad_tensor = grad_specs[spec_id]; -__global__ void global_norm_aggregate_kernel(float* out, size_t grid_size) { - size_t index = threadIdx.x; - // grab block sums from the previous kernel, use 0. as the neutral sum element - float block_sum = (index < grid_size) ? out[index] : 0.f; - float sum = blockReduce(block_sum); - if(threadIdx.x == 0) { - out[0] = sum; // out[0] ends up with the final norm squared + __syncthreads(); // todo - hopefully improves memory locality + while (idx < current_end) { + auto grad128 = load_tensor128(grad_tensor, offset, false, true); + for (int k = 0; k < grad_tensor.num_per_128(); k++) { + float grad = grad128.get(k); + grad_norm_accumulator += grad * grad; + } + idx += stride; + offset += stride; + } } + out[blockIdx.x] = blockReduce(grad_norm_accumulator);; } // ---------------------------------------------------------------------------- // kernel launcher -// Helper function determines the maximum number of block sums -int get_max_num_block_sums(int* num_slices_all, int numel) { - // NOTE: this needs to be kept in sync with `global_norm_squared` below. - const int block_size = 512; +template +void global_norm_tensors(float* out, int gpu_process_rank, cudaStream_t stream=main_stream) { + const int block_size = 256; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; - assert(grid_size > 0); - int max_num_block_sums = 0; - for (int i = 0; i < numel; i++) { - int num_slices = num_slices_all[i]; - const int gx = CEIL_DIV(grid_size, num_slices); - const int gy = num_slices; - max_num_block_sums = max(max_num_block_sums, gx * gy); - } - - return max_num_block_sums; -} - -template -void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t stride, int num_slices, int max_num_block_sums, bool reset, cudaStream_t stream=main_stream) { - const int block_size = 512; - // launch just enough blocks to fill the grid. deliberately no DIV_CEIL. - // having one block less than possible is a tiny performance hit, having - // one block too many is catastrophic, since it only can start once all the other - // blocks finish. anyway, I think cuda_threads_per_SM should be a multiple of 512 - // on all gpus, so the division really is going to be exact. - const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; - assert(grid_size > 0); // gives a better error than letting the call below fail - - const int gx = CEIL_DIV(grid_size, num_slices); - const int gy = num_slices; - assert(gx * gy < 1024); // we want to later accumulate the block sums in a single block + int num_params_tensors = tensors_start[PARAMETER+1]; + int num_shards_opt = tensor_specs[tensors_start[PARAMETER_OPT_M]].num_shards; + int num_shards_grad = tensor_specs[tensors_start[PARAMETER_GRAD]].num_shards; + int num_shards = num_shards_opt / num_shards_grad; // should work for both DPP and ZeRO 1/2/3 + int shard_idx = gpu_process_rank % num_shards; - if (reset) { - cudaCheck(cudaMemsetAsync(out, 0, max_num_block_sums * sizeof(float), stream)); - } - global_norm_squared_kernel<<>>(out, values, count, stride); - cudaCheck(cudaGetLastError()); -} + global_norm_tensors_kernel<<>>(out, num_params_tensors, shard_idx); + global_sum_deterministic(out, out, grid_size, stream); +} \ No newline at end of file diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 71f929399..d366f1289 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -79,36 +79,35 @@ struct TensorGPU { unsigned int* absmax_ptr = NULL; size_t num_elements = 0; - static constexpr bool no_scaling = (sizeof(ElementType) != 1); // todo - this prevents scaling FP16 - bool is_null() const { return (data_ptr == NULL); } - bool enabled() const { return (data_ptr != NULL); } - static __device__ __host__ TensorGPU from(ElementType* ptr=nullptr) { TensorGPU tmp; tmp.data_ptr = ptr; return tmp; } - template __device__ __host__ T* as() { return reinterpret_cast(data_ptr); } - __device__ __host__ operator ElementType*() const { return data_ptr; } - __device__ __host__ ElementType& operator[](size_t index) { return data_ptr[index]; } - __device__ __host__ const ElementType& operator[](size_t index) const { return data_ptr[index]; } - __device__ __host__ int num_per_128() const { return sizeof(int4) / sizeof(ElementType); } + __device__ __host__ bool is_null() const { + return (data_ptr == NULL); + } + __device__ __host__ bool enabled() const { + return (data_ptr != NULL); + } + + static constexpr bool no_scaling = (sizeof(ElementType) != 1); // todo - this prevents scaling FP16 __device__ __host__ float get_scalar(size_t index, bool disable_scaling=no_scaling) const { #ifdef FAKE_FP8 @@ -155,8 +154,8 @@ extern TensorGPU null_tensorX; // they all implicitly refer to this (in tensor_specs[] and tensor_specs_gpu[] for now) with the id // and these other classes are created by converting from this one (sometimes implicitly) struct TensorSpec { - int id; char* ptr; + int id; char name[16]; TT tensor_type; @@ -168,6 +167,9 @@ struct TensorSpec { short num_shards; short remaining_layers; + // explicit as performance optimization for optimizer critical path + ulonglong2 element_start_end; + template __host__ __device__ operator T*() const { // todo - sanity check DType matches T @@ -257,7 +259,7 @@ TensorSpec get_tensor(int spec_index, TT tensor_type, int layer) { // this can only be called at initialisation time, once tensor_specs has been uploaded to the GPU, it is fixed in stone int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, DType data_type, int copy_offset_from=-1, int flags=TFlags::NONE, TT tensor_type=TT::DEFAULT) { - assert(num_tensor_specs < 16*1024); + assert(num_tensor_specs < MAX_TENSORS); assert((total_elements % num_shards) == 0); TensorSpec* spec = &tensor_specs[num_tensor_specs]; @@ -268,6 +270,10 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->data_type = data_type; spec->flags = flags; + // parameter tensors must fit in a 32-bit unsigned integer (used as a performance optimisation in some kernels) + // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact + assert(total_elements < 4UL*1024*1024*1024 || spec->tensor_type == TT::MULTIUSE); + spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; spec->remaining_layers = 0; @@ -411,6 +417,7 @@ public: scaling = scaling && !disable_scaling; if (scaling) { + // using __restrict__ here should allow the compiler to cache/reuse this in loops etc. const float* __restrict__ ptr_restricted = tensor.scale_descale_ptr; scale = ptr_restricted[0]; descale = ptr_restricted[1]; @@ -525,7 +532,7 @@ public: // if this is the end of the kernel, the compiler puts a conditional EXIT right after BAR // but this way the EXIT is right before the barrier which frees the warps slightly quicker bool done = (warp_id != 0); - if (done && exit) asm volatile("exit;"); + if (done && exit) asm volatile("exit;"); // todo - does this help enough to be worth it? __syncthreads(); if (done && !exit) return true; diff --git a/train_gpt2.cu b/train_gpt2.cu index 3269b6f99..0189573cf 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -344,9 +344,12 @@ void gpt2_allocate(GPT2 *model) { } // Set the GPU pointer for each tensor spec (so we don't need to know the base and the offset) + // also specify 1st and end elements explicitly to optimise kernels iterating over the tensors for (size_t i = 0; i < num_tensor_specs; i++) { TensorSpec* spec = &tensor_specs[i]; spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; + spec->element_start_end.x = spec->offset / sizeof_dtype(spec->data_type); + spec->element_start_end.y = spec->element_start_end.x + spec->num_elements; } // we are finished creating the tensors specs and copy them to the GPU (they are effectively read-only) @@ -959,44 +962,21 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - floatX* grads_memory = (floatX*)model->tensor_memory[PARAMETER_GRAD]; - // repurposing this buffer (which isn't needed now) to write grad norm into it float* grad_norm_squared = MULTI_0(output_scratch_fp32); float grad_norm_squared_cpu = 0.0f; - int num_slices[2] = {1, model->config.num_layers}; - int max_num_block_sums = get_max_num_block_sums(num_slices, 2); - /*if (multi_gpu_config->zero_stage == 1) { - // because of the ncclReduceScatter() in backward, - // grads_memory only contains the averaged gradients at the local shards, - // so we only calculate the grad norm at the grads_memory belonging to the local shards - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t offset = tensor.offset + shard.offset; - bool is_first_pass = (i == 0); - if((i < 2 || i > 13)) { - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, 0, 1, - max_num_block_sums, is_first_pass, main_stream); - } else { - global_norm_squared(grad_norm_squared, grads_memory + offset, shard.size, tensor.size, model->config.num_layers, - max_num_block_sums, is_first_pass, main_stream); - } - } - global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); + // automagically handles everything including sharding for ZeRO 1/2/3 + global_norm_tensors(grad_norm_squared, multi_gpu_config->process_rank, main_stream); + #if MULTI_GPU + if (multi_gpu_config->zero_stage >= 1) { // further sum the (partial) squared norm across all GPUs ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); -#endif - } else*/ { - // in regular DDP, backward has averaged the gradients across all GPUs - // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed - global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); - global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); } - cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); +#endif + cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); return grad_norm_cpu; } @@ -1022,19 +1002,16 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo float beta2_correction = 1.0f - powf(beta2, t); unsigned int seed = random_u32(&model->rng_state); int num_shards = tensor_specs[tensors_start[PARAMETER_OPT_M]].num_shards; + int shard_idx = multi_gpu_config->process_rank % num_shards; // todo - currently assuming ZeRO 1 or DPP const int block_size = 64; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; if (model->use_master_weights) { - adamw_full_update<<>>( - tensor_specs_gpu, seed, tensors_start[PARAMETER+1], - model->num_parameters, model->num_parameters / num_shards, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); + adamw_update_everything<<>>(tensors_start[PARAMETER+1], seed, shard_idx, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } else { - adamw_full_update<<>>( - tensor_specs_gpu, seed, tensors_start[PARAMETER+1], - model->num_parameters, model->num_parameters / num_shards, - learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); + adamw_update_everything<<>>(tensors_start[PARAMETER+1], seed, shard_idx, + learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } // AdamW update From 41caf5f6f35a2f53da1980913ec54368f2dedf16 Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 18 Sep 2024 18:35:40 +0000 Subject: [PATCH 22/27] fixed global norm (not bit identical to previous implementation but I think it's correct) --- llmc/global_norm.cuh | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index a47006506..05e33d829 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -39,12 +39,15 @@ __global__ void __launch_bounds__(256, MAX_WARPS/8) global_norm_tensors_kernel(f // but tiny tensors back-to-back might be inefficient spec_id++; if (spec_id >= num_params_tensors) { - return; + break; } opt_v_spec = opt_v_specs[spec_id]; current_start = opt_v_spec.element_start_end.x; current_end = opt_v_spec.element_start_end.y; } + if (spec_id >= num_params_tensors) { + break; // goto would avoid this but I don't want to go to hell + } // offset is 32-bit (checked <=4B elements in add_tensor_spec) unsigned int offset = (idx - current_start) + (shard_idx * opt_v_spec.num_elements); @@ -61,7 +64,10 @@ __global__ void __launch_bounds__(256, MAX_WARPS/8) global_norm_tensors_kernel(f offset += stride; } } - out[blockIdx.x] = blockReduce(grad_norm_accumulator);; + float output = blockReduce(grad_norm_accumulator); + if (threadIdx.x == 0) { + out[blockIdx.x] = output; + } } // ---------------------------------------------------------------------------- From f769bdc731bc1c02f94b45d165e3c113acee4f73 Mon Sep 17 00:00:00 2001 From: ademeure Date: Wed, 18 Sep 2024 20:41:41 +0000 Subject: [PATCH 23/27] Optimized adam/global_norm using new "gpu_tensor_end_element" array --- llmc/adamw.cuh | 37 +++++++++++++++++++---------------- llmc/cuda_common.h | 11 +++++++---- llmc/global_norm.cuh | 46 ++++++++++++++++++++++---------------------- llmc/tensor.cuh | 13 +++++++------ train_gpt2.cu | 15 +++++++++++---- 5 files changed, 68 insertions(+), 54 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 9a15e8377..14428061d 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -111,30 +111,33 @@ __global__ void adamw_update_everything(int num_params_tensors, unsigned int see size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); unsigned int stride = gridDim.x * blockDim.x * iteration_size; - int spec_id = 0; - TensorSpec* opt_v_specs = tensor_specs_ptr + 3 * num_params_tensors; - TensorSpec opt_v_spec = opt_v_specs[spec_id]; - size_t current_start = opt_v_spec.element_start_end.x; - size_t current_end = opt_v_spec.element_start_end.y; + int opt_m_spec_id = 2 * num_params_tensors; + int last_opt_m_id = 3 * num_params_tensors - 1; + size_t current_end = tensor_end_element_ptr[opt_m_spec_id]; while (true) { while (idx >= current_end) { - spec_id++; - if (spec_id >= num_params_tensors) { - return; - } - opt_v_spec = opt_v_specs[spec_id]; - current_start = opt_v_spec.element_start_end.x; - current_end = opt_v_spec.element_start_end.y; + opt_m_spec_id++; + if (opt_m_spec_id > last_opt_m_id) break; + + #if __CUDA_ARCH__ < 800 + current_end = tensor_end_element_ptr[opt_m_spec_id]; + #else + // on A100+ we can prefetch 256B (32 end values) into the L2 + asm("ld.global.L1::evict_last.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + #endif } + if (opt_m_spec_id > last_opt_m_id) break; + int spec_id = opt_m_spec_id - 2 * num_params_tensors; + size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; TensorSpec param_spec = tensor_specs_ptr[spec_id]; - float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; - - TensorGPU master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_v_spec; TensorGPU grad_tensor = tensor_specs_ptr[spec_id + 1*num_params_tensors]; - TensorGPU opt_m_tensor = tensor_specs_ptr[spec_id + 2*num_params_tensors];; - TensorGPU opt_v_tensor = opt_v_spec; + TensorGPU opt_m_tensor = tensor_specs_ptr[spec_id + 2*num_params_tensors]; + TensorGPU opt_v_tensor = tensor_specs_ptr[spec_id + 3*num_params_tensors]; + TensorGPU master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_m_tensor; + + float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; if (param_spec.data_type == DType::FP32) { idx = adamw_update_part((TensorGPU)param_spec, diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 545357ec8..a8b4607a3 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -34,14 +34,17 @@ extern cudaStream_t main_stream; // Defining here like this possibly allows the compiler to optimize better #define WARP_SIZE 32U -// try to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance +// optimise the number of blocks that fit to maximise latency tolerance // this needs to be defines rather than queried to be used for __launch_bounds__ -#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 +#if __CUDA_ARCH__ >= 900 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ <= 700 #define MAX_1024_THREADS_BLOCKS 2 -#define MAX_WARPS 64 +#define MAX_THREADS 2048 // H100/A100/V100/Pascal/Maxwell(/Blackwell?) +#elif __CUDA_ARCH__ == 750 +#define MAX_1024_THREADS_BLOCKS 1 +#define MAX_THREADS 1024 // Turing #else #define MAX_1024_THREADS_BLOCKS 1 -#define MAX_WARPS 48 +#define MAX_THREADS 1536 // Consumer Ampere & Ada Lovelace #endif // convenience macro for calculating grid/block dimensions for kernels diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index 05e33d829..7fede4c91 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -16,42 +16,39 @@ Global norm, used in gralldient clipping // currently assumes all gradients are the same type (simplified adamw_update_everything) // ZeRO 1 should use shard_idx, while DPP and ZeRO 2/3 should simply set it to 0 template -__global__ void __launch_bounds__(256, MAX_WARPS/8) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { +__global__ void __launch_bounds__(256, MAX_THREADS/256) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { float grad_norm_accumulator = 0.f; constexpr size_t block_size = 256; constexpr size_t iteration_size = Packed128::size; - size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); + size_t idx = (blockIdx.x * block_size + threadIdx.x) * iteration_size; unsigned int stride = gridDim.x * blockDim.x * iteration_size; - int spec_id = 0; - TensorSpec* grad_specs = tensor_specs_ptr + num_params_tensors; - TensorSpec* opt_v_specs = tensor_specs_ptr + 3 * num_params_tensors; + int opt_m_spec_id = 2 * num_params_tensors; // opt_m is sharded with ZeRO 1/2/3 + int last_opt_m_id = 3 * num_params_tensors - 1; - TensorSpec opt_v_spec = opt_v_specs[spec_id]; - size_t current_start = opt_v_spec.element_start_end.x; - size_t current_end = opt_v_spec.element_start_end.y; + size_t current_end = tensor_end_element_ptr[opt_m_spec_id]; while (true) { while (idx >= current_end) { - // todo - check performance, misses probably okay if they reduce the tail effect - // (fastest block/SM "prefetches" for the slower ones) - // but tiny tensors back-to-back might be inefficient - spec_id++; - if (spec_id >= num_params_tensors) { - break; - } - opt_v_spec = opt_v_specs[spec_id]; - current_start = opt_v_spec.element_start_end.x; - current_end = opt_v_spec.element_start_end.y; - } - if (spec_id >= num_params_tensors) { - break; // goto would avoid this but I don't want to go to hell + opt_m_spec_id++; + if (opt_m_spec_id > last_opt_m_id) break; + + #if __CUDA_ARCH__ < 800 + current_end = tensor_end_element_ptr[opt_m_spec_id]; + #else + // on A100+ we can prefetch 256B (32 end values) into the L2 + asm("ld.global.L1::evict_last.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + #endif } + if (opt_m_spec_id > last_opt_m_id) break; // offset is 32-bit (checked <=4B elements in add_tensor_spec) - unsigned int offset = (idx - current_start) + (shard_idx * opt_v_spec.num_elements); - TensorGPU grad_tensor = grad_specs[spec_id]; + size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; + unsigned int offset = (idx - current_start) + (shard_idx * tensor_specs_ptr[opt_m_spec_id].num_elements); + + int grad_spec_id = opt_m_spec_id - num_params_tensors; + TensorGPU grad_tensor = tensor_specs_ptr[grad_spec_id]; __syncthreads(); // todo - hopefully improves memory locality while (idx < current_end) { @@ -64,6 +61,7 @@ __global__ void __launch_bounds__(256, MAX_WARPS/8) global_norm_tensors_kernel(f offset += stride; } } + float output = blockReduce(grad_norm_accumulator); if (threadIdx.x == 0) { out[blockIdx.x] = output; @@ -85,5 +83,7 @@ void global_norm_tensors(float* out, int gpu_process_rank, cudaStream_t stream=m int shard_idx = gpu_process_rank % num_shards; global_norm_tensors_kernel<<>>(out, num_params_tensors, shard_idx); + cudaCheck(cudaGetLastError()); global_sum_deterministic(out, out, grid_size, stream); + cudaCheck(cudaGetLastError()); } \ No newline at end of file diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index d366f1289..55006926d 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -48,10 +48,13 @@ extern TT current_tensor_type; // todo - avoid having this somehow? extern int current_absmax_index; // todo - move into model struct? extern float* gpu_scale_memory; extern unsigned int* gpu_absmax_memory; +// end element of each tensor to optimise iterating through them in kernels +extern size_t* gpu_tensor_end_element; __device__ __constant__ TensorSpec* tensor_specs_ptr; __device__ __constant__ float* gpu_scale_memory_ptr; __device__ __constant__ unsigned int* gpu_absmax_memory_ptr; +__device__ __constant__ size_t* tensor_end_element_ptr; // ---------------------------------------------------------------------------- // Helper macros for accessing tensors in the training loop @@ -154,22 +157,19 @@ extern TensorGPU null_tensorX; // they all implicitly refer to this (in tensor_specs[] and tensor_specs_gpu[] for now) with the id // and these other classes are created by converting from this one (sometimes implicitly) struct TensorSpec { - char* ptr; int id; - + char* ptr; // = model->tensor_memory[tensor_type] + offset char name[16]; TT tensor_type; DType data_type; int flags; - size_t offset; // into base pointer + size_t offset; // into tensor type's base pointer + size_t start_element; // on this shard size_t num_elements; // per shard short num_shards; short remaining_layers; - // explicit as performance optimization for optimizer critical path - ulonglong2 element_start_end; - template __host__ __device__ operator T*() const { // todo - sanity check DType matches T @@ -274,6 +274,7 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact assert(total_elements < 4UL*1024*1024*1024 || spec->tensor_type == TT::MULTIUSE); + spec->start_element = tensors_elements[spec->tensor_type]; spec->num_elements = total_elements / num_shards; spec->num_shards = num_shards; spec->remaining_layers = 0; diff --git a/train_gpt2.cu b/train_gpt2.cu index 0189573cf..29a0ae370 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -97,6 +97,8 @@ TT current_tensor_type = TT::PARAMETER; int current_absmax_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; +size_t* gpu_tensor_end_element = NULL; + tensorX null_tensorX = {0}; // ---------------------------------------------------------------------------- @@ -344,17 +346,22 @@ void gpt2_allocate(GPT2 *model) { } // Set the GPU pointer for each tensor spec (so we don't need to know the base and the offset) - // also specify 1st and end elements explicitly to optimise kernels iterating over the tensors + // also specify the end elements explicitly to optimise kernels iterating over the tensors + size_t* cpu_tensor_end_element = (size_t*)mallocCheck(sizeof(size_t) * num_tensor_specs + 256); for (size_t i = 0; i < num_tensor_specs; i++) { TensorSpec* spec = &tensor_specs[i]; spec->ptr = model->tensor_memory[spec->tensor_type] + spec->offset; - spec->element_start_end.x = spec->offset / sizeof_dtype(spec->data_type); - spec->element_start_end.y = spec->element_start_end.x + spec->num_elements; + cpu_tensor_end_element[i] = spec->start_element + spec->num_elements; } // we are finished creating the tensors specs and copy them to the GPU (they are effectively read-only) cudaMalloc((void**)&tensor_specs_gpu, sizeof(TensorSpec) * num_tensor_specs); cudaMemcpy(tensor_specs_gpu, tensor_specs, sizeof(TensorSpec) * num_tensor_specs, cudaMemcpyHostToDevice); + // also upload the "end element" array which we use to optimise iterating through tensors in our kernels + // extra 256B so that we can avoid bounds checking when prefetching etc. + cudaMalloc(&gpu_tensor_end_element, sizeof(size_t) * num_tensor_specs + 256); + cudaMemcpy(gpu_tensor_end_element, cpu_tensor_end_element, sizeof(size_t) * num_tensor_specs + 256, cudaMemcpyHostToDevice); + free(cpu_tensor_end_element); printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRAD] / (1024*1024)); @@ -372,7 +379,7 @@ void gpt2_allocate(GPT2 *model) { cudaMemcpyToSymbol(tensor_specs_ptr, &tensor_specs_gpu, sizeof(TensorSpec*)); cudaMemcpyToSymbol(gpu_scale_memory_ptr, &gpu_scale_memory, sizeof(float*)); cudaMemcpyToSymbol(gpu_absmax_memory_ptr, &gpu_absmax_memory, sizeof(unsigned int*)); - + cudaMemcpyToSymbol(tensor_end_element_ptr, &gpu_tensor_end_element, sizeof(size_t*)); // ======================= // allocate_state stuff // ======================= From ac0dc6e4ba012c1e87ae353971ff7c27179be646 Mon Sep 17 00:00:00 2001 From: ademeure Date: Thu, 19 Sep 2024 00:52:57 +0000 Subject: [PATCH 24/27] more optimization and cleanup for global_norm (83% DRAM efficiency) and adamw (~70% due to sparse accesses/bank clashes) --- llmc/adamw.cuh | 25 ++++++++++---------- llmc/cuda_utils.cuh | 2 +- llmc/global_norm.cuh | 56 ++++++++++++++++++++++++-------------------- llmc/tensor.cuh | 4 ++-- train_gpt2.cu | 9 +++++-- 5 files changed, 53 insertions(+), 43 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 14428061d..91627877b 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -26,7 +26,7 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, auto out_opt_v128 = new_tensor128(opt_v_tensor, true); auto out_param128 = new_tensor128(param_tensor); - __syncthreads(); // todo - hopefully improves memory locality + __syncthreads(); // todo - this should improve memory locality while (idx < current_end) { unsigned int random = get_random_noise(seed, idx); @@ -102,32 +102,33 @@ __device__ size_t adamw_update_part(TensorGPU param_tensor, } template -__global__ void adamw_update_everything(int num_params_tensors, unsigned int seed , unsigned int shard_idx, +__global__ void adamw_update_everything(int num_params_tensors, int start_tensor, int last_tensor, unsigned int seed , unsigned int shard_idx, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, float grad_scale, int t) { // ... - constexpr size_t block_size = 64; // 64 ==> 4KiB chunks with iteration_size=16 for FP32 opt/master - constexpr size_t iteration_size = 16; + constexpr size_t block_size = 64; + constexpr size_t iteration_size = 16; // todo - this causes sparsit and bank clashes for FP32/BF16 loads/stores size_t idx = (blockIdx.x * block_size * iteration_size) + (threadIdx.x * iteration_size); unsigned int stride = gridDim.x * blockDim.x * iteration_size; int opt_m_spec_id = 2 * num_params_tensors; - int last_opt_m_id = 3 * num_params_tensors - 1; - size_t current_end = tensor_end_element_ptr[opt_m_spec_id]; + int last_opt_m_id = opt_m_spec_id + last_tensor; // opt_m is sharded with ZeRO 1 so use it as reference + opt_m_spec_id += start_tensor - 1; // -1 to compensate for the increment at the start of the loop below while (true) { - while (idx >= current_end) { + size_t current_end; + do { opt_m_spec_id++; - if (opt_m_spec_id > last_opt_m_id) break; + if (opt_m_spec_id > last_opt_m_id) return; // done! + // on A100+ we can prefetch 256B (32 values) into the L2, on older GPUs just use a regular load #if __CUDA_ARCH__ < 800 current_end = tensor_end_element_ptr[opt_m_spec_id]; #else - // on A100+ we can prefetch 256B (32 end values) into the L2 - asm("ld.global.L1::evict_last.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + asm("ld.global.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); #endif - } - if (opt_m_spec_id > last_opt_m_id) break; + } while (idx >= current_end); + int spec_id = opt_m_spec_id - 2 * num_params_tensors; size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 8301143c2..c261333bc 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -275,7 +275,7 @@ __device__ inline float warpReduceMax(float val) { } return val; } -// requires all 32 threads in the warp to be active, but should work for any block size +// requires all 32 threads in the warp to be active, but should work for any 1D(!) block size // uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes // the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end // but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index 7fede4c91..d274f47e9 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -13,56 +13,60 @@ Global norm, used in gralldient clipping // ---------------------------------------------------------------------------- // CUDA kernels -// currently assumes all gradients are the same type (simplified adamw_update_everything) -// ZeRO 1 should use shard_idx, while DPP and ZeRO 2/3 should simply set it to 0 -template -__global__ void __launch_bounds__(256, MAX_THREADS/256) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { - float grad_norm_accumulator = 0.f; - - constexpr size_t block_size = 256; - constexpr size_t iteration_size = Packed128::size; - size_t idx = (blockIdx.x * block_size + threadIdx.x) * iteration_size; - unsigned int stride = gridDim.x * blockDim.x * iteration_size; - - int opt_m_spec_id = 2 * num_params_tensors; // opt_m is sharded with ZeRO 1/2/3 - int last_opt_m_id = 3 * num_params_tensors - 1; - - size_t current_end = tensor_end_element_ptr[opt_m_spec_id]; +__device__ float global_norm_tensors_loop(size_t idx, unsigned int stride, int num_params_tensors, unsigned int shard_idx) { + float accumulator = 0.f; + int opt_m_spec_id = 2 * num_params_tensors - 1; // -1 as it gets incremented at the start of the loop below + int last_opt_m_id = 3 * num_params_tensors - 1; // opt_m is fully sharded with ZeRO 1 so we use it as a reference while (true) { - while (idx >= current_end) { + size_t current_end; + // optimized critical path loop to iterate over tensors: only 8 SASS instructions! + // 3 SETP, 2 BRA, 1 IADD3, 1 IMAD, and of course 1 LDG.E.LTC256B.64 + do { opt_m_spec_id++; - if (opt_m_spec_id > last_opt_m_id) break; + if (opt_m_spec_id > last_opt_m_id) return accumulator; // return and write the result to memory + // on A100+ we can prefetch 256B (32 values) into the L2, on older GPUs just use a regular load + // (this improved DRAM utilization from ~81.5% to ~83.5% on my H100 PCIe) #if __CUDA_ARCH__ < 800 current_end = tensor_end_element_ptr[opt_m_spec_id]; #else - // on A100+ we can prefetch 256B (32 end values) into the L2 - asm("ld.global.L1::evict_last.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); + asm("ld.global.L2::256B.u64 {%0}, [%1];" : "=l"(current_end) : "l"(tensor_end_element_ptr + opt_m_spec_id)); #endif - } - if (opt_m_spec_id > last_opt_m_id) break; + } while (idx >= current_end); - // offset is 32-bit (checked <=4B elements in add_tensor_spec) + // offset is 32-bit (we check parameters tensors have less than 4B elements in add_tensor_spec) size_t current_start = tensor_specs_ptr[opt_m_spec_id].start_element; unsigned int offset = (idx - current_start) + (shard_idx * tensor_specs_ptr[opt_m_spec_id].num_elements); int grad_spec_id = opt_m_spec_id - num_params_tensors; TensorGPU grad_tensor = tensor_specs_ptr[grad_spec_id]; - __syncthreads(); // todo - hopefully improves memory locality - while (idx < current_end) { + __syncthreads(); // todo - check that this does improve performance (better memory locality) + while (idx < current_end) { // todo - profile number of iterations and adding an inner loop auto grad128 = load_tensor128(grad_tensor, offset, false, true); for (int k = 0; k < grad_tensor.num_per_128(); k++) { float grad = grad128.get(k); - grad_norm_accumulator += grad * grad; + accumulator += grad * grad; } idx += stride; offset += stride; } } +} + +// currently assumes all gradients are the same type (simplified adamw_update_everything) +// ZeRO 1 should use shard_idx, while DPP and ZeRO 2/3 should simply set it to 0 +template +__global__ void __launch_bounds__(256, MAX_THREADS/256) global_norm_tensors_kernel(float* out, int num_params_tensors, unsigned int shard_idx) { + constexpr size_t block_size = 256; + constexpr size_t iteration_size = Packed128::size; + unsigned int stride = gridDim.x * blockDim.x * iteration_size; + size_t idx = (blockIdx.x * block_size + threadIdx.x) * iteration_size; + + float accumulator = global_norm_tensors_loop(idx, stride, num_params_tensors, shard_idx); - float output = blockReduce(grad_norm_accumulator); + float output = blockReduce(accumulator); if (threadIdx.x == 0) { out[blockIdx.x] = output; } diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 55006926d..a4bd286a5 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -270,8 +270,8 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->data_type = data_type; spec->flags = flags; - // parameter tensors must fit in a 32-bit unsigned integer (used as a performance optimisation in some kernels) - // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact + // parameter tensors must fit in a 32-bit unsigned integer (used as an optimisation in e.g. global_norm_tensors_loop) + // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact, 3) ? assert(total_elements < 4UL*1024*1024*1024 || spec->tensor_type == TT::MULTIUSE); spec->start_element = tensors_elements[spec->tensor_type]; diff --git a/train_gpt2.cu b/train_gpt2.cu index 29a0ae370..7c64e2149 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1013,11 +1013,16 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo const int block_size = 64; const int grid_size = deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / block_size; + + int start_tensor = tensors_start[PARAMETER]; + int last_tensor = tensors_start[PARAMETER+1] - 1; + int num_tensors = last_tensor - start_tensor + 1; + if (model->use_master_weights) { - adamw_update_everything<<>>(tensors_start[PARAMETER+1], seed, shard_idx, + adamw_update_everything<<>>(num_tensors, start_tensor, last_tensor, seed, shard_idx, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } else { - adamw_update_everything<<>>(tensors_start[PARAMETER+1], seed, shard_idx, + adamw_update_everything<<>>(num_tensors, start_tensor, last_tensor, seed, shard_idx, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } From ef3053c4eae5f6ac7b2a08ff3c1bc9557698e14a Mon Sep 17 00:00:00 2001 From: ademeure Date: Thu, 19 Sep 2024 03:12:40 +0000 Subject: [PATCH 25/27] tentative multigpu ZeRO 1 with AllGather + absmax history window --- llmc/tensor.cuh | 64 ++++++++++++++++++++++++----------------- train_gpt2.cu | 76 ++++++++++++++----------------------------------- 2 files changed, 60 insertions(+), 80 deletions(-) diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index a4bd286a5..60cf78477 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -45,7 +45,7 @@ extern size_t tensors_elements[TT::COUNT]; extern int num_tensor_specs; extern TT current_tensor_type; // todo - avoid having this somehow? -extern int current_absmax_index; // todo - move into model struct? +extern int absmax_history_index; // todo - move into model struct? extern float* gpu_scale_memory; extern unsigned int* gpu_absmax_memory; // end element of each tensor to optimise iterating through them in kernels @@ -321,57 +321,69 @@ int add_layer_specs(int num_layers, const char* name, size_t total_elements, siz // ---------------------------------------------------------------------------- -// todo - should this be moved elsewhere? maybe to copy_and_fp8.h? -__global__ void update_scale_descale_kernel(int num_tensor_specs) { +// the 1st num_tensor_specs values are the absmax of the current/last step +// the next [MAX_ABSMAX_HISTORY * num_tensor_specs] values are the history from previous steps +__global__ void update_scale_descale_kernel(int num_tensor_specs, int absmax_history_index) { int tid = blockIdx.x * blockDim.x + threadIdx.x; if (tid >= num_tensor_specs) return; - // Get the absmax value for this tensor - unsigned int absmax_uint = gpu_absmax_memory_ptr[tid]; - float absmax = __uint_as_float(absmax_uint); + // copy current absmax to history then clear it + gpu_absmax_memory_ptr[tid + (absmax_history_index * num_tensor_specs)] = gpu_absmax_memory_ptr[tid]; + gpu_absmax_memory_ptr[tid] = 0; + float absmax = 0.0f; - // Calculate scale and descale - float scale = 1.0f; - float descale = 1.0f; - if (absmax != 0.0f) { - scale = 1.0f / absmax; - descale = absmax; + // get the maximum absmax from the history (todo - do we want to mitigate outliers here?) + #pragma unroll + for (int i = 1; i <= MAX_ABSMAX_HISTORY; i++) { + absmax = max(absmax, __uint_as_float(gpu_absmax_memory_ptr[tid + (i * num_tensor_specs)])); + } + + // calculate scale based on the maximum absmax + float scale = (absmax != 0.0f) ? (1.0f / absmax) : 1.0f; + + // FP8 e4m3 vs e5m2 (the latter is currently only used for activation gradients) + bool use_e5m2 = (tensor_specs_ptr[tid].data_type == DType::FP8E5M2); + #ifdef FAKE_FP8 + if (tensor_specs_ptr[tid].flags & TFlags::GRADIENT && tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE) { + use_e5m2 = true; } + #endif - if ((tensor_specs_ptr[tid].flags & TFlags::GRADIENT) && (tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE)) { - // e5 + if (use_e5m2) { if (absmax != 0.0f) { scale *= 32768.0f; - descale *= 1.0f/32768.0f; } else { - // default so that things are not as bad for gradients on the first step + // hacky default to avoid extreme gradient underflow on the 1st step scale = 4096.0f; - descale = 1.0f/4096.0f; } - } else { - // e4 + } else if (tensor_specs_ptr[tid].data_type == DType::FP8E4M3) { // todo - power benefit of making sure top bit of exponent is (nearly always) zero? // this can be done simply by *not* multiplying here, so that the "maximum" is 1.0f + // we probably want some threshold for badly behaved parameters to use the full range //if (tensor_specs_ptr[tid].tensor_type != TT::PARAMETER || absmax >= 4.0f) { if (absmax != 0.0f) { scale *= 256.0f; - descale *= (1.0f/256.0f); } } - // Update gpu_scale_memory - // todo: descale should be delayed by one step for parameters (see comment in gpt2_update). + // update scale and descale memory + // descale must be delayed by one step for parameters (see comment in gpt2_update). gpu_scale_memory_ptr[tid * 2] = scale; - gpu_scale_memory_ptr[tid * 2 + 1] = descale; - // todo: circular buffer !!! - //gpu_absmax_memory[tid] = 0.0f; + if (tensor_specs_ptr[tid].tensor_type == TT::PARAMETER) { + float old_scale = gpu_scale_memory_ptr[tid * 2]; + gpu_scale_memory_ptr[tid * 2 + 1] = 1.0f / old_scale; + } else { + gpu_scale_memory_ptr[tid * 2 + 1] = 1.0f / scale; + } } void update_scales_from_absmax() { int block_size = 256; int num_blocks = CEIL_DIV(num_tensor_specs, block_size); - update_scale_descale_kernel<<>>(num_tensor_specs); + + update_scale_descale_kernel<<>>(num_tensor_specs, absmax_history_index + 1); + absmax_history_index = (absmax_history_index + 1) % MAX_ABSMAX_HISTORY; } // ---------------------------------------------------------------------------- diff --git a/train_gpt2.cu b/train_gpt2.cu index 7c64e2149..96765a00f 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -94,7 +94,7 @@ size_t tensors_elements[TT::COUNT] = {0}; int num_tensor_specs = 0; TT current_tensor_type = TT::PARAMETER; -int current_absmax_index = 0; +int absmax_history_index = 0; float* gpu_scale_memory = NULL; unsigned int* gpu_absmax_memory = NULL; size_t* gpu_tensor_end_element = NULL; @@ -372,8 +372,8 @@ void gpt2_allocate(GPT2 *model) { // absmax/scale/descale buffers for FP8 & Friends (scale is initialised via update_scales_from_absmax) cudaMalloc(&gpu_scale_memory, 2 * num_tensor_specs * sizeof(float)); - cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); - cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * MAX_ABSMAX_HISTORY); + cudaMalloc(&gpu_absmax_memory, sizeof(unsigned int) * num_tensor_specs * (MAX_ABSMAX_HISTORY + 1)); + cudaMemset(gpu_absmax_memory, 0, sizeof(unsigned int) * num_tensor_specs * (MAX_ABSMAX_HISTORY + 1)); // copy pointers to constant buffers for easy access on the GPU cudaMemcpyToSymbol(tensor_specs_ptr, &tensor_specs_gpu, sizeof(TensorSpec*)); @@ -1026,65 +1026,33 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, grad_scale, t); } - // AdamW update - // handle adamw for all the transformer blocks - /* - for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { - // generate a unique seed for each tensor - unsigned int seed = random_u32(&model->rng_state); - - int num_layers = model->config.num_layers; - if((i < 2 || i > 13)) { - num_layers = 1; - } - - ShardInfo tensor = gpt2_get_tensor_at_layer(model, 0, i); - ShardInfo shard = multi_gpu_get_shard_offset(tensor.size, multi_gpu_config, 1); - ptrdiff_t local_offset_full = tensor.offset + shard.offset; - ptrdiff_t local_offset_partial = tensor.offset / multi_gpu_config->num_processes; - - // we only want to weight decay the 2D tensors and leave all 1D tensors alone - // in particular this also decays the embedding weights, but this is ok: - // - the token embeddings are weight shared and participate in the final projection to logits - // - the position embeddings actively participate at every forward/backward pass - float wd = (i == 0 || i == 1 || i == 4 || i == 6 || i == 10 || i == 12) ? weight_decay : 0.0f; - floatX* param_ptr = (floatX*)model->tensor_memory + local_offset_full; - floatX* grad_ptr = (floatX*)model->grads_memory + local_offset_full; - - ptrdiff_t opt_state_offset = multi_gpu_config->zero_stage < 1 ? local_offset_full : local_offset_partial; - float* m_ptr = model->m_memory + opt_state_offset; - float* v_ptr = model->v_memory + opt_state_offset; - float* master_ptr = nullptr; - if (model->master_weights != nullptr) { master_ptr = model->master_weights + opt_state_offset; } - // ok finally call the kernel to update the weights with AdamW - adamw_update(param_ptr, master_ptr, grad_ptr, - m_ptr, v_ptr, - shard.size, tensor.size, tensor.size, shard.size, num_layers, - learning_rate, - beta1, beta2, t, eps, wd, grad_scale, seed, main_stream); - - if (multi_gpu_config->zero_stage == 1) { #if MULTI_GPU - ncclCheck(ncclGroupStart()); - for(int l = 0; l < num_layers; ++l) { - // gather updated shards of model->tensor_memory from each process - ncclCheck(ncclAllGather(param_ptr + l * tensor.size, - (floatX*) model->tensor_memory + tensor.offset + l * tensor.size, - shard.size, ncclFloatX, - multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); - } - ncclCheck(ncclGroupEnd()); -#endif + if (multi_gpu_config->zero_stage == 1) { + ncclCheck(ncclGroupStart()); + for (int id = 0; id < num_tensors; id++) { + TensorSpec param_tensor = tensor_specs[id]; + TensorSpec opt_tensor = tensor_specs[id + tensors_start[PARAMETER_OPT_M]]; + + size_t sendcount = opt_tensor.num_elements * sizeof_dtype(opt_tensor.data_type); + void* recvbuff = param_tensor.ptr; + void* sendbuff = param_tensor.ptr + (multi_gpu_config->process_rank * sendcount); + + ncclCheck(ncclAllGather(sendbuff, recvbuff, sendcount, ncclFloatX, + multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); } + ncclCheck(ncclGroupEnd()); } - */ + // combine the absmax of all the GPUs + ncclCheck(ncclAllReduce(gpu_absmax_memory, gpu_absmax_memory, num_tensors, ncclFloat, ncclMax, + multi_gpu_config->nccl_comm, multi_gpu_config->nccl_stream)); +#endif + // todo - smarter synchronization with double buffering etc... + cudaCheck(cudaDeviceSynchronize()); // update FP8 scale & descale multipliers based on the absmax // since we just updated the parameters with the old scale, // the descale of parameters is "delayed" by one step. update_scales_from_absmax(); - - cudaCheck(cudaDeviceSynchronize()); } float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { From 0df701b2621a4ec8d907efebd44e8d07678fea9a Mon Sep 17 00:00:00 2001 From: ademeure Date: Thu, 19 Sep 2024 03:23:42 +0000 Subject: [PATCH 26/27] save absmax/scale/descale to state file (untested, does it work? let's find out... but not today!) --- train_gpt2.cu | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/train_gpt2.cu b/train_gpt2.cu index 96765a00f..d018884d6 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1147,6 +1147,8 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) state_header[3] = multi_gpu_config.process_rank; // rank of this process state_header[4] = model->use_master_weights; // whether we're using fp32 master weights state_header[5] = loader->should_shuffle; // shuffle state of the dataloader + state_header[6] = num_tensor_specs; // number of tensor specs (must match) + state_header[7] = MAX_ABSMAX_HISTORY; // size of the absmax history (0 = disabled or old version) // int main state, start at 10 to leave some padding state_header[10] = step; // step of the optimization // model rng state, start at 20 to leave some padding @@ -1173,6 +1175,11 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); } + + // write absmax history and scale/descale memory + device_to_file(state_file, gpu_absmax_memory, num_tensor_specs * sizeof(float) * (MAX_ABSMAX_HISTORY + 1), IO_BUF_SIZE, main_stream); + device_to_file(state_file, gpu_scale_memory, num_tensor_specs * sizeof(float) * 2, IO_BUF_SIZE, main_stream); + fcloseCheck(state_file); } @@ -1184,6 +1191,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename assert(state_header[1] == 1); // version number assert(state_header[2] == multi_gpu_config.num_processes); // number of processes assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process + assert(state_header[6] == num_tensor_specs); // number of tensor specs int use_master_weights = state_header[4]; // whether we're using fp32 master weights int should_shuffle = state_header[5]; // shuffle state of the dataloader *step = state_header[10]; // step of the optimization @@ -1191,6 +1199,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename model->rng_state_last_update = *((unsigned long long*)&state_header[22]); // last gpt2_update size_t current_shard_idx = *((size_t*)&state_header[30]); // shard index size_t current_sample_idx = *((size_t*)&state_header[32]); // position in shard + bool restore_absmax_history = (state_header[7] == MAX_ABSMAX_HISTORY); // todo - restore even if not an exact match // read AdamW m, v, master_weights (they are all float) // allocate all the needed memory as necessary @@ -1230,6 +1239,11 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename } dataloader_resume(loader, current_shard_idx, current_sample_idx); + if (restore_absmax_history) { + file_to_device(gpu_absmax_memory, state_file, num_tensor_specs * sizeof(float) * (MAX_ABSMAX_HISTORY + 1), IO_BUF_SIZE, main_stream); + file_to_device(gpu_scale_memory, state_file, num_tensor_specs * sizeof(float) * 2, IO_BUF_SIZE, main_stream); + } + // all done, close state file fcloseCheck(state_file); } From b1827d1855f809b9abf7abfd0bd461e807f05d5a Mon Sep 17 00:00:00 2001 From: ademeure Date: Thu, 19 Sep 2024 04:20:53 +0000 Subject: [PATCH 27/27] one last bit of cleanup before travelling --- llmc/adamw.cuh | 2 +- llmc/copy_and_fp8.h | 3 +-- llmc/cuda_utils.cuh | 10 +++++----- llmc/global_norm.cuh | 2 +- llmc/matmul.cuh | 19 ++++++++----------- llmc/tensor.cuh | 45 ++++++++++++++++++++++---------------------- train_gpt2.cu | 23 +++++++++++----------- 7 files changed, 51 insertions(+), 53 deletions(-) diff --git a/llmc/adamw.cuh b/llmc/adamw.cuh index 91627877b..340bc492a 100644 --- a/llmc/adamw.cuh +++ b/llmc/adamw.cuh @@ -138,7 +138,7 @@ __global__ void adamw_update_everything(int num_params_tensors, int start_tensor TensorGPU opt_v_tensor = tensor_specs_ptr[spec_id + 3*num_params_tensors]; TensorGPU master_tensor = use_master_weights ? tensor_specs_ptr[spec_id + 4*num_params_tensors] : opt_m_tensor; - float wd = (param_spec.flags & TENSOR_2D) ? weight_decay : 0.0f; + float wd = (param_spec.tensor_flags & TENSOR_2D) ? weight_decay : 0.0f; if (param_spec.data_type == DType::FP32) { idx = adamw_update_part((TensorGPU)param_spec, diff --git a/llmc/copy_and_fp8.h b/llmc/copy_and_fp8.h index b263d7ef1..a784c0bd7 100644 --- a/llmc/copy_and_fp8.h +++ b/llmc/copy_and_fp8.h @@ -36,8 +36,7 @@ __device__ float gelu_forward_elementwise(float x) { // ---------------------------------------------------------------------------- // CUDA kernels -// Same as copy_simple_kernel but with optional absmax and elementwise function options -// absmax is calculated before scaling but after the elementwise function +// Advanced copy with optional format conversion, absmax, scaling and elementwise operation template diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index c261333bc..02e16d0eb 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -89,7 +89,7 @@ __device__ void store128_same_length(ElementType* target, Packed128 } } -// todo - can we unify this with non-cs function somehow? +// with streaming cache hint (low persistence in L1/L2 caches) template __device__ void store128_same_length_cs(ElementType* target, Packed128 value) { int4 bits = value.get_bits(); @@ -201,8 +201,8 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa } else if constexpr (std::is_same::value) { // CUDA doesn't have round down/up instructions for FP8 (in SW or HW) so we do it ourselves // ARM-Intel-NVIDIA style FP8 E4M3 (different for AMD-Graphcore-Qualcomm format!) - // tried this approach to avoid fake_fp8 bug (didn't help), keeping it for now... - // todo: compare perf & accuracy to bit shifting method (do exhaustive testing) + // tried this approach to avoid bug with fake_fp8 (didn't help), keeping it for now... + // todo: check whether it properly matches the bit shifting method (do exhaustive testing!) float low = in; float high = in; @@ -230,8 +230,8 @@ __device__ void stochastic_rounding(float in, Ti &out, unsigned int random, floa } // ---------------------------------------------------------------------------- -__device__ float fake_fp8(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { -#ifdef FAKE_FP8 +__device__ float fake_low_precision(bool faking, float input, float scale, float descale, bool mode_e5, bool stochastic=false) { +#ifdef FAKE_LOW_PRECISION unsigned int random_number; if (faking && scale != 1.0f) { assert(scale == 1.0f/descale || descale == 1.0f/scale || scale == 1.0f); diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index d274f47e9..53a0a7490 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -1,7 +1,7 @@ // TODO - BUGGED - just committing my WIP, not sure why grad norm is zero, probably something silly! /* -Global norm, used in gralldient clipping +Global norm, used in gradient clipping */ #include #include diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index 8d3d2d8da..830797bed 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -117,9 +117,6 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX bool accumulate=false, tensorX pre_gelu=null_tensorX, bool backward=false) { NVTX_RANGE_FN(); - bool has_bias = (bias.data_ptr != NULL); - bool has_gelu = (pre_gelu.data_ptr != NULL); - // check alignment (some modes work unaligned but it always best to be aligned for performance) if(((uintptr_t)a.data_ptr % 16) != 0 || ((uintptr_t)b.data_ptr % 16) != 0 || ((uintptr_t)d.data_ptr % 16) != 0 || ((uintptr_t)bias.data_ptr % 16) != 0) { printf("All cuBLASLt pointers must be aligned!\n"); @@ -177,24 +174,24 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX // setup epilogue and associated pointers for bias & gelu cublasLtEpilogue_t epilogue; - if (has_gelu) { + if (pre_gelu.enabled()) { int64_t gelu_ld = m; // todo - is this affected by anything else? cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &gelu_ld, sizeof(gelu_ld))); cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &pre_gelu.data_ptr, sizeof(pre_gelu.data_ptr))); if (backward) { - assert(!has_bias); // we shouldn't have any backward matmuls that use both GELU and bias + assert(!bias.enabled()); // we shouldn't have any backward matmuls that use both GELU and bias epilogue = CUBLASLT_EPILOGUE_DGELU; } else { - epilogue = has_bias ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; + epilogue = bias.enabled() ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_AUX; } - } else if(has_bias){ + } else if(bias.enabled()){ epilogue = backward ? CUBLASLT_EPILOGUE_BGRADB : CUBLASLT_EPILOGUE_BIAS; } else { epilogue = CUBLASLT_EPILOGUE_DEFAULT; } cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); - if (has_bias) { + if (bias.enabled()) { // cuBLASLt requires bias in FP8 mode to be BF16... (sigh) cublasDataType_t bias_data_type = (sizeof(floatX) == 1) ? CUDA_R_16BF : CUBLAS_LOWP; // force BF16 bias for FP8 mode cublasCheck(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_data_type, sizeof(bias_data_type))); @@ -210,7 +207,7 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX cublasLtMatmulAlgoGetHeuristic(cublaslt_handle, operationDesc, ALayout, BLayout, CLayout, DLayout, preference, 1, &heuristic, &returnedResults); if (returnedResults == 0) { - printf("No cuBLASLt algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, has_bias); + printf("No cuBLASLt algorithm: m: %d, n: %d, k: %d, bias: %d\n", n, m, k, bias.enabled()); exit(EXIT_FAILURE); } @@ -222,8 +219,8 @@ void matmul_cublaslt(tensorX d, const tensorX a, const tensorX b, const tensorX &alpha, a, ALayout, b, BLayout, &beta, d, CLayout, d, DLayout, &heuristic.algo, cublaslt_workspace, cublaslt_workspace_size, stream)); - #ifdef FAKE_FP8 - update_absmax(d, false); // fake FP8 requires the absmax to work + #ifdef FAKE_LOW_PRECISION + update_absmax(d, false); // fake FP8 requires the absmax to work (cuBLAS can't do it for BF16) #endif // cleanups diff --git a/llmc/tensor.cuh b/llmc/tensor.cuh index 60cf78477..cb996f6a3 100644 --- a/llmc/tensor.cuh +++ b/llmc/tensor.cuh @@ -2,7 +2,7 @@ #define TENSOR_CUH // ... -//#define FAKE_FP8 +//#define FAKE_LOW_PRECISION #define UNIQUE_TENSOR_MEMORY false #define LAYERS_PER_ACTIVATION_CHECKPOINT 0 // 0 = disabled // ... @@ -113,7 +113,7 @@ struct TensorGPU { static constexpr bool no_scaling = (sizeof(ElementType) != 1); // todo - this prevents scaling FP16 __device__ __host__ float get_scalar(size_t index, bool disable_scaling=no_scaling) const { - #ifdef FAKE_FP8 + #ifdef FAKE_LOW_PRECISION disable_scaling = true; #endif ElementType* __restrict__ data_ptr_restricted = data_ptr; @@ -125,7 +125,7 @@ struct TensorGPU { } __device__ __host__ ElementType set_scalar(size_t index, float value, bool disable_scaling=no_scaling) { - #ifdef FAKE_FP8 + #ifdef FAKE_LOW_PRECISION disable_scaling = true; #endif ElementType* __restrict__ data_ptr_restricted = data_ptr; @@ -162,7 +162,7 @@ struct TensorSpec { char name[16]; TT tensor_type; DType data_type; - int flags; + short tensor_flags; size_t offset; // into tensor type's base pointer size_t start_element; // on this shard @@ -215,7 +215,7 @@ void print_tensor_elements(int tensor_id) { cudaMemcpy(&descale, &gpu_scale_memory[spec.id * 2 + 1], sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(&absmax, &gpu_absmax_memory[spec.id], sizeof(float), cudaMemcpyDeviceToHost); - printf("Printing tensor %s (tensor_type: %d, data_type: %d, flags: %d)\n", tensor_name, (int)tensor_type, (int)dtype, spec.flags); + printf("Printing tensor %s (tensor_type: %d, data_type: %d, flags: %d)\n", tensor_name, (int)tensor_type, (int)dtype, spec.tensor_flags); printf("GPU memory: %p\n", gpu_tensor); printf("CPU memory: %p\n", cpu_tensor); printf("Num elements: %zu\n", num_elements); @@ -268,7 +268,7 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, spec->name[15] = 0; spec->tensor_type = (tensor_type == TT::DEFAULT) ? current_tensor_type : tensor_type; spec->data_type = data_type; - spec->flags = flags; + spec->tensor_flags = flags; // parameter tensors must fit in a 32-bit unsigned integer (used as an optimisation in e.g. global_norm_tensors_loop) // todo - either 1) 32-bit everywhere (with a DEFINE?), 2) 64-bit everywhere despite the small performance impact, 3) ? @@ -281,7 +281,7 @@ int add_tensor_spec(const char* name, size_t total_elements, size_t num_shards, if (copy_offset_from >= 0) { TensorSpec base_spec = tensor_specs[copy_offset_from]; - base_spec.flags |= (flags & REUSED_MEMORY); + base_spec.tensor_flags |= (flags & REUSED_MEMORY); spec->offset = base_spec.offset; size_t original_tensor_bytes = base_spec.num_elements * sizeof_dtype(base_spec.data_type); @@ -343,8 +343,8 @@ __global__ void update_scale_descale_kernel(int num_tensor_specs, int absmax_his // FP8 e4m3 vs e5m2 (the latter is currently only used for activation gradients) bool use_e5m2 = (tensor_specs_ptr[tid].data_type == DType::FP8E5M2); - #ifdef FAKE_FP8 - if (tensor_specs_ptr[tid].flags & TFlags::GRADIENT && tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE) { + #ifdef FAKE_LOW_PRECISION + if (tensor_specs_ptr[tid].tensor_flags & TFlags::GRADIENT && tensor_specs_ptr[tid].tensor_type == TT::MULTIUSE) { use_e5m2 = true; } #endif @@ -401,9 +401,9 @@ private: bool wrote_absmax = false; int id = -1; - // fake fp8 mode (ignored without FAKE_FP8 define) - bool faking_fp8 = false; - bool mode_e5 = false; + // fake fp8 mode (ignored without FAKE_LOW_PRECISION define) + bool faking_low_precision = false; + bool faking_mode_e5 = false; public: bool scaling = (sizeof(ElementType) == 1); @@ -414,14 +414,14 @@ public: data_ptr = tensor.data_ptr; id = tensor.id; -#ifdef FAKE_FP8 +#ifdef FAKE_LOW_PRECISION // fake FP8 only applies to specific tensors to test expected training performance // todo - expand this to support more unusual formats and test things like blockwise scaling(?) if (!disable_scaling && id >= 0 && sizeof(ElementType) == 2 && tensor_specs_ptr[id].tensor_type != TT::PARAMETER_GRAD) { - if ((tensor_specs_ptr[id].flags & (TFlags::RESIDUAL | TFlags::EMBEDDING | TFlags::BIAS)) == 0) { - faking_fp8 = true; - if ((tensor_specs_ptr[id].flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::MULTIUSE)) { - mode_e5 = true; + if ((tensor_specs_ptr[id].tensor_flags & (TFlags::RESIDUAL | TFlags::EMBEDDING | TFlags::BIAS)) == 0) { + faking_low_precision = true; + if ((tensor_specs_ptr[id].tensor_flags & TFlags::GRADIENT) && (tensor_specs_ptr[id].tensor_type == TT::MULTIUSE)) { + faking_mode_e5 = true; } } } @@ -465,16 +465,17 @@ public: new_absmax = max(new_absmax, fabsf(value)); } - // get and set automatically apply scaling/descaling for FP8 values + // get() and set() automatically apply scaling & descaling for FP8 values __device__ float get(int index) { float value = (float)data128[index] * (scaling ? descale : 1.0f); - value = fake_fp8(faking_fp8, value, scale, descale, mode_e5); // ignored without FAKE_FP8 + // used to simulate FP8 and below (just returns the input without FAKE_LOW_PRECISION) + value = fake_low_precision(faking_low_precision, value, scale, descale, faking_mode_e5); return value; } __device__ void set(int index, float value) { float output = value * (scaling ? scale : 1.0f); - output = fake_fp8(faking_fp8, output, scale, descale, mode_e5); + output = fake_low_precision(faking_low_precision, output, scale, descale, faking_mode_e5); data128[index] = (ElementType)(output); add_value_stats(value, data128[index]); } @@ -510,8 +511,8 @@ public: // return value: if true, we can skip __syncthreads() in the calling function as we have just done one __device__ bool update_absmax(int thread_id, int num_threads, bool exit=false, bool forced=false) { - #ifdef FAKE_FP8 - if (absmax_ptr == NULL || !faking_fp8) { + #ifdef FAKE_LOW_PRECISION + if (absmax_ptr == NULL || !faking_low_precision) { return false; } forced = true; diff --git a/train_gpt2.cu b/train_gpt2.cu index d018884d6..80af71867 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1,4 +1,5 @@ -#define ENABLE_FP8 +#define ENABLE_FP8 // todo - makefile option +bool write_as_floatX = true; // todo - make command line option (and test it properly) /* GPT-2 Transformer Neural Net training loop. See README.md for usage. @@ -41,7 +42,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // Packed128, f128, x128 // warpReduceSum, warpReduceMax, blockReduce #include "llmc/cuda_utils.cuh" -// ... todo ... +// todo - document what tensor.cuh implements #include "llmc/tensor.cuh" // defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace // defines: cublas_compute, cublaslt_handle, cublas_handle @@ -77,7 +78,7 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/zero.cuh" // ---------------------------------------------------------------------------- -// global vars regarding GPU process and disk I/O +// global vars regarding the GPU process and disk I/O cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; char filename_buffer[512]; @@ -363,6 +364,7 @@ void gpt2_allocate(GPT2 *model) { cudaMemcpy(gpu_tensor_end_element, cpu_tensor_end_element, sizeof(size_t) * num_tensor_specs + 256, cudaMemcpyHostToDevice); free(cpu_tensor_end_element); + // todo - move this elsewhere so it's not in the middle of the parameter table... printf("number of parameter bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER] / (1024*1024)); printf("number of parameter gradient bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_GRAD] / (1024*1024)); printf("number of m bytes: %zu MiB\n", tensors_bytes[TT::PARAMETER_OPT_M] / (1024*1024)); @@ -449,7 +451,8 @@ void convert_fixed_parameters(GPT2* model, char* gpu_buffer, size_t fixed_size_b cudaMemset(gpu_buffer, 0, fixed_size_bytes); } -// to convert from variable precision parameters to a single precision (e.g. before checkpointing) +// convert from variable precision parameters to a single precision (e.g. before checkpointing) +// todo template void convert_to_fixed_parameters(GPT2* model, char* gpu_buffer) { size_t offset = 0; @@ -467,7 +470,6 @@ void convert_to_fixed_parameters(GPT2* model, char* gpu_buffer) { } } - // helper function to initialise sharded master weights from unsharded weights template void init_tensor_shard(TensorGPU out, int i) { @@ -511,7 +513,6 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters - bool write_as_floatX = true; if (write_as_floatX && model->num_parameters_bytes != model->num_parameters * sizeof(floatX)) { // convert the parameters to floatX before writing them assert(tensors_bytes[MULTIUSE] >= model->num_parameters * sizeof(floatX)); // todo - make this always work @@ -575,7 +576,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->config.channels = model_header[6]; model->config.padded_vocab_size = model_header[7]; - // key line to allocate all of the GPU buffers for all of the tensos + // key line to allocate all of the GPU buffers for all of the tensors gpt2_allocate(model); // if the number of bytes in the checkpoint doesn't match the number of bytes allocated, @@ -684,12 +685,12 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { int num_param_tensors = tensors_start[PARAMETER+1]; for (int i = 0; i < num_param_tensors; i++) { TensorSpec tensor = tensor_specs[i]; - if ((tensor.flags & TFlags::LAYERNORM) && !(tensor.flags & BIAS)) { + if ((tensor.tensor_flags & TFlags::LAYERNORM) && !(tensor.tensor_flags & BIAS)) { for (size_t j = 0; j < tensor.num_elements; j++) { params_memory_cpu[offset + j] = (floatX)1.0f; } } - if (tensor.flags & TENSOR_2D) { + if (tensor.tensor_flags & TENSOR_2D) { size_t n = tensor.num_elements; if (n == model->config.padded_vocab_size * model->config.channels) { n = model->config.vocab_size * model->config.channels; @@ -698,7 +699,7 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // in GPT-2, the projections back into the residual stream are additionally // scaled by 1/sqrt(2*L) for training stability float scale = 0.02f; - if (strstr(tensor.name, "proj") != NULL) { // always love a good strstr()... /s + if (strstr(tensor.name, "proj") != NULL) { // todo: yuck - use TFlags! scale *= residual_scale; } @@ -742,7 +743,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { if (!CUDNN_ENABLED && T != model->seq_len) { cudaCheck(cudaMemset(ACT_0(att), 0, L * B * NH * T * T * sizeof(floatX))); } - // validate inputs, all indices mucst be in the range [0, V) + // validate inputs, all indices must be in the range [0, V) tokenCheck(inputs, B*T, V); // copy inputs/targets to the model (fully synchronous with the host for now)