store row groups in wdata and calculate only once in GGML_TASK_INIT
ggml-ci
This commit is contained in:
parent
3de63cf103
commit
7afb69b8f5
1 changed files with 41 additions and 36 deletions
75
ggml.c
75
ggml.c
|
@ -9835,9 +9835,22 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
|
// row groups
|
||||||
|
const int id = ggml_get_op_params_i32(dst, 0);
|
||||||
|
const int n_as = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
||||||
|
(char *) params->wdata :
|
||||||
|
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
||||||
|
|
||||||
|
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
||||||
|
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
|
||||||
|
|
||||||
|
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
|
||||||
|
|
||||||
if (params->type == GGML_TASK_INIT) {
|
if (params->type == GGML_TASK_INIT) {
|
||||||
if (src1->type != vec_dot_type) {
|
|
||||||
char * wdata = params->wdata;
|
char * wdata = params->wdata;
|
||||||
|
if (src1->type != vec_dot_type) {
|
||||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||||
|
|
||||||
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
assert(params->wsize >= ne11*ne12*ne13*row_size);
|
||||||
|
@ -9853,6 +9866,19 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize matrix_row_counts
|
||||||
|
GGML_ASSERT(wdata == wdata_src1_end);
|
||||||
|
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
||||||
|
|
||||||
|
// group rows by src0 matrix
|
||||||
|
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||||
|
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
||||||
|
|
||||||
|
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||||
|
MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
|
||||||
|
matrix_row_counts[row_id] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -9860,24 +9886,6 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int id = ggml_get_op_params_i32(dst, 0);
|
|
||||||
const int n_as = ggml_get_op_params_i32(dst, 1);
|
|
||||||
|
|
||||||
// group rows by src0 matrix
|
|
||||||
// TODO: allocate in wdata
|
|
||||||
#define MMID_MAX_BATCH 512
|
|
||||||
int matrix_row_counts[GGML_MAX_SRC-2] = {0}; // number of rows for each matrix
|
|
||||||
int matrix_rows[GGML_MAX_SRC-2][MMID_MAX_BATCH]; // row indices for each matrix
|
|
||||||
|
|
||||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
|
||||||
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
|
||||||
|
|
||||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
|
||||||
|
|
||||||
matrix_rows[row_id][matrix_row_counts[row_id]] = i01;
|
|
||||||
matrix_row_counts[row_id] += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// compute each matrix multiplication in sequence
|
// compute each matrix multiplication in sequence
|
||||||
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
|
||||||
const int64_t cne1 = matrix_row_counts[cur_a];
|
const int64_t cne1 = matrix_row_counts[cur_a];
|
||||||
|
@ -9934,10 +9942,10 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||||
const int64_t i13 = (ir1/(ne12*cne1)); // TODO: remove, src1 is always a matrix
|
const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
|
||||||
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
||||||
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
|
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
|
||||||
const int64_t i11 = matrix_rows[cur_a][_i11];
|
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
||||||
|
|
||||||
// broadcast src0 into src1
|
// broadcast src0 into src1
|
||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
|
@ -9972,6 +9980,8 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#undef MMID_MATRIX_ROW
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_out_prod
|
// ggml_compute_forward_out_prod
|
||||||
|
@ -16139,7 +16149,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
// FIXME: blas
|
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
|
@ -16473,20 +16482,16 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
const struct ggml_tensor * a = node->src[2];
|
const struct ggml_tensor * src0 = node->src[2];
|
||||||
const struct ggml_tensor * b = node->src[1];
|
const struct ggml_tensor * src1 = node->src[1];
|
||||||
const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type;
|
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
||||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
if (src1->type != vec_dot_type) {
|
||||||
if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) {
|
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||||
if (a->type != GGML_TYPE_F32) {
|
|
||||||
// here we need memory just for single 2D matrix from src0
|
|
||||||
cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]);
|
|
||||||
}
|
|
||||||
} else
|
|
||||||
#endif
|
|
||||||
if (b->type != vec_dot_type) {
|
|
||||||
cur = ggml_row_size(vec_dot_type, ggml_nelements(b));
|
|
||||||
}
|
}
|
||||||
|
const int n_as = ggml_get_op_params_i32(node, 1);
|
||||||
|
cur = GGML_PAD(cur, sizeof(int64_t)); // align
|
||||||
|
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||||
|
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_OUT_PROD:
|
case GGML_OP_OUT_PROD:
|
||||||
{
|
{
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue