Fix inplace version of operators
Use inplace version when possible
This commit is contained in:
parent
bcf363cb53
commit
8febfc73af
3 changed files with 90 additions and 24 deletions
73
ggml.c
73
ggml.c
|
@ -4278,9 +4278,7 @@ struct ggml_tensor * ggml_scale_impl(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
|
|
||||||
result->op = GGML_OP_SCALE;
|
result->op = GGML_OP_SCALE;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -4593,10 +4591,11 @@ struct ggml_tensor * ggml_get_rows(
|
||||||
|
|
||||||
// ggml_diag_mask_inf
|
// ggml_diag_mask_inf
|
||||||
|
|
||||||
struct ggml_tensor * ggml_diag_mask_inf(
|
struct ggml_tensor * ggml_diag_mask_inf_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past) {
|
int n_past,
|
||||||
|
bool inplace) {
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
|
@ -4604,9 +4603,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_INF;
|
result->op = GGML_OP_DIAG_MASK_INF;
|
||||||
|
@ -4617,11 +4614,26 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past) {
|
||||||
|
ggml_diag_mask_inf_impl(ctx, a, n_past, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past) {
|
||||||
|
ggml_diag_mask_inf_impl(ctx, a, n_past, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_soft_max
|
// ggml_soft_max
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_max(
|
struct ggml_tensor * ggml_soft_max_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a) {
|
struct ggml_tensor * a,
|
||||||
|
bool inplace) {
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad) {
|
if (a->grad) {
|
||||||
|
@ -4629,9 +4641,7 @@ struct ggml_tensor * ggml_soft_max(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
|
|
||||||
result->op = GGML_OP_SOFT_MAX;
|
result->op = GGML_OP_SOFT_MAX;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -4641,14 +4651,26 @@ struct ggml_tensor * ggml_soft_max(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_soft_max(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
ggml_soft_max_impl(ctx, a, false);
|
||||||
|
}
|
||||||
|
struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
ggml_soft_max_impl(ctx, a, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_rope
|
// ggml_rope
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope(
|
struct ggml_tensor * ggml_rope_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode,
|
||||||
|
bool inplace) {
|
||||||
GGML_ASSERT(n_past >= 0);
|
GGML_ASSERT(n_past >= 0);
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
|
@ -4657,9 +4679,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
|
|
||||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||||
((int32_t *) b->data)[0] = n_past;
|
((int32_t *) b->data)[0] = n_past;
|
||||||
|
@ -4673,6 +4693,23 @@ struct ggml_tensor * ggml_rope(
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
struct ggml_tensor * ggml_rope(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past,
|
||||||
|
int n_dims,
|
||||||
|
int mode){
|
||||||
|
ggml_rope_impl(ctx, a, n_past, n_dims, mode, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past,
|
||||||
|
int n_dims,
|
||||||
|
int mode){
|
||||||
|
ggml_rope_impl(ctx, a, n_past, n_dims, mode, true);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_conv_1d_1s
|
// ggml_conv_1d_1s
|
||||||
|
|
||||||
|
|
33
ggml.h
33
ggml.h
|
@ -470,27 +470,45 @@ struct ggml_tensor * ggml_repeat(
|
||||||
struct ggml_tensor * ggml_abs(
|
struct ggml_tensor * ggml_abs(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_abs_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_sgn(
|
struct ggml_tensor * ggml_sgn(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_sgn_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_neg(
|
struct ggml_tensor * ggml_neg(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_neg_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_step(
|
struct ggml_tensor * ggml_step(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_step_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_relu(
|
struct ggml_tensor * ggml_relu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_relu_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// TODO: double-check this computation is correct
|
// TODO: double-check this computation is correct
|
||||||
struct ggml_tensor * ggml_gelu(
|
struct ggml_tensor * ggml_gelu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_gelu_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
struct ggml_tensor * ggml_silu(
|
struct ggml_tensor * ggml_silu(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -605,16 +623,22 @@ struct ggml_tensor * ggml_get_rows(
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// set elements above the diagonal to -INF
|
// set elements above the diagonal to -INF
|
||||||
// in-place, returns view(a)
|
|
||||||
struct ggml_tensor * ggml_diag_mask_inf(
|
struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past);
|
int n_past);
|
||||||
|
struct ggml_tensor * ggml_diag_mask_inf_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past);
|
||||||
|
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
struct ggml_tensor * ggml_soft_max(
|
struct ggml_tensor * ggml_soft_max(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
struct ggml_tensor * ggml_soft_max_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// rotary position embedding
|
// rotary position embedding
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
|
@ -626,7 +650,12 @@ struct ggml_tensor * ggml_rope(
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode);
|
int mode);
|
||||||
|
struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
int n_past,
|
||||||
|
int n_dims,
|
||||||
|
int mode);
|
||||||
// padding = 1
|
// padding = 1
|
||||||
// TODO: we don't support extra parameters for now
|
// TODO: we don't support extra parameters for now
|
||||||
// that's why we are hard-coding the stride, padding, and dilation
|
// that's why we are hard-coding the stride, padding, and dilation
|
||||||
|
|
|
@ -826,7 +826,7 @@ static bool llama_eval_internal(
|
||||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_rope(ctx0,
|
ggml_rope_inplace(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
Qcur,
|
Qcur,
|
||||||
n_embd/n_head, n_head, N),
|
n_embd/n_head, n_head, N),
|
||||||
|
@ -836,7 +836,7 @@ static bool llama_eval_internal(
|
||||||
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
// K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_rope(ctx0,
|
ggml_rope_inplace(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||||
n_embd/n_head, n_head, n_past + N),
|
n_embd/n_head, n_head, n_past + N),
|
||||||
|
@ -853,10 +853,10 @@ static bool llama_eval_internal(
|
||||||
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
||||||
|
|
||||||
// KQ_masked = mask_past(KQ_scaled)
|
// KQ_masked = mask_past(KQ_scaled)
|
||||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||||
|
|
||||||
// KQ = soft_max(KQ_masked)
|
// KQ = soft_max(KQ_masked)
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||||
|
|
||||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||||
struct ggml_tensor * V_trans =
|
struct ggml_tensor * V_trans =
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue