ggml : ggml_soft_max support F16/F32 mask/pos
ggml-ci
This commit is contained in:
parent
c11d05fec0
commit
f725ca90fb
6 changed files with 105 additions and 34 deletions
|
@ -352,6 +352,7 @@ kernel void kernel_sum_rows(
|
|||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -376,8 +377,8 @@ kernel void kernel_soft_max(
|
|||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
|
||||
device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||
|
||||
float slope = 0.0f;
|
||||
|
@ -456,6 +457,7 @@ kernel void kernel_soft_max(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max_4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
|
@ -480,8 +482,8 @@ kernel void kernel_soft_max_4(
|
|||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
|
||||
device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
|
||||
device const T * ppos = src2 != src0 ? (device const T *) src2 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||
|
||||
float slope = 0.0f;
|
||||
|
@ -562,6 +564,14 @@ kernel void kernel_soft_max_4(
|
|||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
|
||||
typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
|
||||
|
||||
template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
|
||||
template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
|
||||
template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
|
||||
template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
|
||||
|
||||
kernel void kernel_diag_mask_inf(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue