metal : normalize mat-vec kernel signatures

This commit is contained in:
Georgi Gerganov 2023-12-31 12:31:26 +02:00
parent ad7cf37fe8
commit 049a32fffa
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -920,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -939,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -958,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -977,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1082,8 +1110,8 @@ kernel void kernel_mul_mv_q8_0_f32(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1189,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@ -1216,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
@ -1353,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@ -1459,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@ -1485,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
constant uint64_t & nb12, constant uint64_t & nb12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & r2 [[buffer(17)]], constant uint & r2,
constant uint & r3 [[buffer(18)]], constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) { uint tiisg[[thread_index_in_simdgroup]]) {
@ -2576,14 +2604,21 @@ kernel void kernel_mul_mv_q2_K_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -2833,14 +2868,21 @@ kernel void kernel_mul_mv_q3_K_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -3064,14 +3106,21 @@ kernel void kernel_mul_mv_q4_K_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -3263,14 +3312,21 @@ kernel void kernel_mul_mv_q5_K_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -3390,14 +3446,21 @@ kernel void kernel_mul_mv_q6_K_f32(
device const float * src1, device const float * src1,
device float * dst, device float * dst,
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01[[buffer(4)]], constant int64_t & ne01,
constant int64_t & ne02[[buffer(5)]], constant int64_t & ne02,
constant int64_t & ne10[[buffer(9)]], constant uint64_t & nb00,
constant int64_t & ne12[[buffer(11)]], constant uint64_t & nb01,
constant int64_t & ne0 [[buffer(15)]], constant uint64_t & nb02,
constant int64_t & ne1 [[buffer(16)]], constant int64_t & ne10,
constant uint & r2 [[buffer(17)]], constant int64_t & ne11,
constant uint & r3 [[buffer(18)]], constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {