metal : add more general support for ggml_get_rows + tests
This commit is contained in:
parent
9064b1ca05
commit
2cbcba829f
4 changed files with 78 additions and 25 deletions
|
@ -3223,14 +3223,16 @@ kernel void kernel_get_rows(
|
|||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg[[threads_per_threadgroup]]) {
|
||||
const int i = tgpig;
|
||||
const int r = ((device int32_t *) src1)[i];
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
|
||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
|
||||
float4x4 temp;
|
||||
dequantize_func(
|
||||
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
|
||||
|
@ -3238,6 +3240,52 @@ kernel void kernel_get_rows(
|
|||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f16(
|
||||
device const void * src0,
|
||||
device const int * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i = tgpig;
|
||||
const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
const int64_t i02 = i/ne10;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg) {
|
||||
((device float *) ((device char *) dst + i*nb1))[ind] =
|
||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
|
||||
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
|
||||
#define BLOCK_SIZE_K 32
|
||||
|
@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
|
|||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb1,
|
||||
uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue