ggml : fix mul_mat_id work size
This commit is contained in:
parent
5e879c9977
commit
ac145fd2e3
1 changed files with 3 additions and 2 deletions
5
ggml.c
5
ggml.c
|
@ -16646,14 +16646,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
cur = 0;
|
||||||
const struct ggml_tensor * src0 = node->src[2];
|
const struct ggml_tensor * src0 = node->src[2];
|
||||||
const struct ggml_tensor * src1 = node->src[1];
|
const struct ggml_tensor * src1 = node->src[1];
|
||||||
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
|
||||||
if (src1->type != vec_dot_type) {
|
if (src1->type != vec_dot_type) {
|
||||||
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
|
||||||
}
|
}
|
||||||
const int n_as = ggml_get_op_params_i32(node, 1);
|
const int n_as = ggml_get_op_params_i32(node, 1);
|
||||||
cur = GGML_PAD(cur, sizeof(int64_t)); // align
|
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
||||||
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
||||||
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
|
||||||
} break;
|
} break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue