wipwipwiwpip
This commit is contained in:
parent
fc59407efe
commit
ddc59e8e0a
4 changed files with 132 additions and 1 deletions
|
@ -2698,6 +2698,29 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||
|
||||
kernel void kernel_ssm_conv_f32(
|
||||
device const float * src0,
|
||||
device const float * src1,
|
||||
device const float * src2,
|
||||
device const int32_t * src3,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne20,
|
||||
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb21,
|
||||
constant uint64_t & nb22,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue