metal : add more general support for ggml_get_rows + tests
This commit is contained in:
parent
9064b1ca05
commit
2cbcba829f
4 changed files with 78 additions and 25 deletions
6
ggml.c
6
ggml.c
|
@ -10363,7 +10363,7 @@ static void ggml_compute_forward_get_rows_q(
|
|||
|
||||
dequantize_row_q(
|
||||
(const void *) ((char *) src0->data + i02*nb02 + r*nb01),
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
||||
(float *) ((char *) dst->data + i*nb1), nc);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10396,7 +10396,7 @@ static void ggml_compute_forward_get_rows_f16(
|
|||
|
||||
for (int j = 0; j < nc; ++j) {
|
||||
ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i02*nb02 + r*nb01))[j];
|
||||
((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
|
||||
((float *) ((char *) dst->data + i*nb1))[j] = GGML_FP16_TO_FP32(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10429,7 +10429,7 @@ static void ggml_compute_forward_get_rows_f32(
|
|||
const int64_t i02 = i/ne10;
|
||||
|
||||
ggml_vec_cpy_f32(nc,
|
||||
(float *) ((char *) dst->data + i*dst->nb[1]),
|
||||
(float *) ((char *) dst->data + i*nb1),
|
||||
(float *) ((char *) src0->data + i02*nb02 + r*nb01));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue