ggml : ggml_soft_max support F16/F32 mask/pos

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-22 13:46:23 +03:00
parent c11d05fec0
commit f725ca90fb
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
6 changed files with 105 additions and 34 deletions

View file

@ -352,6 +352,7 @@ kernel void kernel_sum_rows(
dst_row[0] = row_sum;
}
template<typename T>
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
@ -376,8 +377,8 @@ kernel void kernel_soft_max(
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 0.0f;
@ -456,6 +457,7 @@ kernel void kernel_soft_max(
}
}
template<typename T>
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
@ -480,8 +482,8 @@ kernel void kernel_soft_max_4(
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
float slope = 0.0f;
@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
}
}
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,