This commit is contained in:
slaren 2024-04-02 16:08:55 +02:00
parent 6875369909
commit 6f33852f3d
2 changed files with 7 additions and 17 deletions

View file

@ -2209,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
GGML_ASSERT(false);
CUDA_CHECK(err);
}
return true;

22
ggml.c
View file

@ -4573,6 +4573,8 @@ void ggml_mul_mat_set_prec(
// ggml_mul_mat_id
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
// this will allow computing all the used experts in a single matrix multiplication
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
@ -4581,12 +4583,11 @@ struct ggml_tensor * ggml_mul_mat_id(
struct ggml_tensor * b) {
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
GGML_ASSERT(ids->ne[1] == b->ne[1]);
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
//GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
GGML_ASSERT(id >= 0 && id < ids->ne[0]);
// TODO: restore checks
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
bool is_node = false;
@ -4605,14 +4606,6 @@ struct ggml_tensor * ggml_mul_mat_id(
result->src[1] = b;
result->src[2] = ids;
//for (int i = 0; i < n_as; i++) {
// struct ggml_tensor * a = as[i];
// GGML_ASSERT(ggml_are_same_shape(as[0], a));
// GGML_ASSERT(ggml_can_mul_mat(a, b));
// GGML_ASSERT(!ggml_is_transposed(a));
// result->src[i + 2] = a;
//}
return result;
}
@ -10980,9 +10973,6 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
//const int64_t r2 = ne12/ne02;
//const int64_t r3 = ne13/ne03;
// broadcast is not supported with mmid
assert(ne12 == 1);
assert(ne13 == 1);