metal : add more general support for ggml_get_rows + tests

This commit is contained in:
Georgi Gerganov 2023-12-09 14:18:42 +02:00
parent 9064b1ca05
commit 2cbcba829f
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 78 additions and 25 deletions

6
ggml.c
View file

@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(
dequantize_row_q(
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
(float *) ((char *) dst->data + i*nb1), nc);
}
}
@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(
for (int j = 0; j < nc; ++j) {
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
}
}
}
@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
const int64_t i02 = i/ne10;
ggml_vec_cpy_f32(nc,
(float *) ((char *) dst->data + i*dst->nb[1]),
(float *) ((char *) dst->data + i*nb1),
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
}
}