diff --git a/ggml.c b/ggml.c index e88b9fbe4..a822cc9b4 100644 --- a/ggml.c +++ b/ggml.c @@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - //const size_t nb00 = src0->nb[0]; + const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb10 = src1->nb[0]; const size_t nb11 = src1->nb[1]; - //const size_t nb0 = dst->nb[0]; + const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F16); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(float)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr); + } } } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } } static void ggml_compute_forward_add_f16_f16( @@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16( const int n = ggml_nrows(src0); const int nc = src0->ne[0]; - //const size_t nb00 = src0->nb[0]; + const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb10 = src1->nb[0]; const size_t nb11 = src1->nb[1]; - //const size_t nb0 = dst->nb[0]; + const size_t nb0 = dst->nb[0]; const size_t nb1 = dst->nb[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F16); - for (int j = ith; j < n; j += nth) { - ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); - ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); - for (int i = 0; i < nc; i++) { - ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); - dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + GGML_ASSERT( nb0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + + if (nb10 == sizeof(ggml_fp16_t)) { + for (int j = ith; j < n; j += nth) { + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1); + ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01); + for (int i = 0; i < nc; i++) { + ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10); + dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr)); + } } } + else { + // src1 is not contiguous + GGML_ASSERT(false); + } } static void ggml_compute_forward_add_q_f32(