metal : add GGML_OP_REPEAT kernels (#7557)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-27 12:10:19 +03:00 committed by GitHub
parent 62bfef5194
commit 1d8fca72ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 96 additions and 4 deletions

View file

@ -168,6 +168,53 @@ kernel void kernel_div(
}
}
template<typename T>
kernel void kernel_repeat(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3 % ne03;
const int64_t i02 = i2 % ne02;
const int64_t i01 = i1 % ne01;
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
const int i00 = i0 % ne00;
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
}
}
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(