ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
This commit is contained in:
parent
c11d05fec0
commit
f725ca90fb
6 changed files with 105 additions and 34 deletions
|
@ -1,7 +1,17 @@
|
|||
#include "softmax.cuh"
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static __global__ void soft_max_f32(const float * x, const half * mask, const half * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ float t2f32(T val) {
|
||||
return (float) val;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float __forceinline__ t2f32<half>(half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -43,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
|
|||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? __half2float(mask[iy]) : 0.0f) + (pos ? slope*__half2float(pos[col]) : 0.0f);
|
||||
const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
|
@ -114,7 +124,8 @@ static __global__ void soft_max_f32(const float * x, const half * mask, const ha
|
|||
}
|
||||
}
|
||||
|
||||
static void soft_max_f32_cuda(const float * x, const half * mask, const half * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
|
@ -167,15 +178,19 @@ static void soft_max_f32_cuda(const float * x, const half * mask, const half * p
|
|||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const half * src1_d = src1 ? (const half *)src1->data : nullptr;
|
||||
const void * src1_d = src1 ? (const void *)src1->data : nullptr;
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
|
@ -188,14 +203,25 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
// positions tensor
|
||||
half * src2_dd = nullptr;
|
||||
void * src2_d = nullptr;
|
||||
|
||||
ggml_tensor * src2 = dst->src[2];
|
||||
const bool use_src2 = src2 != nullptr;
|
||||
|
||||
if (use_src2) {
|
||||
src2_dd = (half *)src2->data;
|
||||
src2_d = (void *)src2->data;
|
||||
}
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
const half * src2_dd = (const half *)src2_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
const float * src2_dd = (const float *)src2_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
||||
}
|
||||
}
|
||||
|
|
29
ggml-metal.m
29
ggml-metal.m
|
@ -46,8 +46,10 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
||||
GGML_METAL_KERNEL_TYPE_SILU,
|
||||
GGML_METAL_KERNEL_TYPE_SILU_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
|
||||
GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
|
||||
GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
|
||||
|
@ -492,8 +494,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4, soft_max_f32_4, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||
|
@ -1346,22 +1350,33 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||
|
||||
if (ne00%4 == 0) {
|
||||
while (nth < ne00/4 && nth < 256) {
|
||||
nth *= 2;
|
||||
}
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
|
||||
if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
|
||||
}
|
||||
} else {
|
||||
while (nth < ne00 && nth < 1024) {
|
||||
nth *= 2;
|
||||
}
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
||||
if (use_f16) {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16].pipeline;
|
||||
} else {
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32].pipeline;
|
||||
}
|
||||
}
|
||||
|
||||
float scale;
|
||||
|
|
|
@ -352,6 +352,7 @@ kernel void kernel_sum_rows(
|
|||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -376,8 +377,8 @@ kernel void kernel_soft_max(
|
|||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
|
||||
device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
|
@ -456,6 +457,7 @@ kernel void kernel_soft_max(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max_4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -480,8 +482,8 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
|
||||
device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
|||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||
|
||||
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||
|
||||
kernel void kernel_diag_mask_inf(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
|
36
ggml.c
36
ggml.c
|
@ -5473,7 +5473,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||
GGML_ASSERT(ggml_is_contiguous(a));
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(ggml_is_matrix(mask));
|
||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||
|
@ -5481,10 +5481,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|||
|
||||
if (pos) {
|
||||
GGML_ASSERT(ggml_is_vector(pos));
|
||||
GGML_ASSERT(pos->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
||||
}
|
||||
|
||||
if (pos && mask) {
|
||||
GGML_ASSERT(pos->type == mask->type);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
GGML_ASSERT(pos);
|
||||
}
|
||||
|
@ -12410,20 +12414,30 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
||||
|
||||
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
||||
ggml_fp16_t * pos = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
||||
ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
|
||||
float * pos_f32 = src2 ? (float *) src2->data : src0->data;
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
|
||||
|
||||
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
||||
// broadcast the mask across rows
|
||||
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
|
||||
|
||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||
ggml_vec_scale_f32(nc, wp, scale);
|
||||
if (mp) {
|
||||
if (mp_f32) {
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += GGML_FP16_TO_FP32(mp[i]);
|
||||
wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += mp_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12432,8 +12446,14 @@ static void ggml_compute_forward_soft_max_f32(
|
|||
const uint32_t h = (i1/ne01)%ne02; // head
|
||||
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
||||
|
||||
for (int i = 0; i < nc; i++) {
|
||||
wp[i] = wp[i] + slope*ggml_fp16_to_fp32(pos[i]);
|
||||
if (use_f16) {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
wp[i] += slope*pos_f32[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -6710,14 +6710,14 @@ struct llm_build_context {
|
|||
}
|
||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
|
||||
ggml_set_input(lctx.inp_KQ_mask);
|
||||
return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16);
|
||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_KQ_pos() {
|
||||
lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
|
||||
cb(lctx.inp_KQ_pos, "KQ_pos", -1);
|
||||
ggml_set_input(lctx.inp_KQ_pos);
|
||||
return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16);
|
||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_mean() {
|
||||
|
|
|
@ -1120,11 +1120,11 @@ struct test_soft_max : public test_case {
|
|||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_tensor * mask = nullptr;
|
||||
if (this->mask) {
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]);
|
||||
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
|
||||
}
|
||||
ggml_tensor * pos = nullptr;
|
||||
if (max_bias > 0.0f) {
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]);
|
||||
pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
|
||||
}
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
|
||||
return out;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue