ggml : add n_as argument to ggml_mul_mat_id
This commit is contained in:
parent
7372b62271
commit
ee8fb399aa
6 changed files with 17 additions and 14 deletions
14
ggml.c
14
ggml.c
|
@ -4076,12 +4076,11 @@ struct ggml_tensor * ggml_mul_mat(
|
|||
struct ggml_tensor * ggml_mul_mat_id(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * as[],
|
||||
int n_as,
|
||||
struct ggml_tensor * ids,
|
||||
int id,
|
||||
struct ggml_tensor * b) {
|
||||
|
||||
int64_t n_as = ids->ne[0];
|
||||
|
||||
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
|
||||
GGML_ASSERT(ids->ne[1] == b->ne[1]);
|
||||
|
@ -4099,6 +4098,7 @@ struct ggml_tensor * ggml_mul_mat_id(
|
|||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
|
||||
|
||||
ggml_set_op_params_i32(result, 0, id);
|
||||
ggml_set_op_params_i32(result, 1, n_as);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT_ID;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -4106,8 +4106,7 @@ struct ggml_tensor * ggml_mul_mat_id(
|
|||
result->src[1] = b;
|
||||
|
||||
// TODO: n_as is the selected experts, but it should be the total number of experts
|
||||
//for (int64_t i = 0; i < n_as; i++) {
|
||||
for (int64_t i = 0; i < 8; i++) {
|
||||
for (int i = 0; i < n_as; i++) {
|
||||
struct ggml_tensor * a = as[i];
|
||||
GGML_ASSERT(ggml_are_same_shape(as[0], a));
|
||||
GGML_ASSERT(ggml_can_mul_mat(a, b));
|
||||
|
@ -9757,14 +9756,13 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
}
|
||||
|
||||
const struct ggml_tensor * ids = src0;
|
||||
const int id = ggml_get_op_params_i32(dst, 0);
|
||||
const int id = ggml_get_op_params_i32(dst, 0);
|
||||
const int n_as = ggml_get_op_params_i32(dst, 1);
|
||||
|
||||
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
||||
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
|
||||
|
||||
// TODO: this assert seems wrong?
|
||||
//printf("row_id = %d, ids->ne[0] = %d, id = %d\n", row_id, ids->ne[0], id);
|
||||
//GGML_ASSERT(row_id >= 0 && row_id < ids->ne[0]);
|
||||
GGML_ASSERT(row_id >= 0 && row_id < n_as);
|
||||
|
||||
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
|
||||
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue