metal : add more general support for ggml_get_rows + tests

This commit is contained in:
Georgi Gerganov 2023-12-09 14:18:42 +02:00
parent 9064b1ca05
commit 2cbcba829f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 78 additions and 25 deletions

View file

@ -3223,14 +3223,16 @@ kernel void kernel_get_rows(
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg[[threads_per_threadgroup]]) {
const int i = tgpig;
const int r = ((device int32_t *) src1)[i];
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
for (int ind = tiitg; ind < ne00/16; ind += tptg) {
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg) {
float4x4 temp;
dequantize_func(
((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
@ -3238,6 +3240,52 @@ kernel void kernel_get_rows(
}
}
kernel void kernel_get_rows_f32(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
const int64_t i02 = i/ne10;
for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
kernel void kernel_get_rows_f16(
device const void * src0,
device const int * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tptg [[threads_per_threadgroup]]) {
const int64_t i = tgpig;
const int64_t r = ((device int32_t *) src1)[i];
const int64_t i02 = i/ne10;
for (int ind = tiitg; ind < ne00; ind += tptg) {
((device float *) ((device char *) dst + i*nb1))[ind] =
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
@ -3490,11 +3538,13 @@ typedef void (get_rows_t)(
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant uint64_t & nb1,
uint, uint, uint);
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;