diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 5bd8d9c8b..f42dd1c00 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -454,6 +454,7 @@ extern "C" { GGML_OP_RMS_NORM, GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, + GGML_OP_L2_NORM, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -1095,6 +1096,18 @@ extern "C" { int n_groups, float eps); + // l2 normalize along rows + // used in rwkv v7 + GGML_API struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + // a - x // b - dy GGML_API struct ggml_tensor * ggml_rms_norm_back( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index fdb430a43..356502cb3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7333,6 +7333,69 @@ static void ggml_compute_forward_group_norm( } } +// ggml_compute_forward_l2_norm + +static void ggml_compute_forward_l2_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps >= 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + memcpy(y, x, ne00 * sizeof(float)); + + const float scale = 1.0f/fmaxf(sqrtf(sum), eps); + + ggml_vec_scale_f32(ne00, y, scale); + } + } + } +} + +static void ggml_compute_forward_l2_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_l2_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_mul_mat static void ggml_compute_forward_mul_mat_one_chunk( @@ -12823,6 +12886,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_group_norm(params, tensor); } break; + case GGML_OP_L2_NORM: + { + ggml_compute_forward_l2_norm(params, tensor); + } break; case GGML_OP_MUL_MAT: { ggml_compute_forward_mul_mat(params, tensor); @@ -13235,6 +13302,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: case GGML_OP_MUL_MAT: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4dbaefdba..694081a89 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2191,6 +2191,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GROUP_NORM: ggml_cuda_op_group_norm(ctx, dst); break; + case GGML_OP_L2_NORM: + ggml_cuda_op_l2_norm(ctx, dst); + break; case GGML_OP_CONCAT: ggml_cuda_op_concat(ctx, dst); break; @@ -3135,6 +3138,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index f127616ed..2b3b441b5 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -201,6 +201,40 @@ static __global__ void rms_norm_back_f32( } } +template +static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html + const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * x[row*ncols + col]; + } +} + static void norm_f32_cuda( const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { @@ -248,6 +282,17 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * } } +static void l2_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + l2_norm_f32<<>>(x, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + l2_norm_f32<1024><<>>(x, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *) src0->data; @@ -340,3 +385,18 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + l2_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index d63d34380..706a5660a 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e3dc25f16..d71d49298 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -285,4 +285,11 @@ typedef struct { float eps; } ggml_metal_kargs_rms_norm; +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_l2_norm; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 944d90af3..f6c427f7f 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -177,6 +177,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, GGML_METAL_KERNEL_TYPE_RMS_NORM, + GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, @@ -782,6 +783,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); @@ -1210,6 +1212,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_GROUP_NORM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: return true; @@ -3052,6 +3055,42 @@ static void ggml_metal_encode_node( const int64_t nrows = ggml_nrows(src0); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case GGML_OP_L2_NORM: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_L2_NORM].pipeline; + + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_l2_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + + const int64_t nrows = ggml_nrows(src0); + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 44f04c909..f394d743c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1534,6 +1534,49 @@ kernel void kernel_rms_norm( } } +kernel void kernel_l2_norm( + constant ggml_metal_kargs_l2_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } + + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; + + // parallel sum + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); + } + sumf = simd_sum(sumf); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float scale = 1.0f/sqrt(max(sumf, args.eps)); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + y[i00] = x[i00] * scale; + } +} + kernel void kernel_group_norm( device const float * src0, device float * dst, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3d24d2165..7398e8623 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3270,6 +3270,12 @@ static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * ds GGML_SYCL_DEBUG("call %s done\n", __func__); } +static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_SYCL_DEBUG("call %s\n", __func__); + ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_l2_norm); + GGML_SYCL_DEBUG("call %s done\n", __func__); +} + static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_SYCL_DEBUG("call %s\n", __func__); ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm); @@ -4034,6 +4040,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens case GGML_OP_RMS_NORM: ggml_sycl_rms_norm(ctx, dst); break; + case GGML_OP_L2_NORM: + ggml_sycl_l2_norm(ctx, dst); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { return false; @@ -4545,6 +4554,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 9cf2be155..6439db21b 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -180,6 +180,50 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa } } +static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + const int tid = item_ct1.get_local_id(2); + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row * ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp, item_ct1); + if (block_size > WARP_SIZE) { + + int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; + int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + /* + DPCT1118:3: SYCL group functions and algorithms must be encountered in + converged control flow. You may need to adjust the code. + */ + item_ct1.barrier(sycl::access::fence_space::local_space); + size_t nreduce = nwarps / WARP_SIZE; + tmp = 0.f; + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += s_sum[lane_id + i * WARP_SIZE]; + } + tmp = warp_reduce_sum(tmp, item_ct1); + } + + const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[row * ncols + col] = scale * x[row * ncols + col]; + } +} + static void norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, queue_ptr stream, int device) { @@ -311,6 +355,48 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, } } +static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, + const int nrows, const float eps, + queue_ptr stream, int device) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + if (ncols < 1024) { + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + stream->submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + nullptr, WARP_SIZE); + }); + }); + } + else { + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + const sycl::range<3> block_dims(1, 1, work_group_size); + /* + DPCT1049:19: The work-group size passed to the SYCL kernel may exceed + the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->submit([&](sycl::handler& cgh) { + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + cgh); + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(WARP_SIZE)]] { + l2_norm_f32(x, dst, ncols, eps, item_ct1, + get_pointer(s_sum_acc_ct1), work_group_size); + }); + }); + } +} + void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst, const float* src0_dd, const float* src1_dd, float* dst_dd, @@ -376,3 +462,25 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr (void)dst; (void)src1_dd; } + +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + + (void)src1; + (void)dst; + (void)src1_dd; +} diff --git a/ggml/src/ggml-sycl/norm.hpp b/ggml/src/ggml-sycl/norm.hpp index a9ad9156f..11e91680c 100644 --- a/ggml/src/ggml-sycl/norm.hpp +++ b/ggml/src/ggml-sycl/norm.hpp @@ -32,4 +32,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* float* dst_dd, const queue_ptr& main_stream); +void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, + const ggml_tensor* src1, ggml_tensor* dst, + const float* src0_dd, const float* src1_dd, + float* dst_dd, + const queue_ptr& main_stream); + #endif // GGML_SYCL_NORM_HPP diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d32ba4efb..3b2d49242 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -239,6 +239,7 @@ struct vk_device_struct { vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_l2_norm_f32; vk_pipeline pipeline_gelu_f32; vk_pipeline pipeline_gelu_quick_f32; vk_pipeline pipeline_silu_f32; @@ -2069,6 +2070,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5203,6 +5205,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_f32; } return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: @@ -5542,6 +5549,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SUM_ROWS: { @@ -6058,6 +6066,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -7023,6 +7036,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: @@ -7075,6 +7089,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -7171,6 +7186,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM: ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -7305,6 +7324,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: @@ -8223,6 +8243,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_ACC: @@ -8747,6 +8768,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_group_norm(ggml_ctx, src0_clone, *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_L2_NORM) { + tensor_clone = ggml_l2_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 000000000..deba8c398 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77e7e1148..c31590bd9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -418,6 +418,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 3b4861542..4f3578a3f 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -925,6 +925,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM", "RMS_NORM_BACK", "GROUP_NORM", + "L2_NORM", "MUL_MAT", "MUL_MAT_ID", @@ -992,7 +993,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1022,6 +1023,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", "group_norm(x)", + "l2_norm(x)", "X*Y", "X[i]*Y", @@ -1089,7 +1091,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); +static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -2681,6 +2683,37 @@ struct ggml_tensor * ggml_group_norm_inplace( return ggml_group_norm_impl(ctx, a, n_groups, eps, true); } +// ggml_l2_norm + +static struct ggml_tensor * ggml_l2_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps, + bool inplace) { + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_f32(result, 0, eps); + + result->op = GGML_OP_L2_NORM; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_l2_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, false); +} + +struct ggml_tensor * ggml_l2_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps) { + return ggml_l2_norm_impl(ctx, a, eps, true); +} + // ggml_mul_mat static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1bfd41254..13846caf6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2953,6 +2953,32 @@ struct test_group_norm : public test_case { } }; +// GGML_OP_L2_NORM +struct test_l2_norm : public test_case { + const ggml_type type; + const std::array ne; + const float eps; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_l2_norm(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 64, 320, 1}, + float eps = 1e-12f) + : type(type), ne(ne), eps(eps) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_l2_norm(ctx, a, eps); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_ACC struct test_acc : public test_case { const ggml_type type; @@ -3984,8 +4010,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps)); } test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); + test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));