diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index a99ece9a6..b5916f230 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -556,6 +556,14 @@ struct ggml_tensor * forward( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -583,8 +591,8 @@ struct ggml_tensor * forward( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, 1] // Kcur shape [n_embd/n_head, n_head, N, 1] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0); // store key and value to memory { @@ -810,9 +818,18 @@ struct ggml_tensor * forward_batch( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * inpSA = inpL; @@ -840,8 +857,8 @@ struct ggml_tensor * forward_batch( // wk shape [n_embd, n_embd, 1, 1] // Qcur shape [n_embd/n_head, n_head, N, n_batch] // Kcur shape [n_embd/n_head, n_head, N, n_batch] - struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); - struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0); + struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); + struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0); assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch); assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch); @@ -1100,6 +1117,14 @@ struct ggml_tensor * forward_lora( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // inpL shape [n_embd,N,1,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); for (int il = 0; il < n_layer; ++il) { @@ -1133,7 +1158,7 @@ struct ggml_tensor * forward_lora( model->layers[il].wqb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, @@ -1142,7 +1167,7 @@ struct ggml_tensor * forward_lora( model->layers[il].wkb, cur)), n_embd/n_head, n_head, N), - n_past, n_rot, 0, 0); + KQ_pos, n_rot, 0, 0); // store key and value to memory { diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 947aa7ed3..d91566586 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs( } }; + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + // rope has so much parameters that we make a custom function for it - auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale] + auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale] (struct ggml_tensor * t) -> struct ggml_tensor * { // not capturing these, to silcence warnings - const int n_past = 0; const int rope_mode = 0; return ggml_rope_custom(ctx, - t, n_past, n_rot, rope_mode, n_ctx, + t, KQ_pos, n_rot, rope_mode, n_ctx, rope_freq_base, rope_freq_scale); }; diff --git a/ggml-metal.m b/ggml-metal.m index 1139ee311..231debcfd 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -100,7 +100,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DECL_KERNEL(rope); + GGML_METAL_DECL_KERNEL(rope_f32); + GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -261,7 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_ADD_KERNEL(rope); + GGML_METAL_ADD_KERNEL(rope_f32); + GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - GGML_METAL_DEL_KERNEL(rope); + GGML_METAL_DEL_KERNEL(rope_f32); + GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); @@ -736,25 +739,59 @@ void ggml_metal_graph_compute( GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; + bool bcast_row = false; - if (ggml_nelements(src1) == ne10) { + int64_t nb = ne00; + + if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; [encoder setComputePipelineState:ctx->pipeline_add_row]; + + bcast_row = true; } else { [encoder setComputePipelineState:ctx->pipeline_add]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; - const int64_t n = ggml_nelements(dst)/4; + if (bcast_row) { + const int64_t n = ggml_nelements(dst)/4; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } break; case GGML_OP_MUL: { @@ -836,7 +873,7 @@ void ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { - const int nth = 32; + const int nth = MIN(32, ne00); if (ne00%4 == 0) { [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; @@ -1100,7 +1137,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 512; + const int nth = MIN(512, ne00); [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1119,7 +1156,7 @@ void ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 256; + const int nth = MIN(256, ne00); [encoder setComputePipelineState:ctx->pipeline_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1137,6 +1174,8 @@ void ggml_metal_graph_compute( { GGML_ASSERT((src0t == GGML_TYPE_F32)); + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; @@ -1170,12 +1209,14 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - const int nth = 32; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ROPE: { + GGML_ASSERT(ne10 == ne02); + + const int nth = MIN(1024, ne00); + const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -1185,38 +1226,44 @@ void ggml_metal_graph_compute( memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + switch (src0->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; + default: GGML_ASSERT(false); + }; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; + [encoder setBytes:&mode length:sizeof( int) atIndex:21]; + [encoder setBytes:&freq_base length:sizeof(float) atIndex:22]; + [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: { - const int nth = 32; + const int nth = MIN(1024, ne00); switch (src0t) { case GGML_TYPE_F32: diff --git a/ggml-metal.metal b/ggml-metal.metal index 7f1c3d9ea..5e1af6a09 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -24,12 +24,59 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +85,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -806,30 +853,61 @@ kernel void kernel_alibi_f32( } } +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template kernel void kernel_rope( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -839,7 +917,9 @@ kernel void kernel_rope( const bool is_neox = mode & 2; - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; const float theta_0 = freq_scale * (float)p; const float inv_ndims = -1.f/n_dims; @@ -851,11 +931,11 @@ kernel void kernel_rope( const float cos_theta = cos(theta); const float sin_theta = sin(theta); - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const T x0 = src[0]; + const T x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -870,8 +950,8 @@ kernel void kernel_rope( const int64_t i0 = ib*n_dims + ic/2; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -883,6 +963,9 @@ kernel void kernel_rope( } } +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -1273,8 +1356,8 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[32]; - const uint16_t kmask1 = 0x3030; - const uint16_t kmask2 = 0x0f0f; + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/4; const int ix = tiisg%4; diff --git a/ggml.c b/ggml.c index a0be068d6..e4faafee6 100644 --- a/ggml.c +++ b/ggml.c @@ -6968,7 +6968,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace( static struct ggml_tensor * ggml_rope_impl( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -6977,7 +6977,10 @@ static struct ggml_tensor * ggml_rope_impl( float xpos_base, bool xpos_down, bool inplace) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + bool is_node = false; if (a->grad) { @@ -6986,7 +6989,7 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -6996,6 +6999,7 @@ static struct ggml_tensor * ggml_rope_impl( result->op = GGML_OP_ROPE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7003,55 +7007,55 @@ static struct ggml_tensor * ggml_rope_impl( struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); } struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); } struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); } struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, float freq_base, float freq_scale) { - return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); + return ggml_rope_impl(ctx, a, b, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); } struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down) { - return ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); + return ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); } // ggml_rope_back @@ -7059,7 +7063,7 @@ struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -7067,7 +7071,10 @@ struct ggml_tensor * ggml_rope_back( float freq_scale, float xpos_base, bool xpos_down) { - GGML_ASSERT(n_past >= 0); + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(a->ne[2] == b->ne[0]); + GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; @@ -7078,7 +7085,7 @@ struct ggml_tensor * ggml_rope_back( struct ggml_tensor * result = ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -7088,6 +7095,7 @@ struct ggml_tensor * ggml_rope_back( result->op = GGML_OP_ROPE_BACK; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -8798,8 +8806,6 @@ static void ggml_compute_forward_add_f32( #else ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous @@ -12623,8 +12629,8 @@ static void ggml_compute_forward_clamp( static void ggml_compute_forward_rope_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12634,9 +12640,9 @@ static void ggml_compute_forward_rope_f32( // these two only relevant for xPos RoPE: float xpos_base; - bool xpos_down; + bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; @@ -12645,8 +12651,6 @@ static void ggml_compute_forward_rope_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12677,9 +12681,11 @@ static void ggml_compute_forward_rope_f32( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12716,7 +12722,7 @@ static void ggml_compute_forward_rope_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -12761,8 +12767,8 @@ static void ggml_compute_forward_rope_f32( static void ggml_compute_forward_rope_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } @@ -12770,15 +12776,13 @@ static void ggml_compute_forward_rope_f16( float freq_base; float freq_scale; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12809,9 +12813,11 @@ static void ggml_compute_forward_rope_f16( const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12890,15 +12896,16 @@ static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_f16(params, src0, dst); + ggml_compute_forward_rope_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_f32(params, src0, dst); + ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; default: { @@ -12912,6 +12919,7 @@ static void ggml_compute_forward_rope( static void ggml_compute_forward_rope_back_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -12929,7 +12937,7 @@ static void ggml_compute_forward_rope_back_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); @@ -12938,8 +12946,6 @@ static void ggml_compute_forward_rope_back_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -12966,9 +12972,11 @@ static void ggml_compute_forward_rope_back_f32( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -12980,7 +12988,7 @@ static void ggml_compute_forward_rope_back_f32( const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; theta *= theta_scale; @@ -13023,6 +13031,7 @@ static void ggml_compute_forward_rope_back_f32( static void ggml_compute_forward_rope_back_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -13033,12 +13042,10 @@ static void ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - assert(n_past >= 0); - GGML_TENSOR_UNARY_OP_LOCALS; //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); @@ -13065,9 +13072,11 @@ static void ggml_compute_forward_rope_back_f16( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; @@ -13119,15 +13128,16 @@ static void ggml_compute_forward_rope_back_f16( static void ggml_compute_forward_rope_back( const struct ggml_compute_params * params, const struct ggml_tensor * src0, + const struct ggml_tensor * src1, struct ggml_tensor * dst) { switch (src0->type) { case GGML_TYPE_F16: { - ggml_compute_forward_rope_back_f16(params, src0, dst); + ggml_compute_forward_rope_back_f16(params, src0, src1, dst); } break; case GGML_TYPE_F32: { - ggml_compute_forward_rope_back_f32(params, src0, dst); + ggml_compute_forward_rope_back_f32(params, src0, src1, dst); } break; default: { @@ -15864,11 +15874,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_ROPE: { - ggml_compute_forward_rope(params, tensor->src[0], tensor); + ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ROPE_BACK: { - ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); } break; case GGML_OP_ALIBI: { @@ -16506,7 +16516,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { // necessary for llama if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16523,7 +16533,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad, ggml_rope_back(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16537,7 +16547,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_ROPE_BACK: { if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16554,7 +16564,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad, ggml_rope_impl(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, diff --git a/ggml.h b/ggml.h index f45456876..e2bfa1ae4 100644 --- a/ggml.h +++ b/ggml.h @@ -1219,14 +1219,15 @@ extern "C" { struct ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements + // if mode & 1 == 1, skip n_past elements (DEPRECATED) // if mode & 2 == 1, GPT-NeoX style // if mode & 4 == 1, ChatGLM style - // TODO: avoid creating a new tensor every time + // + // b is an int32 vector with size a->ne[2], it contains the positions GGML_API struct ggml_tensor * ggml_rope( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1235,7 +1236,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1244,7 +1245,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1255,7 +1256,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1266,7 +1267,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, float base, bool down); @@ -1276,7 +1277,7 @@ extern "C" { GGML_API struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, struct ggml_tensor * a, - int n_past, + struct ggml_tensor * b, int n_dims, int mode, int n_ctx, diff --git a/llama.cpp b/llama.cpp index 0cab18093..0737cb2bc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2404,6 +2404,7 @@ static struct ggml_cgraph * llm_build_llama( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -2411,6 +2412,41 @@ static struct ggml_cgraph * llm_build_llama( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + // KQ_mask + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < N; ++j) { + for (int i = n_past + j + 1; i < n_past + N; ++i) { + data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + } + } + } + } + + // Q_pos - contains the positions + struct ggml_tensor * Q_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(lctx.alloc, Q_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) Q_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + + struct ggml_tensor * K_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_past + N); + ggml_allocr_alloc(lctx.alloc, K_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_pos->data; + for (int i = 0; i < n_past + N; ++i) { + data[i] = i; + } + } + for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2447,14 +2483,18 @@ static struct ggml_cgraph * llm_build_llama( offload_func_kq(tmpq); ggml_set_name(tmpq, "tmpq"); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + // Note: we are not RoPE-ing K here + struct ggml_tensor * Kcur = tmpk; offload_func_kq(Kcur); ggml_set_name(Kcur, "Kcur"); - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), Q_pos, n_embd_head, 0, 0, freq_base, freq_scale); offload_func_kq(Qcur); ggml_set_name(Qcur, "Qcur"); + struct ggml_tensor * ck; + struct ggml_tensor * cv; + // store key and value to memory { // compute the transposed [N, n_embd] V matrix @@ -2477,9 +2517,11 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(v); ggml_set_name(v, "v"); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + ck = ggml_cpy(ctx0, Kcur, k); + cv = ggml_cpy(ctx0, Vcur, v); + + ggml_build_forward_expand(gf, ck); + ggml_build_forward_expand(gf, cv); } struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -2488,13 +2530,18 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, - n_embd_head, n_past + N, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, + n_embd_head, n_head_kv, n_past + N, ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); offload_func_kq(K); ggml_set_name(K, "K"); + // RoPE the K cache + K->src[1] = ck; // TODO: HACK!! + K = ggml_rope_custom(ctx0, K, K_pos, n_embd_head, 0, 0, freq_base, freq_scale); + K = ggml_permute(ctx0, K, 0, 2, 1, 3); + // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); offload_func_kq(KQ); @@ -2502,17 +2549,18 @@ static struct ggml_cgraph * llm_build_llama( // KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); // KQ_masked = mask_past(KQ_scaled) - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); @@ -2736,6 +2784,7 @@ static struct ggml_cgraph * llm_build_baichaun( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -2743,6 +2792,32 @@ static struct ggml_cgraph * llm_build_baichaun( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + // KQ_mask + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < N; ++j) { + for (int i = n_past + j + 1; i < n_past + N; ++i) { + data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + for (int il = 0; il < n_layer; ++il) { ggml_format_name(inpL, "layer_inp_%d", il); @@ -2783,11 +2858,11 @@ static struct ggml_cgraph * llm_build_baichaun( struct ggml_tensor * Qcur; switch (model.type) { case MODEL_7B: - Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); - Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale); + Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); + Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), KQ_pos, n_embd_head, 0, 0, freq_base, freq_scale); break; case MODEL_13B: - Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); + Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N); Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N); break; default: @@ -2847,7 +2922,7 @@ static struct ggml_cgraph * llm_build_baichaun( // KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled shape [n_past + N, N, n_head, 1] - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); @@ -2856,24 +2931,26 @@ static struct ggml_cgraph * llm_build_baichaun( switch (model.type) { case MODEL_7B: - KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled, KQ_mask); + //KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); break; case MODEL_13B: KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); - KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); + KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + //KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); break; default: GGML_ASSERT(false); } // KQ_masked = mask_past(KQ_scaled) - // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); // struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past); // offload_func_kq(KQ_masked); // ggml_set_name(KQ_masked, "KQ_masked"); // KQ = soft_max(KQ_masked) - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); @@ -3096,6 +3173,7 @@ static struct ggml_cgraph * llm_build_falcon( } #endif // GGML_USE_CUBLAS + // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -3103,6 +3181,32 @@ static struct ggml_cgraph * llm_build_falcon( } ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + // KQ_mask + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, N, 1); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < N; ++j) { + for (int i = n_past + j + 1; i < n_past + N; ++i) { + data[h*(n_past + N)*N + j*(n_past + N) + i] = -INFINITY; + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < N; ++i) { + data[i] = n_past + i; + } + } + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * attn_norm; @@ -3179,9 +3283,9 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_v(tmpv); // using mode = 2 for neox mode - struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Qcur = ggml_rope_custom(ctx0, tmpq, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Qcur); - struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale); + struct ggml_tensor * Kcur = ggml_rope_custom(ctx0, tmpk, KQ_pos, n_embd_head, 2, 0, freq_base, freq_scale); offload_func_kq(Kcur); { @@ -3220,15 +3324,15 @@ static struct ggml_cgraph * llm_build_falcon( offload_func_kq(KQ); ggml_set_name(KQ, "KQ"); - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); offload_func_kq(KQ_scaled); ggml_set_name(KQ_scaled, "KQ_scaled"); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); offload_func_kq(KQ_masked); ggml_set_name(KQ_masked, "KQ_masked"); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 916dc9d05..a19e1376e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -37,6 +37,8 @@ llama_build_and_test_executable(test-llama-grammar.cpp) llama_build_and_test_executable(test-grad0.cpp) # SLOW # llama_build_and_test_executable(test-opt.cpp) # SLOW +llama_build_and_test_executable(test-rope.cpp) + # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 468cde66a..7b0c0fcdb 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1404,6 +1404,11 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + for (int i = 0; i < ne2[2]; ++i) { + ((int32_t *) p->data)[i] = n_past + i; + } + ggml_set_param(ctx0, x[0]); const bool skip_past = (mode & 1); @@ -1415,7 +1420,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); @@ -1438,6 +1443,11 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + for (int i = 0; i < ne2[2]; ++i) { + ((int32_t *) p->data)[i] = n_past + i; + } + ggml_set_param(ctx0, x[0]); const bool skip_past = (mode & 1); @@ -1449,7 +1459,7 @@ int main(int argc, const char ** argv) { continue; } - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); + struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode, 0)); GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY); diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp new file mode 100644 index 000000000..26c1f42dc --- /dev/null +++ b/tests/test-rope.cpp @@ -0,0 +1,221 @@ +#include "ggml.h" + +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wdouble-promotion" +#endif + +#define MAX_NARGS 3 + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define GGML_SILU_FP16 + +// +// logging +// + +#if (GGML_DEBUG >= 1) +#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG(...) +#endif + +#if (GGML_DEBUG >= 5) +#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_5(...) +#endif + +#if (GGML_DEBUG >= 10) +#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) +#else +#define GGML_PRINT_DEBUG_10(...) +#endif + +#define GGML_PRINT(...) printf(__VA_ARGS__) + +static float frand(void) { + return (float)rand()/(float)RAND_MAX; +} + +static int irand(int n) { + if (n == 0) return 0; + return rand()%n; +} + +static void get_random_dims(int64_t * dims, int ndims) { + dims[0] = dims[1] = dims[2] = dims[3] = 1; + + for (int i = 0; i < ndims; i++) { + dims[i] = 1 + irand(4); + } +} + +static struct ggml_tensor * get_random_tensor_f32( + struct ggml_context * ctx0, + int ndims, + const int64_t ne[], + float fmin, + float fmax) { + struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); + + switch (ndims) { + case 1: + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; + } + break; + case 2: + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + break; + case 3: + for (int i2 = 0; i2 < ne[2]; i2++) { + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + } + break; + case 4: + for (int i3 = 0; i3 < ne[3]; i3++) { + for (int i2 = 0; i2 < ne[2]; i2++) { + for (int i1 = 0; i1 < ne[1]; i1++) { + for (int i0 = 0; i0 < ne[0]; i0++) { + ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + } + } + } + break; + default: + assert(false); + }; + + return result; +} + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + ggml_graph_compute(graph, &plan); +} + +int main(int /*argc*/, const char ** /*argv*/) { + struct ggml_init_params params = { + /* .mem_size = */ 128*1024*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, + }; + + std::vector work_buffer; + + struct ggml_context * ctx0 = ggml_init(params); + + struct ggml_tensor * x; + + // rope f32 + for (int m = 0; m < 3; ++m) { + const int ndims = 4; + + const int64_t n_rot = 128; + const int64_t ne[4] = { 2*n_rot, 32, 73, 1 }; + + const int n_past_0 = 100; + const int n_past_2 = 33; + + struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + + for (int i = 0; i < ne[2]; ++i) { + ((int32_t *) p0->data)[i] = n_past_0 + i; + ((int32_t *) p1->data)[i] = n_past_2 - n_past_0; + ((int32_t *) p2->data)[i] = n_past_2 + i; + } + + // test mode 0, 2, 4 (standard, GPT-NeoX, GLM) + const int mode = m == 0 ? 0 : m == 1 ? 2 : 4; + + x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); + + // 100, 101, 102, ..., 172 + struct ggml_tensor * r0 = ggml_rope(ctx0, x, p0, n_rot, mode, 1024); + // -67, -67, -67, ..., -67 + struct ggml_tensor * r1 = ggml_rope(ctx0, r0, p1, n_rot, mode, 1024); // "context swap", i.e. forget n_past_0 - n_past_2 tokens + + // 33, 34, 35, ..., 105 + struct ggml_tensor * r2 = ggml_rope(ctx0, x, p2, n_rot, mode, 1024); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_build_forward_expand(gf, r0); + ggml_build_forward_expand(gf, r1); + ggml_build_forward_expand(gf, r2); + + ggml_graph_compute_helper(work_buffer, gf, 4); + + // check that r1 and r2 are the same + { + double sum0 = 0.0f; + double sum1 = 0.0f; + double diff = 0.0f; + + const float * r1_data = (float *) r1->data; + const float * r2_data = (float *) r2->data; + + const int n_elements = ggml_nelements(r1); + + for (int i = 0; i < n_elements; ++i) { + sum0 += fabs(r1_data[i]); + sum1 += fabs(r2_data[i]); + diff += fabs(r1_data[i] - r2_data[i]); + //if (fabs(r1_data[i] - r2_data[i]) > 0.0001f) { + // printf("%d: %f %f\n", i, r1_data[i], r2_data[i]); + // printf("diff: %f\n", fabs(r1_data[i] - r2_data[i])); + //} + } + + //for (int i = 4096; i < 4096 + 128; ++i) { + // printf("%f %f\n", r1_data[i], r2_data[i]); + //} + + printf("mode: %d\n", mode); + printf("sum0: %f\n", sum0); + printf("sum1: %f\n", sum1); + printf("diff: %f\n", diff); + printf("rel err: %f\n", diff / sum0); + printf("rel err: %f\n", diff / sum1); + + GGML_ASSERT(diff / sum0 < 0.0001f); + GGML_ASSERT(diff / sum1 < 0.0001f); + } + } + + ggml_free(ctx0); + + return 0; +} +