Fix inplace version of operators
Use inplace version when possible
This commit is contained in:
parent
bcf363cb53
commit
8febfc73af
3 changed files with 90 additions and 24 deletions
73
ggml.c
73
ggml.c
|
@ -4278,9 +4278,7 @@ struct ggml_tensor * ggml_scale_impl(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SCALE;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -4593,10 +4591,11 @@ struct ggml_tensor * ggml_get_rows(
|
|||
|
||||
// ggml_diag_mask_inf
|
||||
|
||||
struct ggml_tensor * ggml_diag_mask_inf(
|
||||
struct ggml_tensor * ggml_diag_mask_inf_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past) {
|
||||
int n_past,
|
||||
bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -4604,9 +4603,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
||||
|
||||
result->op = GGML_OP_DIAG_MASK_INF;
|
||||
|
@ -4617,11 +4614,26 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
|||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_diag_mask_inf(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past) {
|
||||
ggml_diag_mask_inf_impl(ctx, a, n_past, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past) {
|
||||
ggml_diag_mask_inf_impl(ctx, a, n_past, true);
|
||||
}
|
||||
|
||||
// ggml_soft_max
|
||||
|
||||
struct ggml_tensor * ggml_soft_max(
|
||||
struct ggml_tensor * ggml_soft_max_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
struct ggml_tensor * a,
|
||||
bool inplace) {
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
|
@ -4629,9 +4641,7 @@ struct ggml_tensor * ggml_soft_max(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SOFT_MAX;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -4641,14 +4651,26 @@ struct ggml_tensor * ggml_soft_max(
|
|||
return result;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_soft_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
ggml_soft_max_impl(ctx, a, false);
|
||||
}
|
||||
struct ggml_tensor * ggml_soft_max_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a) {
|
||||
ggml_soft_max_impl(ctx, a, true);
|
||||
}
|
||||
|
||||
// ggml_rope
|
||||
|
||||
struct ggml_tensor * ggml_rope(
|
||||
struct ggml_tensor * ggml_rope_impl(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode) {
|
||||
int mode,
|
||||
bool inplace) {
|
||||
GGML_ASSERT(n_past >= 0);
|
||||
bool is_node = false;
|
||||
|
||||
|
@ -4657,9 +4679,7 @@ struct ggml_tensor * ggml_rope(
|
|||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
|
@ -4673,6 +4693,23 @@ struct ggml_tensor * ggml_rope(
|
|||
|
||||
return result;
|
||||
}
|
||||
struct ggml_tensor * ggml_rope(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode){
|
||||
ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_rope_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode){
|
||||
ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
|
||||
}
|
||||
|
||||
// ggml_conv_1d_1s
|
||||
|
||||
|
|
33
ggml.h
33
ggml.h
|
@ -470,27 +470,45 @@ struct ggml_tensor * ggml_repeat(
|
|||
struct ggml_tensor * ggml_abs(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_abs_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_sgn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_sgn_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_neg(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_neg_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_step(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_step_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_relu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_relu_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// TODO: double-check this computation is correct
|
||||
struct ggml_tensor * ggml_gelu(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_gelu_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
struct ggml_tensor * ggml_silu(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -605,16 +623,22 @@ struct ggml_tensor * ggml_get_rows(
|
|||
struct ggml_tensor * b);
|
||||
|
||||
// set elements above the diagonal to -INF
|
||||
// in-place, returns view(a)
|
||||
struct ggml_tensor * ggml_diag_mask_inf(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past);
|
||||
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past);
|
||||
|
||||
// in-place, returns view(a)
|
||||
struct ggml_tensor * ggml_soft_max(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
struct ggml_tensor * ggml_soft_max_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a);
|
||||
|
||||
// rotary position embedding
|
||||
// in-place, returns view(a)
|
||||
|
@ -626,7 +650,12 @@ struct ggml_tensor * ggml_rope(
|
|||
int n_past,
|
||||
int n_dims,
|
||||
int mode);
|
||||
|
||||
struct ggml_tensor * ggml_rope_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_dims,
|
||||
int mode);
|
||||
// padding = 1
|
||||
// TODO: we don't support extra parameters for now
|
||||
// that's why we are hard-coding the stride, padding, and dilation
|
||||
|
|
|
@ -826,7 +826,7 @@ static bool llama_eval_internal(
|
|||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_rope_inplace(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
Qcur,
|
||||
n_embd/n_head, n_head, N),
|
||||
|
@ -836,7 +836,7 @@ static bool llama_eval_internal(
|
|||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_rope(ctx0,
|
||||
ggml_rope_inplace(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
|
@ -853,10 +853,10 @@ static bool llama_eval_internal(
|
|||
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor * V_trans =
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue