diff --git a/ggml-metal.m b/ggml-metal.m index 1c7930a18..1c1810da4 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -757,17 +757,17 @@ void ggml_metal_graph_compute( case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break; default: GGML_ASSERT(false && "MUL MAT-MAT not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -945,11 +945,11 @@ void ggml_metal_graph_compute( const int nth = 256; [encoder setComputePipelineState:ctx->pipeline_norm]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -992,7 +992,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + const int nth = 32; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: @@ -1007,8 +1009,8 @@ void ggml_metal_graph_compute( memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; @@ -1059,24 +1061,24 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "not implemented"); } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml.c b/ggml.c index c7cc72d48..6839d0321 100644 --- a/ggml.c +++ b/ggml.c @@ -3554,9 +3554,9 @@ inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } -static const float GELU_COEF_A = 0.044715f; -static const float GELU_QUICK_COEF = -1.702f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float GELU_COEF_A = 0.044715f; +static const float GELU_QUICK_COEF = -1.702f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; inline static float ggml_gelu_f32(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); diff --git a/llama.cpp b/llama.cpp index 38593ac01..a2ee80ab2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2545,26 +2545,23 @@ static struct ggml_cgraph * llm_build_falcon( struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); ggml_set_name(KQV_merged, "KQV_merged"); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_set_name(cur, "KQV_merged_contiguous"); - cur = ggml_cpy(ctx0, - KQV_merged, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); ggml_set_name(cur, "result_wo"); } - struct ggml_tensor * inpFF = attn_norm; - struct ggml_tensor * attn_out = ggml_cpy( - ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + struct ggml_tensor * attn_out = cur; - cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + // feed forward + { + struct ggml_tensor * inpFF = attn_norm; + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + } cur = ggml_add(ctx0, cur, attn_out); cur = ggml_add(ctx0, cur, inpL);