ggml: Add op l2_norm
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
19d3c8293b
commit
5445300758
16 changed files with 489 additions and 2 deletions
|
@ -454,6 +454,7 @@ extern "C" {
|
||||||
GGML_OP_RMS_NORM,
|
GGML_OP_RMS_NORM,
|
||||||
GGML_OP_RMS_NORM_BACK,
|
GGML_OP_RMS_NORM_BACK,
|
||||||
GGML_OP_GROUP_NORM,
|
GGML_OP_GROUP_NORM,
|
||||||
|
GGML_OP_L2_NORM,
|
||||||
|
|
||||||
GGML_OP_MUL_MAT,
|
GGML_OP_MUL_MAT,
|
||||||
GGML_OP_MUL_MAT_ID,
|
GGML_OP_MUL_MAT_ID,
|
||||||
|
@ -1095,6 +1096,18 @@ extern "C" {
|
||||||
int n_groups,
|
int n_groups,
|
||||||
float eps);
|
float eps);
|
||||||
|
|
||||||
|
// l2 normalize along rows
|
||||||
|
// used in rwkv v7
|
||||||
|
GGML_API struct ggml_tensor * ggml_l2_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||||
|
|
|
@ -7333,6 +7333,69 @@ static void ggml_compute_forward_group_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_l2_norm
|
||||||
|
|
||||||
|
static void ggml_compute_forward_l2_norm_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
GGML_TENSOR_UNARY_OP_LOCALS
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(eps >= 0.0f);
|
||||||
|
|
||||||
|
// TODO: optimize
|
||||||
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||||
|
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||||
|
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||||
|
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||||
|
sum += (ggml_float)(x[i00] * x[i00]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
|
||||||
|
|
||||||
|
memcpy(y, x, ne00 * sizeof(float));
|
||||||
|
|
||||||
|
const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
|
||||||
|
|
||||||
|
ggml_vec_scale_f32(ne00, y, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_l2_norm(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const struct ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_l2_norm_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_mul_mat
|
// ggml_compute_forward_mul_mat
|
||||||
|
|
||||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||||
|
@ -12823,6 +12886,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_group_norm(params, tensor);
|
ggml_compute_forward_group_norm(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_l2_norm(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_mul_mat(params, tensor);
|
ggml_compute_forward_mul_mat(params, tensor);
|
||||||
|
@ -13235,6 +13302,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
|
|
|
@ -2191,6 +2191,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
ggml_cuda_op_group_norm(ctx, dst);
|
ggml_cuda_op_group_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_cuda_op_l2_norm(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
ggml_cuda_op_concat(ctx, dst);
|
ggml_cuda_op_concat(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -3135,6 +3138,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
break;
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
|
||||||
|
|
|
@ -201,6 +201,40 @@ static __global__ void rms_norm_back_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int block_size>
|
||||||
|
static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[row*ncols + col];
|
||||||
|
tmp += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
__shared__ float s_sum[32];
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
tmp = s_sum[lane_id];
|
||||||
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
|
||||||
|
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
dst[row*ncols + col] = scale * x[row*ncols + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void norm_f32_cuda(
|
static void norm_f32_cuda(
|
||||||
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
|
||||||
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
|
||||||
|
@ -248,6 +282,17 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
|
if (ncols < 1024) {
|
||||||
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
|
l2_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
|
} else {
|
||||||
|
const dim3 block_dims(1024, 1, 1);
|
||||||
|
l2_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
|
@ -340,3 +385,18 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
|
||||||
|
|
||||||
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const float * src0_d = (const float *)src0->data;
|
||||||
|
float * dst_d = (float *)dst->data;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
|
||||||
|
}
|
||||||
|
|
|
@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -285,4 +285,11 @@ typedef struct {
|
||||||
float eps;
|
float eps;
|
||||||
} ggml_metal_kargs_rms_norm;
|
} ggml_metal_kargs_rms_norm;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne00_4;
|
||||||
|
uint64_t nb01;
|
||||||
|
float eps;
|
||||||
|
} ggml_metal_kargs_l2_norm;
|
||||||
|
|
||||||
#endif // GGML_METAL_IMPL
|
#endif // GGML_METAL_IMPL
|
||||||
|
|
|
@ -177,6 +177,7 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
||||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
||||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||||
|
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_NORM,
|
GGML_METAL_KERNEL_TYPE_NORM,
|
||||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||||
|
@ -782,6 +783,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||||
|
@ -1210,6 +1212,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||||
case GGML_OP_ARGMAX:
|
case GGML_OP_ARGMAX:
|
||||||
return true;
|
return true;
|
||||||
|
@ -3052,6 +3055,42 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
const int64_t nrows = ggml_nrows(src0);
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline;
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = MIN(nth, ne00/4);
|
||||||
|
|
||||||
|
ggml_metal_kargs_l2_norm args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne00_4 =*/ ne00/4,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.eps =*/ eps,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
|
||||||
|
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||||
|
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
|
|
|
@ -1534,6 +1534,49 @@ kernel void kernel_rms_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_l2_norm(
|
||||||
|
constant ggml_metal_kargs_l2_norm & args,
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort ntg[[threads_per_threadgroup]]) {
|
||||||
|
if (sgitg == 0) {
|
||||||
|
shmem_f32[tiisg] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
|
||||||
|
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
// parallel sum
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
|
sumf += dot(x[i00], x[i00]);
|
||||||
|
}
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
sumf = shmem_f32[tiisg];
|
||||||
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
const float scale = 1.0f/sqrt(max(sumf, args.eps));
|
||||||
|
|
||||||
|
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
|
||||||
|
y[i00] = x[i00] * scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_group_norm(
|
kernel void kernel_group_norm(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
|
|
@ -3270,6 +3270,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds
|
||||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm);
|
||||||
|
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||||
|
@ -4034,6 +4040,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
ggml_sycl_rms_norm(ctx, dst);
|
ggml_sycl_rms_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_sycl_l2_norm(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -4545,6 +4554,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_SCALE:
|
case GGML_OP_SCALE:
|
||||||
|
|
|
@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
|
||||||
|
const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
|
||||||
|
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||||
|
item_ct1.get_local_id(1);
|
||||||
|
const int tid = item_ct1.get_local_id(2);
|
||||||
|
const int nthreads = item_ct1.get_local_range(2);
|
||||||
|
const int nwarps = nthreads / WARP_SIZE;
|
||||||
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[row * ncols + col];
|
||||||
|
tmp += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
|
||||||
|
int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||||
|
int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = tmp;
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
DPCT1118:3: SYCL group functions and algorithms must be encountered in
|
||||||
|
converged control flow. You may need to adjust the code.
|
||||||
|
*/
|
||||||
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||||
|
size_t nreduce = nwarps / WARP_SIZE;
|
||||||
|
tmp = 0.f;
|
||||||
|
for (size_t i = 0; i < nreduce; i += 1)
|
||||||
|
{
|
||||||
|
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||||
|
}
|
||||||
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
dst[row * ncols + col] = scale * x[row * ncols + col];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
const int nrows, const float eps,
|
const int nrows, const float eps,
|
||||||
queue_ptr stream, int device) {
|
queue_ptr stream, int device) {
|
||||||
|
@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||||
|
const int nrows, const float eps,
|
||||||
|
queue_ptr stream, int device) {
|
||||||
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
|
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||||
|
if (ncols < 1024) {
|
||||||
|
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
|
nullptr, WARP_SIZE);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||||
|
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
||||||
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||||
|
/*
|
||||||
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||||
|
the limit. To get the device limit, query
|
||||||
|
info::device::max_work_group_size. Adjust the work-group size if needed.
|
||||||
|
*/
|
||||||
|
stream->submit([&](sycl::handler& cgh) {
|
||||||
|
sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
|
||||||
|
cgh);
|
||||||
|
cgh.parallel_for(
|
||||||
|
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
|
||||||
|
block_dims),
|
||||||
|
[=](sycl::nd_item<3> item_ct1)
|
||||||
|
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||||
|
l2_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||||
|
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1,
|
||||||
ggml_tensor* dst, const float* src0_dd,
|
ggml_tensor* dst, const float* src0_dd,
|
||||||
const float* src1_dd, float* dst_dd,
|
const float* src1_dd, float* dst_dd,
|
||||||
|
@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
||||||
(void)dst;
|
(void)dst;
|
||||||
(void)src1_dd;
|
(void)src1_dd;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst,
|
||||||
|
const float* src0_dd, const float* src1_dd,
|
||||||
|
float* dst_dd,
|
||||||
|
const queue_ptr& main_stream) {
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||||
|
|
||||||
|
(void)src1;
|
||||||
|
(void)dst;
|
||||||
|
(void)src1_dd;
|
||||||
|
}
|
||||||
|
|
|
@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
||||||
float* dst_dd,
|
float* dst_dd,
|
||||||
const queue_ptr& main_stream);
|
const queue_ptr& main_stream);
|
||||||
|
|
||||||
|
void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||||
|
const ggml_tensor* src1, ggml_tensor* dst,
|
||||||
|
const float* src0_dd, const float* src1_dd,
|
||||||
|
float* dst_dd,
|
||||||
|
const queue_ptr& main_stream);
|
||||||
|
|
||||||
#endif // GGML_SYCL_NORM_HPP
|
#endif // GGML_SYCL_NORM_HPP
|
||||||
|
|
|
@ -239,6 +239,7 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_norm_f32;
|
vk_pipeline pipeline_norm_f32;
|
||||||
vk_pipeline pipeline_group_norm_f32;
|
vk_pipeline pipeline_group_norm_f32;
|
||||||
vk_pipeline pipeline_rms_norm_f32;
|
vk_pipeline pipeline_rms_norm_f32;
|
||||||
|
vk_pipeline pipeline_l2_norm_f32;
|
||||||
vk_pipeline pipeline_gelu_f32;
|
vk_pipeline pipeline_gelu_f32;
|
||||||
vk_pipeline pipeline_gelu_quick_f32;
|
vk_pipeline pipeline_gelu_quick_f32;
|
||||||
vk_pipeline pipeline_silu_f32;
|
vk_pipeline pipeline_silu_f32;
|
||||||
|
@ -2069,6 +2070,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
@ -5203,6 +5205,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
return ctx->device->pipeline_rms_norm_f32;
|
return ctx->device->pipeline_rms_norm_f32;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
|
return ctx->device->pipeline_l2_norm_f32;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(dst)) {
|
switch (ggml_get_unary_op(dst)) {
|
||||||
case GGML_UNARY_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
|
@ -5542,6 +5549,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
{
|
{
|
||||||
|
@ -6058,6 +6066,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
float * op_params = (float *)dst->op_params;
|
||||||
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
@ -7023,6 +7036,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
|
@ -7075,6 +7089,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
@ -7171,6 +7186,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
|
break;
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
|
ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(node)) {
|
switch (ggml_get_unary_op(node)) {
|
||||||
|
@ -7305,6 +7324,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
|
@ -8223,6 +8243,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
|
case GGML_OP_L2_NORM:
|
||||||
return ggml_is_contiguous(op->src[0]);
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
|
@ -8747,6 +8768,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
|
||||||
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
||||||
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
||||||
|
} else if (tensor->op == GGML_OP_L2_NORM) {
|
||||||
|
tensor_clone = ggml_l2_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
||||||
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
||||||
if (src1 != nullptr) {
|
if (src1 != nullptr) {
|
||||||
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
||||||
|
|
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
41
ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "generic_head.comp"
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#define BLOCK_SIZE 512
|
||||||
|
|
||||||
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||||
|
|
||||||
|
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||||
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
|
||||||
|
sum[tid] += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
sum[tid] += sum[tid + s];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
|
||||||
|
|
||||||
|
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
|
||||||
|
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
|
||||||
|
}
|
||||||
|
}
|
|
@ -418,6 +418,7 @@ void process_shaders() {
|
||||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||||
|
|
|
@ -925,6 +925,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"RMS_NORM",
|
"RMS_NORM",
|
||||||
"RMS_NORM_BACK",
|
"RMS_NORM_BACK",
|
||||||
"GROUP_NORM",
|
"GROUP_NORM",
|
||||||
|
"L2_NORM",
|
||||||
|
|
||||||
"MUL_MAT",
|
"MUL_MAT",
|
||||||
"MUL_MAT_ID",
|
"MUL_MAT_ID",
|
||||||
|
@ -992,7 +993,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"OPT_STEP_ADAMW",
|
"OPT_STEP_ADAMW",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1022,6 +1023,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"rms_norm(x)",
|
"rms_norm(x)",
|
||||||
"rms_norm_back(x)",
|
"rms_norm_back(x)",
|
||||||
"group_norm(x)",
|
"group_norm(x)",
|
||||||
|
"l2_norm(x)",
|
||||||
|
|
||||||
"X*Y",
|
"X*Y",
|
||||||
"X[i]*Y",
|
"X[i]*Y",
|
||||||
|
@ -1089,7 +1091,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"adamw(x)",
|
"adamw(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -2681,6 +2683,37 @@ struct ggml_tensor * ggml_group_norm_inplace(
|
||||||
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
return ggml_group_norm_impl(ctx, a, n_groups, eps, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_l2_norm
|
||||||
|
|
||||||
|
static struct ggml_tensor * ggml_l2_norm_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps,
|
||||||
|
bool inplace) {
|
||||||
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
|
|
||||||
|
ggml_set_op_params_f32(result, 0, eps);
|
||||||
|
|
||||||
|
result->op = GGML_OP_L2_NORM;
|
||||||
|
result->src[0] = a;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_l2_norm(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps) {
|
||||||
|
return ggml_l2_norm_impl(ctx, a, eps, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_l2_norm_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
float eps) {
|
||||||
|
return ggml_l2_norm_impl(ctx, a, eps, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_mul_mat
|
// ggml_mul_mat
|
||||||
|
|
||||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||||
|
|
|
@ -2953,6 +2953,32 @@ struct test_group_norm : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_L2_NORM
|
||||||
|
struct test_l2_norm : public test_case {
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne;
|
||||||
|
const float eps;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR2(type, ne);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_l2_norm(ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne = {64, 64, 320, 1},
|
||||||
|
float eps = 1e-12f)
|
||||||
|
: type(type), ne(ne), eps(eps) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
ggml_tensor * out = ggml_l2_norm(ctx, a, eps);
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_ACC
|
// GGML_OP_ACC
|
||||||
struct test_acc : public test_case {
|
struct test_acc : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
|
@ -3984,8 +4010,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
|
||||||
}
|
}
|
||||||
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
|
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
|
||||||
|
|
||||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
|
||||||
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue