ggml : fix mul_mat_id work size

This commit is contained in:
slaren 2024-01-08 03:51:15 +01:00
parent 5e879c9977
commit ac145fd2e3

5
ggml.c
View file

@ -16646,14 +16646,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break; } break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
cur = 0;
const struct ggml_tensor * src0 = node->src[2]; const struct ggml_tensor * src0 = node->src[2];
const struct ggml_tensor * src1 = node->src[1]; const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type; const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) { if (src1->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1)); cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
} }
const int n_as = ggml_get_op_params_i32(node, 1); const int n_as = ggml_get_op_params_i32(node, 1);
cur = GGML_PAD(cur, sizeof(int64_t)); // align cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
} break; } break;