add GGML_ASSERT to catch ggml_rope and back value errors

This commit is contained in:
xaedes 2023-05-07 01:47:14 +02:00
parent 561fbe0d1b
commit e91b83b899
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

20
ggml.c
View file

@ -11430,8 +11430,8 @@ static void ggml_compute_forward_rope_f32(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3); GGML_ASSERT(ggml_nelements(src1) == 3);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -11454,12 +11454,16 @@ static void ggml_compute_forward_rope_f32(
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2); //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
assert(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT(n_dims <= nc);
GGML_ASSERT(n_dims % 2 == 0);
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -11520,8 +11524,8 @@ static void ggml_compute_forward_rope_f16(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
assert(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
assert(ggml_nelements(src1) == 3); GGML_ASSERT(ggml_nelements(src1) == 3);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -11544,12 +11548,16 @@ static void ggml_compute_forward_rope_f16(
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
//printf("n_past = %d, ne2 = %d\n", n_past, ne2); //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
assert(nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
const int nr = ggml_nrows(src0); const int nr = ggml_nrows(src0);
const int nc = src0->ne[0];
GGML_ASSERT(n_dims <= nc);
GGML_ASSERT(n_dims % 2 == 0);
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;