ggml : assert contiguousness
This commit is contained in:
parent
5db268c9d8
commit
a69935baac
6 changed files with 37 additions and 17 deletions
|
@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||||
// use cublasGemmStridedBatchedEx
|
// use cublasGemmStridedBatchedEx
|
||||||
CUBLAS_CHECK(
|
CUBLAS_CHECK(
|
||||||
|
|
|
@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
// TODO: assert that dim2 and dim3 are contiguous
|
ggml_is_contiguous_2(src0);
|
||||||
|
ggml_is_contiguous_2(src1);
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
|
|
@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 == ne10);
|
GGML_ASSERT(ne00 == ne10);
|
||||||
|
|
||||||
// TODO: assert that dim2 and dim3 are contiguous
|
ggml_is_contiguous_2(src0);
|
||||||
|
ggml_is_contiguous_2(src1);
|
||||||
|
|
||||||
GGML_ASSERT(ne12 % ne02 == 0);
|
GGML_ASSERT(ne12 % ne02 == 0);
|
||||||
GGML_ASSERT(ne13 % ne03 == 0);
|
GGML_ASSERT(ne13 % ne03 == 0);
|
||||||
|
|
||||||
|
@ -2187,7 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
@ -2249,7 +2251,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
|
@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
|
||||||
const int64_t r2 = ne12/ne02;
|
const int64_t r2 = ne12/ne02;
|
||||||
const int64_t r3 = ne13/ne03;
|
const int64_t r3 = ne13/ne03;
|
||||||
|
|
||||||
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
|
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
||||||
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
|
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
|
||||||
|
|
32
ggml.c
32
ggml.c
|
@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
|
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
|
||||||
|
return ggml_is_contiguous(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
return
|
return
|
||||||
|
@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
|
||||||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
|
||||||
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
return
|
||||||
|
tensor->nb[0] == ggml_type_size(tensor->type) &&
|
||||||
|
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||||
|
|
||||||
|
@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
|
||||||
|
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
|
||||||
|
@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
|
||||||
const struct ggml_tensor * src0 = dst->src[0];
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
const struct ggml_tensor * grad = dst->src[1];
|
const struct ggml_tensor * grad = dst->src[1];
|
||||||
|
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
|
GGML_ASSERT(ggml_is_contiguous_1(grad));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
GGML_ASSERT(ggml_are_same_shape(src0, grad));
|
||||||
|
|
||||||
|
|
6
ggml.h
6
ggml.h
|
@ -756,7 +756,6 @@ extern "C" {
|
||||||
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
|
GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
|
||||||
|
|
||||||
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
|
||||||
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||||
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
|
||||||
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
|
||||||
|
@ -765,6 +764,11 @@ extern "C" {
|
||||||
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
|
GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
|
||||||
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
|
GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
|
||||||
|
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous (const struct ggml_tensor * tensor);
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
|
||||||
|
GGML_API GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
|
||||||
|
|
||||||
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue