Skip to content

Commit

Permalink
multi-threaded model initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed Aug 12, 2024
1 parent 6e6a528 commit 011f59c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 37 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ else
# Check for OpenMP support in GCC or Clang on Linux
ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)
CFLAGS += -fopenmp -DOMP
NVCC_FLAGS += -Xcompiler -fopenmp -DOMP
LDLIBS += -lgomp
$(info ✓ OpenMP found)
else
Expand Down
87 changes: 50 additions & 37 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,47 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool w
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_init_layer(GPT2 *model, int l, mt19937_state* rng, floatX* params) {
int offset = 0;
size_t L = model->config.num_layers;
float residual_scale = 1.0f / sqrtf(2.0f * L);
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// the layernorm parameters are all initialized to 1
if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once
for (size_t j = 0; j < model->param_elements[i]; j++) {
params[offset + j] = 1.0f;
}
}
// weights tensors are handled here
if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors
|| i == 4 || i == 6 || i == 10 || i == 12) {
size_t n = model->param_elements[i];
size_t layer_offset = 0;
if (i == 0) {
// for wte tensor (padded vocab) override to init V instead of Vp rows
n = model->config.vocab_size * model->config.channels;
}
if (i == 4 || i == 6 || i == 10 || i == 12) {
// weight tensors, we are only initializing layer l
assert(n % L == 0);
n = n / L;
layer_offset = l * n;
}
// in GPT-2, the projections back into the residual stream are additionally
// scaled by 1/sqrt(2*L) for training stability
float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;
// okay let's draw the random numbers and write them
float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));
normal_(fp32_buffer, n, 0.0f, scale, rng);
for (size_t j = 0; j < n; j++) {
params[offset + layer_offset + j] = (floatX)fp32_buffer[j];
}
free(fp32_buffer);
}
offset += model->param_elements[i];
}
}

void gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) {
int depth = atoi(depth_str);
assert(depth > 0); // atoi returns 0 if not a number
Expand Down Expand Up @@ -584,44 +625,16 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {
// we have to init all these tensors exactly in the order that PyTorch initializes them
// so that we can match them up and get correctness and exactly the same initial conditions
size_t L = model->config.num_layers;
size_t offset = 0;

// create a local rng for each layer, so we get determinism independent of the number of threads
std::vector<mt19937_state> layer_rng(L);
for (int l = 0; l < L; l++) {
offset = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// the layernorm parameters are all initialized to 1
if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once
for (size_t j = 0; j < model->param_elements[i]; j++) {
params_memory_cpu[offset + j] = 1.0f;
}
}
// weights tensors are handled here
if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors
|| i == 4 || i == 6 || i == 10 || i == 12) {
size_t n = model->param_elements[i];
size_t layer_offset = 0;
if (i == 0) {
// for wte tensor (padded vocab) override to init V instead of Vp rows
n = model->config.vocab_size * model->config.channels;
}
if (i == 4 || i == 6 || i == 10 || i == 12) {
// weight tensors, we are only initializing layer l
assert(n % L == 0);
n = n / L;
layer_offset = l * n;
}
// in GPT-2, the projections back into the residual stream are additionally
// scaled by 1/sqrt(2*L) for training stability
float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;
// okay let's draw the random numbers and write them
float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));
normal_(fp32_buffer, n, 0.0f, scale, &init_rng);
for (size_t j = 0; j < n; j++) {
params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j];
}
free(fp32_buffer);
}
offset += model->param_elements[i];
}
manual_seed(&layer_rng[l], randint32(&init_rng));
}

#pragma omp parallel for
for (int l = 0; l < L; l++) {
gpt2_init_layer(model, l, &layer_rng[l], params_memory_cpu);
}

// copy them to GPU
Expand Down

0 comments on commit 011f59c

Please sign in to comment.