Fix inplace version of operators

Use inplace version when possible
This commit is contained in:
Howard Su 2023-04-01 01:26:48 +08:00
parent bcf363cb53
commit 8febfc73af
3 changed files with 90 additions and 24 deletions

73
ggml.c
View file

@ -4278,9 +4278,7 @@ struct ggml_tensor * ggml_scale_impl(
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -4593,10 +4591,11 @@ struct ggml_tensor * ggml_get_rows(
// ggml_diag_mask_inf
struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_tensor * ggml_diag_mask_inf_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past) {
int n_past,
bool inplace) {
bool is_node = false;
if (a->grad) {
@ -4604,9 +4603,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
result->op = GGML_OP_DIAG_MASK_INF;
@ -4617,11 +4614,26 @@ struct ggml_tensor * ggml_diag_mask_inf(
return result;
}
struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past) {
ggml_diag_mask_inf_impl(ctx, a, n_past, false);
}
struct ggml_tensor * ggml_diag_mask_inf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past) {
ggml_diag_mask_inf_impl(ctx, a, n_past, true);
}
// ggml_soft_max
struct ggml_tensor * ggml_soft_max(
struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx,
struct ggml_tensor * a) {
struct ggml_tensor * a,
bool inplace) {
bool is_node = false;
if (a->grad) {
@ -4629,9 +4641,7 @@ struct ggml_tensor * ggml_soft_max(
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -4641,14 +4651,26 @@ struct ggml_tensor * ggml_soft_max(
return result;
}
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
ggml_soft_max_impl(ctx, a, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
ggml_soft_max_impl(ctx, a, true);
}
// ggml_rope
struct ggml_tensor * ggml_rope(
struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode) {
int mode,
bool inplace) {
GGML_ASSERT(n_past >= 0);
bool is_node = false;
@ -4657,9 +4679,7 @@ struct ggml_tensor * ggml_rope(
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
((int32_t *) b->data)[0] = n_past;
@ -4673,6 +4693,23 @@ struct ggml_tensor * ggml_rope(
return result;
}
struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode){
ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
}
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode){
ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
}
// ggml_conv_1d_1s

33
ggml.h
View file

@ -470,27 +470,45 @@ struct ggml_tensor * ggml_repeat(
struct ggml_tensor * ggml_abs(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_abs_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_sgn(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_sgn_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_neg(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_neg_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_step(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_step_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_relu(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct
struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_gelu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_silu(
struct ggml_context * ctx,
@ -605,16 +623,22 @@ struct ggml_tensor * ggml_get_rows(
struct ggml_tensor * b);
// set elements above the diagonal to -INF
// in-place, returns view(a)
struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past);
struct ggml_tensor * ggml_diag_mask_inf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past);
// in-place, returns view(a)
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// rotary position embedding
// in-place, returns view(a)
@ -626,7 +650,12 @@ struct ggml_tensor * ggml_rope(
int n_past,
int n_dims,
int mode);
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode);
// padding = 1
// TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation

View file

@ -826,7 +826,7 @@ static bool llama_eval_internal(
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_rope_inplace(ctx0,
ggml_reshape_3d(ctx0,
Qcur,
n_embd/n_head, n_head, N),
@ -836,7 +836,7 @@ static bool llama_eval_internal(
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_rope_inplace(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
@ -853,10 +853,10 @@ static bool llama_eval_internal(
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =