From a69935baac062911a83831fc708a7b705f62cfa5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 May 2024 18:55:10 +0300 Subject: [PATCH] ggml : assert contiguousness --- ggml-cuda.cu | 2 +- ggml-kompute.cpp | 4 +++- ggml-metal.m | 8 +++++--- ggml-sycl.cpp | 2 +- ggml.c | 32 ++++++++++++++++++++++---------- ggml.h | 6 +++++- 6 files changed, 37 insertions(+), 17 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 1ccf311e5..1172f7b2f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp index 6c6058b2a..ed59d2be6 100644 --- a/ggml-kompute.cpp +++ b/ggml-kompute.cpp @@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml { GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous + ggml_is_contiguous_2(src0); + ggml_is_contiguous_2(src1); + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); diff --git a/ggml-metal.m b/ggml-metal.m index c7fc069eb..a7e13bdcf 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute( { GGML_ASSERT(ne00 == ne10); - // TODO: assert that dim2 and dim3 are contiguous + ggml_is_contiguous_2(src0); + ggml_is_contiguous_2(src1); + GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); @@ -2187,7 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute( case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3 + GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -2249,7 +2251,7 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_NORM: { - GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3 + GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index a73448136..5cd97e4ff 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0, const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, diff --git a/ggml.c b/ggml.c index 7464ae397..b2b725f65 100644 --- a/ggml.c +++ b/ggml.c @@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) { tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } -static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { +GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) { + return ggml_is_contiguous(tensor); +} + +GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return @@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; } +GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; +} + GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32( const struct ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { @@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * grad = dst->src[1]; - GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); - GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); + GGML_ASSERT(ggml_is_contiguous_1(grad)); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, grad)); diff --git a/ggml.h b/ggml.h index f9deac7e8..f38699698 100644 --- a/ggml.h +++ b/ggml.h @@ -756,7 +756,6 @@ extern "C" { GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); - GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); @@ -765,6 +764,11 @@ extern "C" { GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor); + GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() + GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 + GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);