WIP on f16
This commit is contained in:
parent
32d0654dd0
commit
1b6fd5470b
1 changed files with 30 additions and 15 deletions
45
ggml.c
45
ggml.c
|
@ -4936,16 +4936,32 @@ static void ggml_compute_forward_dup_f16(
|
|||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
const int thread_num = params->ith;
|
||||
const int total_threads = params->nth;
|
||||
const int64_t regions = ne03 * ne02 * ne01 * ne00;
|
||||
|
||||
const int64_t regions_per_thread = (regions + total_threads - 1) / total_threads;
|
||||
|
||||
const int64_t thread_start_region = regions_per_thread * thread_num;
|
||||
const int64_t thread_stop_region = MIN(thread_start_region + regions_per_thread, regions);
|
||||
|
||||
int region_index = 0;
|
||||
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
if (region_index > thread_stop_region) break;
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
||||
if (region_index++ >= thread_start_region) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
memcpy(dst_ptr, src0_ptr, sizeof(ggml_fp16_t));
|
||||
}
|
||||
|
||||
// Regardless, we have to keep the dst counters updated
|
||||
if (++i10 == ne00) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne01) {
|
||||
|
@ -4967,11 +4983,17 @@ static void ggml_compute_forward_dup_f16(
|
|||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
||||
if (region_index > thread_stop_region) break;
|
||||
|
||||
if (region_index++ >= thread_start_region) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
*(float *) dst_ptr = GGML_FP16_TO_FP32(*(const ggml_fp16_t *) src0_ptr);
|
||||
}
|
||||
|
||||
// Regardless, we have to keep the dst counters updated
|
||||
if (++i10 == ne00) {
|
||||
i10 = 0;
|
||||
if (++i11 == ne01) {
|
||||
|
@ -5029,7 +5051,6 @@ static void ggml_compute_forward_dup_f32(
|
|||
int64_t i12 = 0;
|
||||
int64_t i13 = 0;
|
||||
|
||||
|
||||
const int thread_num = params->ith;
|
||||
const int total_threads = params->nth;
|
||||
const int64_t regions = ne03 * ne02 * ne01 * ne00;
|
||||
|
@ -5046,11 +5067,9 @@ static void ggml_compute_forward_dup_f32(
|
|||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
|
||||
if (region_index > thread_stop_region) break;
|
||||
|
||||
|
||||
if (region_index >= thread_start_region) {
|
||||
if (region_index++ >= thread_start_region) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
|
@ -5079,10 +5098,9 @@ static void ggml_compute_forward_dup_f32(
|
|||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
for (int64_t i00 = 0; i00 < ne00; i00++) {
|
||||
|
||||
if (region_index > thread_stop_region) break;
|
||||
|
||||
if (region_index >= thread_start_region) {
|
||||
if (region_index++ >= thread_start_region) {
|
||||
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
|
||||
|
||||
|
@ -7302,7 +7320,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
|
||||
// row index used to determine which thread to use
|
||||
int ir = 0;
|
||||
int i = 0;
|
||||
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
|
||||
|
@ -7312,7 +7329,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
if (ir > ir1) break;
|
||||
|
||||
for (int i0 = 0; i0 < n_dims; i0 += 2) {
|
||||
++i;
|
||||
const float theta = powf(10000.0, ((float)-i0)/n_dims);
|
||||
|
||||
const float cos_theta = cosf(p*theta);
|
||||
|
@ -7330,7 +7346,6 @@ static void ggml_compute_forward_rope_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f16(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue