diff --git a/ggml.c b/ggml.c index f38e26921..539085ed0 100644 --- a/ggml.c +++ b/ggml.c @@ -3990,6 +3990,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = { "PERMUTE", "TRANSPOSE", "GET_ROWS", + "GET_ROWS_BACK", "DIAG_MASK_INF", "DIAG_MASK_ZERO", "SOFT_MAX", @@ -4045,6 +4046,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "permute(x)", "transpose(x)", "get_rows(x)", + "get_rows_back(x)", "diag_mask_inf(x)", "diag_mask_zero(x)", "soft_max(x)", @@ -6132,7 +6134,6 @@ struct ggml_tensor * ggml_get_rows( bool is_node = false; if (a->grad || b->grad) { - GGML_ASSERT(false); // TODO: implement backward is_node = true; } @@ -6148,6 +6149,32 @@ struct ggml_tensor * ggml_get_rows( return result; } +// ggml_get_rows_back + +struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32); + + bool is_node = false; + + if (a->grad || b->grad) { + is_node = true; + } + + // TODO: implement non F32 return + //struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]); + struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]); + + result->op = GGML_OP_GET_ROWS_BACK; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src0 = a; + result->src1 = b; + + return result; +} + // ggml_diag_mask_inf struct ggml_tensor * ggml_diag_mask_inf_impl( @@ -10052,7 +10079,8 @@ static void ggml_compute_forward_get_rows_q( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + bool backward) { assert(params->ith == 0); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10068,12 +10096,15 @@ static void ggml_compute_forward_get_rows_q( assert( dst->ne[1] == nr); assert(src0->nb[0] == GGML_TYPE_SIZE[type]); + const int b = backward ? 1 : 0; + const int f = backward ? 0 : 1; + for (int i = 0; i < nr; ++i) { const int r = ((int32_t *) src1->data)[i]; dequantize_row_q( - (const void *) ((char *) src0->data + r*src0->nb[1]), - (float *) ((char *) dst->data + i*dst->nb[1]), nc); + (const void *) ((char *) src0->data + (f*r + b*i)*src0->nb[1]), + (float *) ((char *) dst->data + (f*i + b*r)*dst->nb[1]), nc); } } @@ -10081,7 +10112,8 @@ static void ggml_compute_forward_get_rows_f16( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + bool backward) { assert(params->ith == 0); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10095,12 +10127,15 @@ static void ggml_compute_forward_get_rows_f16( assert( dst->ne[1] == nr); assert(src0->nb[0] == sizeof(ggml_fp16_t)); + const int b = backward ? 1 : 0; + const int f = backward ? 0 : 1; + for (int i = 0; i < nr; ++i) { const int r = ((int32_t *) src1->data)[i]; for (int j = 0; j < nc; ++j) { - ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j]; - ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); + ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + (f*r + b*i)*src0->nb[1]))[j]; + ((float *) ((char *) dst->data + (f*i + b*r)*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v); } } } @@ -10109,7 +10144,8 @@ static void ggml_compute_forward_get_rows_f32( const struct ggml_compute_params * params, const struct ggml_tensor * src0, const struct ggml_tensor * src1, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + bool backward) { assert(params->ith == 0); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { @@ -10123,12 +10159,15 @@ static void ggml_compute_forward_get_rows_f32( assert( dst->ne[1] == nr); assert(src0->nb[0] == sizeof(float)); + const int b = backward ? 1 : 0; + const int f = backward ? 0 : 1; + for (int i = 0; i < nr; ++i) { const int r = ((int32_t *) src1->data)[i]; ggml_vec_cpy_f32(nc, - (float *) ((char *) dst->data + i*dst->nb[1]), - (float *) ((char *) src0->data + r*src0->nb[1])); + (float *) ((char *) dst->data + (f*i + b*r)*dst->nb[1]), + (float *) ((char *) src0->data + (f*r + b*i)*src0->nb[1])); } } @@ -10146,15 +10185,64 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: { - ggml_compute_forward_get_rows_q(params, src0, src1, dst); + ggml_compute_forward_get_rows_q(params, src0, src1, dst, false); } break; case GGML_TYPE_F16: { - ggml_compute_forward_get_rows_f16(params, src0, src1, dst); + ggml_compute_forward_get_rows_f16(params, src0, src1, dst, false); } break; case GGML_TYPE_F32: { - ggml_compute_forward_get_rows_f32(params, src0, src1, dst); + ggml_compute_forward_get_rows_f32(params, src0, src1, dst, false); + } break; + default: + { + GGML_ASSERT(false); + } break; + } + + //static bool first = true; + //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]); + //if (first) { + // first = false; + //} else { + // for (int k = 0; k < dst->ne[1]; ++k) { + // for (int j = 0; j < dst->ne[0]/16; ++j) { + // for (int i = 0; i < 16; ++i) { + // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]); + // } + // printf("\n"); + // } + // printf("\n"); + // } + // printf("\n"); + // exit(0); + //} +} + +// ggml_compute_forward_get_rows_back + +static void ggml_compute_forward_get_rows_back( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct ggml_tensor * dst) { + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_2: + case GGML_TYPE_Q4_3: + case GGML_TYPE_Q8_0: + { + ggml_compute_forward_get_rows_q(params, src0, src1, dst, true); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_get_rows_f16(params, src0, src1, dst, true); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_get_rows_f32(params, src0, src1, dst, true); } break; default: { @@ -12351,6 +12439,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor); } break; + case GGML_OP_GET_ROWS_BACK: + { + ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor); + } break; case GGML_OP_DIAG_MASK_INF: { ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor); @@ -12787,7 +12879,28 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_GET_ROWS: { // necessary for llama (only for tokenizer) - GGML_ASSERT(false); // TODO: not implemented + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_get_rows_back(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + // noop + } + } break; + case GGML_OP_GET_ROWS_BACK: + { + // necessary for llama (only for tokenizer) + if (src0->grad) { + src0->grad = + ggml_add_impl(ctx, src0->grad, + ggml_get_rows(ctx, tensor->grad, src1), + inplace); + } + if (src1->grad) { + // noop + } } break; case GGML_OP_DIAG_MASK_INF: { @@ -13362,6 +13475,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG_MASK_INF: { node->n_tasks = 1; diff --git a/ggml.h b/ggml.h index 9d2ba48ea..1677ea533 100644 --- a/ggml.h +++ b/ggml.h @@ -284,6 +284,7 @@ extern "C" { GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, + GGML_OP_GET_ROWS_BACK, GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, GGML_OP_SOFT_MAX, @@ -694,6 +695,11 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // set elements above the diagonal to -INF GGML_API struct ggml_tensor * ggml_diag_mask_inf( struct ggml_context * ctx, @@ -749,7 +755,6 @@ extern "C" { // rotary position embedding backward, i.e compute dx from dy GGML_API struct ggml_tensor * ggml_rope_back( struct ggml_context * ctx, - struct ggml_tensor * x, struct ggml_tensor * dy, int n_past, int n_dims,