ggml : fix YARN + add tests + add asserts (#7617)

* tests : add rope tests

ggml-ci

* ggml : fixes (hopefully)

ggml-ci

* tests : add non-cont tests

ggml-ci

* cuda : add asserts for rope/norm + fix DS2

ggml-ci

* ggml : assert contiguousness

* tests : reduce RoPE tests

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-29 20:17:31 +03:00 committed by GitHub
parent cce3dcffc5
commit fb76ec31a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 167 additions and 104 deletions

View file

@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_ASSERT(ne00 == ne10);
// TODO: assert that dim2 and dim3 are contiguous
ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_GROUP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0));
//float eps;
//memcpy(&eps, dst->op_params, sizeof(float));
@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_NORM:
{
GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps;
memcpy(&eps, dst->op_params, sizeof(float));