Skip to content

Commit

Permalink
ggml : fix ggml_get_rows to take into account ne02 / ne11
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 9, 2023
1 parent ee8fb39 commit 9064b1c
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10342,20 +10342,27 @@ static void ggml_compute_forward_get_rows_q(
return;
}

const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;

assert( dst->ne[0] == nc);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == ggml_type_size(type));
assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == ggml_type_size(type));

for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
// TODO: multi-thread
for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];

const int64_t i02 = i/ne10;

dequantize_row_q(
(const void *) ((char *) src0->data + r*src0->nb[1]),
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
}
}
Expand All @@ -10371,19 +10378,25 @@ static void ggml_compute_forward_get_rows_f16(
return;
}

const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert( dst->ne[0] == nc);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(ggml_fp16_t));
assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == sizeof(ggml_fp16_t));

for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
// TODO: multi-thread
for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];

const int64_t i02 = i/ne10;

for (int j = 0; j < nc; ++j) {
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
}
}
}
Expand All @@ -10399,19 +10412,25 @@ static void ggml_compute_forward_get_rows_f32(
return;
}

const int nc = src0->ne[0];
const int nr = ggml_nelements(src1);
GGML_TENSOR_BINARY_OP_LOCALS

const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);

assert( dst->ne[0] == nc);
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(dst) == nr);
assert(src0->nb[0] == sizeof(float));

for (int i = 0; i < nr; ++i) {
const int r = ((int32_t *) src1->data)[i];
// TODO: multi-thread
for (int64_t i = 0; i < nr; ++i) {
const int64_t r = ((int32_t *) src1->data)[i];

const int64_t i02 = i/ne10;

ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*dst->nb[1]),
(float *) ((char *) src0->data + r*src0->nb[1]));
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
}
}

Expand Down

0 comments on commit 9064b1c

Please sign in to comment.