metal : add kernel_get_rows_i32
ggml-ci
This commit is contained in:
parent
ab62fc3e55
commit
289313716f
2 changed files with 33 additions and 0 deletions
|
@ -3829,6 +3829,35 @@ kernel void kernel_get_rows_f16(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_i32(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device int32_t * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue