Impl for 32

This commit is contained in:
Will Beddow 2023-04-06 23:46:22 -04:00
parent cc9cee8e9e
commit 32d0654dd0

46
ggml.c
View file

@ -4997,7 +4997,6 @@ static void ggml_compute_forward_dup_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
GGML_ASSERT(params->ith == 0);
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
@ -5030,16 +5029,35 @@ static void ggml_compute_forward_dup_f32(
int64_t i12 = 0;
int64_t i13 = 0;
const int thread_num = params->ith;
const int total_threads = params->nth;
const int64_t regions = ne03 * ne02 * ne01 * ne00;
const int64_t regions_per_thread = (regions + total_threads - 1) / total_threads;
const int64_t thread_start_region = regions_per_thread * thread_num;
const int64_t thread_stop_region = MIN(thread_start_region + regions_per_thread, regions);
int region_index = 0;
if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, sizeof(float));
if (region_index > thread_stop_region) break;
if (region_index >= thread_start_region) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
memcpy(dst_ptr, src0_ptr, sizeof(float));
}
// Regardless, we have to keep the dst counters updated
if (++i10 == dst->ne[0]) {
i10 = 0;
if (++i11 == dst->ne[1]) {
@ -5061,11 +5079,17 @@ static void ggml_compute_forward_dup_f32(
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = 0; i01 < ne01; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
if (region_index > thread_stop_region) break;
if (region_index >= thread_start_region) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(ggml_fp16_t *) dst_ptr = GGML_FP32_TO_FP16(*(const float *) src0_ptr);
}
// Regardless, we have to keep the dst counters updated
if (++i10 == dst->ne[0]) {
i10 = 0;
if (++i11 == dst->ne[1]) {
@ -7278,6 +7302,7 @@ static void ggml_compute_forward_rope_f32(
// row index used to determine which thread to use
int ir = 0;
int i = 0;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
@ -7287,6 +7312,7 @@ static void ggml_compute_forward_rope_f32(
if (ir > ir1) break;
for (int i0 = 0; i0 < n_dims; i0 += 2) {
++i;
const float theta = powf(10000.0, ((float)-i0)/n_dims);
const float cos_theta = cosf(p*theta);
@ -7304,6 +7330,7 @@ static void ggml_compute_forward_rope_f32(
}
}
}
}
static void ggml_compute_forward_rope_f16(
@ -9441,7 +9468,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
{
node->n_tasks = n_threads;
} break;
case GGML_OP_CPY:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
@ -9451,10 +9477,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
{
node->n_tasks = 1;
} break;
case GGML_OP_CPY:
case GGML_OP_SOFT_MAX:
{
node->n_tasks = n_threads;
} break;
case GGML_OP_ROPE:
{
node->n_tasks = n_threads;