Fix inplace version of operators

Use inplace version when possible
This commit is contained in:
Howard Su 2023-04-01 01:26:48 +08:00
parent bcf363cb53
commit 8febfc73af
3 changed files with 90 additions and 24 deletions

73
ggml.c
View file

@ -4278,9 +4278,7 @@ struct ggml_tensor * ggml_scale_impl(
is_node = true; is_node = true;
} }
// TODO: when implement backward, fix this: struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SCALE; result->op = GGML_OP_SCALE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -4593,10 +4591,11 @@ struct ggml_tensor * ggml_get_rows(
// ggml_diag_mask_inf // ggml_diag_mask_inf
struct ggml_tensor * ggml_diag_mask_inf( struct ggml_tensor * ggml_diag_mask_inf_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past) { int n_past,
bool inplace) {
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -4604,9 +4603,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
is_node = true; is_node = true;
} }
// TODO: when implement backward, fix this: struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_i32(ctx, n_past); struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
result->op = GGML_OP_DIAG_MASK_INF; result->op = GGML_OP_DIAG_MASK_INF;
@ -4617,11 +4614,26 @@ struct ggml_tensor * ggml_diag_mask_inf(
return result; return result;
} }
struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past) {
ggml_diag_mask_inf_impl(ctx, a, n_past, false);
}
struct ggml_tensor * ggml_diag_mask_inf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past) {
ggml_diag_mask_inf_impl(ctx, a, n_past, true);
}
// ggml_soft_max // ggml_soft_max
struct ggml_tensor * ggml_soft_max( struct ggml_tensor * ggml_soft_max_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a) { struct ggml_tensor * a,
bool inplace) {
bool is_node = false; bool is_node = false;
if (a->grad) { if (a->grad) {
@ -4629,9 +4641,7 @@ struct ggml_tensor * ggml_soft_max(
is_node = true; is_node = true;
} }
// TODO: when implement backward, fix this: struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SOFT_MAX; result->op = GGML_OP_SOFT_MAX;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -4641,14 +4651,26 @@ struct ggml_tensor * ggml_soft_max(
return result; return result;
} }
struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx,
struct ggml_tensor * a) {
ggml_soft_max_impl(ctx, a, false);
}
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
ggml_soft_max_impl(ctx, a, true);
}
// ggml_rope // ggml_rope
struct ggml_tensor * ggml_rope( struct ggml_tensor * ggml_rope_impl(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past, int n_past,
int n_dims, int n_dims,
int mode) { int mode,
bool inplace) {
GGML_ASSERT(n_past >= 0); GGML_ASSERT(n_past >= 0);
bool is_node = false; bool is_node = false;
@ -4657,9 +4679,7 @@ struct ggml_tensor * ggml_rope(
is_node = true; is_node = true;
} }
// TODO: when implement backward, fix this: struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
((int32_t *) b->data)[0] = n_past; ((int32_t *) b->data)[0] = n_past;
@ -4673,6 +4693,23 @@ struct ggml_tensor * ggml_rope(
return result; return result;
} }
struct ggml_tensor * ggml_rope(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode){
ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
}
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode){
ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
}
// ggml_conv_1d_1s // ggml_conv_1d_1s

33
ggml.h
View file

@ -470,27 +470,45 @@ struct ggml_tensor * ggml_repeat(
struct ggml_tensor * ggml_abs( struct ggml_tensor * ggml_abs(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_abs_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_sgn( struct ggml_tensor * ggml_sgn(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_sgn_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_neg( struct ggml_tensor * ggml_neg(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_neg_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_step( struct ggml_tensor * ggml_step(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_step_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_relu( struct ggml_tensor * ggml_relu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// TODO: double-check this computation is correct // TODO: double-check this computation is correct
struct ggml_tensor * ggml_gelu( struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_gelu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
struct ggml_tensor * ggml_silu( struct ggml_tensor * ggml_silu(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -605,16 +623,22 @@ struct ggml_tensor * ggml_get_rows(
struct ggml_tensor * b); struct ggml_tensor * b);
// set elements above the diagonal to -INF // set elements above the diagonal to -INF
// in-place, returns view(a)
struct ggml_tensor * ggml_diag_mask_inf( struct ggml_tensor * ggml_diag_mask_inf(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
int n_past); int n_past);
struct ggml_tensor * ggml_diag_mask_inf_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past);
// in-place, returns view(a) // in-place, returns view(a)
struct ggml_tensor * ggml_soft_max( struct ggml_tensor * ggml_soft_max(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
struct ggml_tensor * ggml_soft_max_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// rotary position embedding // rotary position embedding
// in-place, returns view(a) // in-place, returns view(a)
@ -626,7 +650,12 @@ struct ggml_tensor * ggml_rope(
int n_past, int n_past,
int n_dims, int n_dims,
int mode); int mode);
struct ggml_tensor * ggml_rope_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_dims,
int mode);
// padding = 1 // padding = 1
// TODO: we don't support extra parameters for now // TODO: we don't support extra parameters for now
// that's why we are hard-coding the stride, padding, and dilation // that's why we are hard-coding the stride, padding, and dilation

View file

@ -826,7 +826,7 @@ static bool llama_eval_internal(
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
struct ggml_tensor * Q = struct ggml_tensor * Q =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_rope(ctx0, ggml_rope_inplace(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
Qcur, Qcur,
n_embd/n_head, n_head, N), n_embd/n_head, n_head, N),
@ -836,7 +836,7 @@ static bool llama_eval_internal(
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_permute(ctx0, ggml_permute(ctx0,
ggml_rope(ctx0, ggml_rope_inplace(ctx0,
ggml_reshape_3d(ctx0, ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd), ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N), n_embd/n_head, n_head, n_past + N),
@ -853,10 +853,10 @@ static bool llama_eval_internal(
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
// KQ_masked = mask_past(KQ_scaled) // KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
// KQ = soft_max(KQ_masked) // KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans = struct ggml_tensor * V_trans =