ggml : assert contiguousness

This commit is contained in:
Georgi Gerganov 2024-05-29 18:55:10 +03:00
parent 5db268c9d8
commit a69935baac
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 37 additions and 17 deletions

View file

@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
} }
} }
#else #else
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx // use cublasGemmStridedBatchedEx
CUBLAS_CHECK( CUBLAS_CHECK(

View file

@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
{ {
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
// TODO: assert that dim2 and dim3 are contiguous ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);

View file

@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
{ {
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
// TODO: assert that dim2 and dim3 are contiguous ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);
GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0); GGML_ASSERT(ne13 % ne03 == 0);
@ -2187,7 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
{ {
GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3 GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));
@ -2249,7 +2251,7 @@ static enum ggml_status ggml_metal_graph_compute(
} break; } break;
case GGML_OP_NORM: case GGML_OP_NORM:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3 GGML_ASSERT(ggml_is_contiguous_1(src0));
float eps; float eps;
memcpy(&eps, dst->op_params, sizeof(float)); memcpy(&eps, dst->op_params, sizeof(float));

View file

@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
const int64_t r2 = ne12/ne02; const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03; const int64_t r3 = ne13/ne03;
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3 // there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans, *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,

32
ggml.c
View file

@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
} }
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) { GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
return ggml_is_contiguous(tensor);
}
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return return
@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
tensor->nb[3] == tensor->nb[2]*tensor->ne[2]; tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
} }
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
tensor->nb[0] == ggml_type_size(tensor->type) &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) { GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) { if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1]; const struct ggml_tensor * grad = dst->src[1];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad)); GGML_ASSERT(ggml_is_contiguous_1(grad));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0)); GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst)); GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, grad)); GGML_ASSERT(ggml_are_same_shape(src0, grad));

6
ggml.h
View file

@ -756,7 +756,6 @@ extern "C" {
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor); GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
@ -765,6 +764,11 @@ extern "C" {
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);