ggml_add: Add more checks
This commit is contained in:
parent
0a6d5ad7cc
commit
8d37db3cdf
1 changed files with 34 additions and 16 deletions
50
ggml.c
50
ggml.c
|
@ -5893,27 +5893,36 @@ static void ggml_compute_forward_add_f16_f32(
|
|||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
//const size_t nb00 = src0->nb[0];
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
//const size_t nb0 = dst->nb[0];
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_f16_f16(
|
||||
|
@ -5933,27 +5942,36 @@ static void ggml_compute_forward_add_f16_f16(
|
|||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
//const size_t nb00 = src0->nb[0];
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
//const size_t nb0 = dst->nb[0];
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_q_f32(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue