ggml : mul mat wip
ggml-ci
This commit is contained in:
parent
5b2b2dc6ae
commit
0822d27613
1 changed files with 20 additions and 17 deletions
37
ggml.c
37
ggml.c
|
@ -10510,32 +10510,35 @@ static void ggml_compute_forward_mul_mat(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallelize by src0 rows
|
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
const int64_t dr = (ne01 + nth - 1)/nth;
|
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||||
|
|
||||||
const int64_t ir10 = dr*ith;
|
|
||||||
const int64_t ir11 = MIN(ir10 + dr, ne01);
|
|
||||||
|
|
||||||
// src1 rows
|
// src1 rows
|
||||||
const int64_t nr1 = ne11*ne12*ne13;
|
const int64_t nr1 = ne11*ne12*ne13;
|
||||||
|
|
||||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
/*int64_t ir010 = 0;*/
|
||||||
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
/*int64_t ir011 = ne01;*/
|
||||||
|
|
||||||
|
// parallelize by src0 rows
|
||||||
|
const int64_t dr = (ne01 + nth - 1)/nth;
|
||||||
|
|
||||||
|
const int64_t ir010 = dr*ith;
|
||||||
|
const int64_t ir011 = MIN(ir010 + dr, ne01);
|
||||||
|
|
||||||
|
assert(ne12 % ne02 == 0);
|
||||||
|
assert(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const int64_t r2 = ne12/ne02;
|
||||||
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
|
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
|
||||||
const int64_t i13 = (ir1/(ne12*ne11));
|
const int64_t i13 = (ir1/(ne12*ne11));
|
||||||
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
|
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
|
||||||
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
|
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
|
||||||
|
|
||||||
const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
|
const int64_t i03 = i13/r3;
|
||||||
const int64_t i03 = (ir0/(ne02));
|
const int64_t i02 = i12/r2;
|
||||||
// Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
|
|
||||||
// See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
|
|
||||||
// GG: this is likely the correct way to broadcast, though need some more thought
|
|
||||||
// therefore leaving the comments to remind us for now
|
|
||||||
const int64_t i02 = (i12 / (ne12 / ne02));
|
|
||||||
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
|
|
||||||
// const int64_t i02 = (ir0 - i03*ne02);
|
|
||||||
|
|
||||||
const int64_t i1 = i11;
|
const int64_t i1 = i11;
|
||||||
const int64_t i2 = i12;
|
const int64_t i2 = i12;
|
||||||
|
@ -10554,7 +10557,7 @@ static void ggml_compute_forward_mul_mat(
|
||||||
|
|
||||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
for (int64_t ir = ir10; ir < ir11; ++ir) {
|
for (int64_t ir = ir010; ir < ir011; ++ir) {
|
||||||
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
|
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue