add some new ops, fix some operators and add batch operations to certain operators. (ggml/747)

* cuda: fix group_norm

* cuda: add batch inference support for ggml_pad/ggml_upscale

* add ggml_arrange

* add ggml_timestep_embedding

* update ggml_arange/ggml_timestep_embedding tests

* cuda: fix im2col

* add ggml_arange/ggml_timestep_embbeding support for metal backend

* fix some bugs

* fix some bugs

* Update ggml.h

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-cuda.cu

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.m

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml-metal.metal

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* modify according to the review comments

* ggml : fix compile warnings + code style

* ggml : normalize compute_forward calls + fix seg fault in debug

* minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
leejet 2024-03-03 20:23:52 +08:00 committed by Georgi Gerganov
parent 82f3e668ad
commit 7d43c585dc
No known key found for this signature in database
GPG key ID: BF970631944C16B7
6 changed files with 550 additions and 52 deletions

View file

@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
return false;
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
return true;
@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
{
GGML_ASSERT(ggml_is_contiguous(src0));
const float scale = *(const float *) dst->op_params;
float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
int64_t n = ggml_nelements(dst);
@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
}
const float scale = ((float *) dst->op_params)[0];
const float max_bias = ((float *) dst->op_params)[1];
float scale;
float max_bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARANGE:
{
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start;
float step;
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
[encoder setBytes:&start length:sizeof(start) atIndex:2];
[encoder setBytes:&step length:sizeof(step) atIndex:3];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
const int dim = dst->op_params[0];
const int max_period = dst->op_params[1];
const int half = dim / 2;
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
const int nth = MIN(1024, half);
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);