metal : use uint64_t for strides

This commit is contained in:
Georgi Gerganov 2023-12-31 12:07:58 +02:00
parent b14b5a9eb3
commit 4c054d98d4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -59,26 +59,26 @@ kernel void kernel_add(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & nb00,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & nb03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & nb13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant int64_t & nb0,
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int64_t & offs,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
@ -109,26 +109,26 @@ kernel void kernel_mul(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & nb00,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & nb03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & nb13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant int64_t & nb0,
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@ -158,26 +158,26 @@ kernel void kernel_div(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & nb00,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & nb03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & nb13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant int64_t & nb0,
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@ -205,7 +205,7 @@ kernel void kernel_add_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb [[buffer(28)]],
constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
@ -214,7 +214,7 @@ kernel void kernel_mul_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb [[buffer(28)]],
constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
@ -223,7 +223,7 @@ kernel void kernel_div_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
constant int64_t & nb [[buffer(28)]],
constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
@ -307,26 +307,26 @@ kernel void kernel_sum_rows(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & nb00,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & nb03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & nb13,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant int64_t & nb0,
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
@ -3777,12 +3777,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@ -3909,12 +3909,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@ -3950,19 +3950,19 @@ kernel void kernel_mul_mm_id(
device const uchar * ids,
device const uchar * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4055,12 +4055,12 @@ typedef void (mat_mm_t)(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@ -4089,19 +4089,19 @@ typedef void (mat_mm_id_t)(
device const uchar * ids,
device const uchar * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne12,
constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4138,7 +4138,7 @@ kernel void kernel_mul_mv_id_f32_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4154,7 +4154,7 @@ kernel void kernel_mul_mv_id_f32_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4207,7 +4207,7 @@ kernel void kernel_mul_mv_id_f16_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4223,7 +4223,7 @@ kernel void kernel_mul_mv_id_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4276,7 +4276,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4292,7 +4292,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4339,7 +4339,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4355,7 +4355,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4402,7 +4402,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4418,7 +4418,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4465,7 +4465,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4481,7 +4481,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4528,7 +4528,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4544,7 +4544,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4591,7 +4591,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4607,7 +4607,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4654,7 +4654,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4670,7 +4670,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4717,7 +4717,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4733,7 +4733,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4780,7 +4780,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4796,7 +4796,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@ -4843,7 +4843,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
device const char * ids,
device const char * src1,
device uchar * dst,
constant int64_t & nbi1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@ -4859,7 +4859,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & nb1,
constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,