diff --git a/ggml-metal.m b/ggml-metal.m index 0953af6a4..89b17ce5e 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -45,13 +45,20 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); + GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); + GGML_METAL_DECL_KERNEL(get_rows_q2_k); + GGML_METAL_DECL_KERNEL(get_rows_q4_k); + GGML_METAL_DECL_KERNEL(get_rows_q6_k); GGML_METAL_DECL_KERNEL(rms_norm); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -99,7 +106,7 @@ struct ggml_metal_context * ggml_metal_init(void) { NSError * error = nil; //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSString * path = [[NSBundle mainBundle] pathForResource:@"ggml-metal" ofType:@"metal"]; + NSString * path = @"./ggml-metal.metal"; fprintf(stderr, "%s: loading '%s'\n", __func__, [path UTF8String]); NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; @@ -129,13 +136,20 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); + GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); + GGML_METAL_ADD_KERNEL(get_rows_q2_k); + GGML_METAL_ADD_KERNEL(get_rows_q4_k); + GGML_METAL_ADD_KERNEL(get_rows_q6_k); GGML_METAL_ADD_KERNEL(rms_norm); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -408,6 +422,20 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_GELU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + [encoder setComputePipelineState:ctx->pipeline_gelu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { if (encoder == nil) { @@ -514,10 +542,41 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 8; - nth1 = 4; + nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - default: GGML_ASSERT(false && "not implemented"); + case GGML_TYPE_Q2_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32]; + } break; + case GGML_TYPE_Q4_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32]; + } break; + case GGML_TYPE_Q6_K: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 4; + nth1 = 16; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32]; + } break; + default: + { + fprintf(stderr, "Asserting on type %d\n",(int)src0t); + GGML_ASSERT(false && "not implemented"); + } }; @@ -540,6 +599,15 @@ void ggml_metal_graph_compute( if (src0t == GGML_TYPE_Q4_0) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q2_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q4_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else if (src0t == GGML_TYPE_Q6_K) { + [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; @@ -555,6 +623,9 @@ void ggml_metal_graph_compute( switch (src0->type) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index a359bebe2..745fe8ad3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -81,6 +81,17 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +constant float GELU_COEF_A = 0.044715f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +kernel void kernel_gelu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + float x = src0[tpig]; + dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + kernel void kernel_soft_max( device const float * src0, device float * dst, @@ -267,6 +278,8 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; + const int8_t m8 = 8; + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -276,45 +289,65 @@ kernel void kernel_mul_mat_q4_0_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - sum[ith] = 0.0f; + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 - for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs; - device const float4 * y0p = (device const float4 *) (y + i*QK4_0); + const int first = 4 * iy; - const float d = (float)((x + i)->d); + float sumf = 0; - const uchar4 x0v = *(x0p + tpitg.y); - const float4 y0v = *(y0p + tpitg.y + 0); - const float4 y1v = *(y0p + tpitg.y + 4); + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - float acc = 0.0f; + const float d = (float)x[i].d; + + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_0 + first; + + float2 acc = {0.0f, 0.0f}; for (int j = 0; j < 4; ++j) { - const int x0 = x0v[j] & 0x0F; - const int x1 = x0v[j] >> 4; - const float y0 = y0v[j]; - const float y1 = y1v[j]; + acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); + acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); - acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - sum[ith] += acc*d; + sumf += d * (acc[0] + acc[1]); } - // accumulate the sum from all threads in the threadgroup + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = nth/2; i > 0; i /= 2) { - if (ith < i) { - sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_mul_mat_f16_f32( @@ -338,6 +371,7 @@ kernel void kernel_mul_mat_f16_f32( uint3 tpig[[thread_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; @@ -503,3 +537,474 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +//============================================ k-quants ====================================================== + +#define QK_K 256 + +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins +} block_q2_k; + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_k; + +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_k; + +static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { + uchar4 r; + if (j < 4) { + r[0] = q[j+0] & 63; r[1] = q[j+4] & 63; + r[2] = q[j+1] & 63; r[3] = q[j+5] & 63; + } else { + r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4); + r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4); + } + return r; +} + +//========================================== dequantization ============================= + +static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + const float min = x[i].dmin; + + device const uint8_t * q = x[i].qs; + + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } + + } +} + +static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = x[i].d; + const float min = x[i].dmin; + + device const uint8_t * q = x[i].qs; + device const uint8_t * scales = x[i].scales; + + int is = 0; + for (int j = 0; j < QK_K; j += 64) { + const uchar4 sc = get_scale_min_k4(is, scales); + const float d1 = d * sc[0]; const float m1 = min * sc[1]; + const float d2 = d * sc[2]; const float m2 = min * sc[3]; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + + } +} + +static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + device const uint8_t * ql = x[i].ql; + device const uint8_t * qh = x[i].qh; + device const int8_t * sc = x[i].scales; + + const float d = x[i].d; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +kernel void kernel_get_rows_q2_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q2_k( + (device const block_q2_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_get_rows_q4_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q4_k( + (device const block_q4_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +kernel void kernel_get_rows_q6_k( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + dequantize_row_q6_k( + (device const block_q6_k *) ((device char *) src0 + r*nb01), + (device float *) ((device char *) dst + i*nb1), ne00); +} + +//====================================== dot products ========================= + +kernel void kernel_mul_mat_q2_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const int nth = tptg.x*tptg.y; + const int ith = tptg.y*tpitg.x + tpitg.y; + + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int ip = il/2; // 0 or 1 + const int shift1 = 4*(il%2);// 0 or 4 + const int shift2 = shift1+2;// 2 or 6 + const int n = 8; + const int is = 4*il + (n*ir)/16; + + sum[ith] = 0.0f; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q = x[i].qs + 32*ip + n*ir; + device const uint8_t * scales = x[i].scales + is; + + uint8_t d1 = scales[0] & 0xF; + uint8_t m1 = scales[0] >> 4; + uint8_t d2 = scales[2] & 0xF; + uint8_t m2 = scales[2] >> 4; + + device const float * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = (float)x[i].d; + const float dmin = (float)x[i].dmin; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; + s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; + } + sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); + + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} +} + +kernel void kernel_mul_mat_q4_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int tid = tpitg.y; // 0...16 + const int il = tid/4; // 0...3 + const int ir = tid%4; // 0...3 + const int n = 8; + const int is = 2*il; + + sum[ith] = 0.0f; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * q = (x + i)->qs + 32*il + n*ir; + device const float * y = yy + i*QK_K + 64*il + n*ir; + device const uint8_t * scales = (x + i)->scales; + + const float dall = (float)((x + i)->d); + const float dmin = (float)((x + i)->dmin); + + const uchar4 sc = get_scale_min_k4(is, scales); + + float4 s = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * (q[l] & 0xF); s[1] += y[l+ 0]; + s[2] += y[l+32] * (q[l] >> 4); s[3] += y[l+32]; + } + sumf += dall * (s[0] * sc[0] + s[2] * sc[2]) - dmin * (s[1] * sc[1] + s[3] * sc[3]); + + } + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // This version is slightly faster than the commented out one below, + // which I copy-pasted from ggerganov's q4_0 dot product for metal. + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + + //// accumulate the sum from all threads in the threadgroup + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = nth/2; i > 0; i /= 2) { + // if (ith < i) { + // sum[ith] += sum[ith + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + + //if (ith == 0) { + // dst[r1*ne0 + r0] = sum[0]; + //} +} + +kernel void kernel_mul_mat_q6_k_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + threadgroup float * sum [[threadgroup(0)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint2 tpig[[thread_position_in_grid]], // we don't use this for now + uint2 tpitg[[thread_position_in_threadgroup]], + uint2 tptg[[threads_per_threadgroup]]) { + + const uint8_t kmask1 = 0x03; + const uint8_t kmask2 = 0x0C; + const uint8_t kmask3 = 0x30; + const uint8_t kmask4 = 0xC0; + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; + + const uint nth = tptg.x*tptg.y; + const uint ith = tptg.y*tpitg.x + tpitg.y; + + const int step = QK_K / tptg.y; // we expect this to be 16 + const int iqs = step * tpitg.y; // 0...240 in steps of 16 + const int ip = iqs / 128; // 0 or 1 + const int il = (iqs - 128*ip)/16; // 0...7 + const int n = 4; + const int is = 8*ip + (n*il)/16; + + float sumf = 0; + for (int i = tpitg.x; i < nb; i += tptg.x) { + + device const uint8_t * ql = x[i].ql + 64*ip + n*il; + device const uint8_t * qh = x[i].qh + 32*ip + n*il; + device const int8_t * sc = x[i].scales + is; + + device const float * y = yy + i * QK_K + 128*ip + n*il; + + const float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } + + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); + + } + + sum[ith] = sumf; + + // + // Accumulate the sum from all threads in the threadgroup + // + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + dst[r1*ne0 + r0] = sum[0]; + } + +} diff --git a/ggml.c b/ggml.c index 9d4d3583a..3b72b80f3 100644 --- a/ggml.c +++ b/ggml.c @@ -14729,12 +14729,12 @@ static void ggml_graph_export_leaf(const struct ggml_tensor * tensor, FILE * fou const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-12s %8d %8d %d %d %d %16zu %16zu %16zu %16zu %16p %32s\n", + fprintf(fout, "%-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], - nb[0], nb[1], nb[2], nb[3], + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->data, tensor->name); } @@ -14743,13 +14743,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %d %d %d %d %16zu %16zu %16zu %16zu %8d %16p %32s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, - (int) ne[0], (int) ne[1], (int) ne[2], (int) ne[3], - nb[0], nb[1], nb[2], nb[3], + ne[0], ne[1], ne[2], ne[3], + nb[0], nb[1], nb[2], nb[3], tensor->n_tasks, tensor->data, tensor->name); @@ -14772,11 +14772,11 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { FILE * fout = stdout; fprintf(fout, "\n"); - fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); - fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); - fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); - fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); - fprintf(fout, "%-16s %8d\n", "eval", (int) size_eval); + fprintf(fout, "%-16s %8x\n", "magic", GGML_FILE_MAGIC); + fprintf(fout, "%-16s %8d\n", "version", GGML_FILE_VERSION); + fprintf(fout, "%-16s %8d\n", "leafs", cgraph->n_leafs); + fprintf(fout, "%-16s %8d\n", "nodes", cgraph->n_nodes); + fprintf(fout, "%-16s %" PRIu64 "\n", "eval", size_eval); // header fprintf(fout, "\n"); diff --git a/llama.cpp b/llama.cpp index d80706446..1e2c9d767 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1028,6 +1028,14 @@ static void llama_model_load_internal( } } + #if defined(GGML_USE_CLBLAST) + if (file_version == LLAMA_FILE_VERSION_GGJT_V3) { + if (hparams.ftype >= LLAMA_FTYPE_MOSTLY_Q2_K && hparams.ftype <= LLAMA_FTYPE_MOSTLY_Q6_K) { + printf("\n===\nK-Quants are currently not supported with CLBlast!!!\nPlease select a q4_0, q4_0, q5_0 or q5_1 format instead!\n=====\n"); + } + } + #endif + if (vocab_only) { return; }