cuda : add asserts for rope/norm + fix DS2
ggml-ci
This commit is contained in:
parent
1e41f2fc7e
commit
5db268c9d8
5 changed files with 51 additions and 14 deletions
|
@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
return true;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
|
return ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_IM2COL:
|
case GGML_OP_IM2COL:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
|
|
|
@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
|
|
@ -251,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
float * dst_d = (float *)dst->data;
|
float * dst_d = (float *)dst->data;
|
||||||
cudaStream_t stream = ctx.stream();
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(src0->type == dst->type);
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
|
|
@ -2187,6 +2187,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_RMS_NORM:
|
case GGML_OP_RMS_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
@ -2214,6 +2215,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_OP_GROUP_NORM:
|
case GGML_OP_GROUP_NORM:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne00 % 4 == 0);
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
//float eps;
|
//float eps;
|
||||||
//memcpy(&eps, dst->op_params, sizeof(float));
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
@ -2247,6 +2249,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
|
||||||
|
|
||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params, sizeof(float));
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
|
52
llama.cpp
52
llama.cpp
|
@ -11187,46 +11187,69 @@ struct llm_build_context {
|
||||||
}
|
}
|
||||||
|
|
||||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||||
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
|
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||||
|
0);
|
||||||
cb(q_nope, "q_nope", il);
|
cb(q_nope, "q_nope", il);
|
||||||
|
|
||||||
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope);
|
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||||
|
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||||
|
ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
|
||||||
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||||
cb(compressed_kv_pe, "compressed_kv_pe", il);
|
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
|
||||||
|
|
||||||
// split into {kv_lora_rank, n_tokens}
|
// split into {kv_lora_rank, n_tokens}
|
||||||
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
|
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
|
||||||
cb(compressed_kv, "compressed_kv", il);
|
kv_pe_compresseed->nb[1],
|
||||||
|
0);
|
||||||
|
cb(kv_compressed, "kv_compressed", il);
|
||||||
|
|
||||||
// and {n_embd_head_qk_rope, n_tokens}
|
// and {n_embd_head_qk_rope, n_tokens}
|
||||||
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
|
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
|
||||||
|
kv_pe_compresseed->nb[1],
|
||||||
|
kv_pe_compresseed->nb[1],
|
||||||
|
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||||
cb(k_pe, "k_pe", il);
|
cb(k_pe, "k_pe", il);
|
||||||
|
|
||||||
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
|
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
||||||
|
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
||||||
model.layers[il].attn_kv_a_norm, NULL,
|
model.layers[il].attn_kv_a_norm, NULL,
|
||||||
LLM_NORM_RMS, cb, il);
|
LLM_NORM_RMS, cb, il);
|
||||||
cb(compressed_kv, "compressed_kv", il);
|
cb(kv_compressed, "kv_compressed", il);
|
||||||
|
|
||||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
|
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
||||||
cb(kv, "kv", il);
|
cb(kv, "kv", il);
|
||||||
|
|
||||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||||
|
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
|
0);
|
||||||
cb(k_nope, "k_nope", il);
|
cb(k_nope, "k_nope", il);
|
||||||
|
|
||||||
// and {n_head * n_embd_head_v, n_tokens}
|
// and {n_head * n_embd_head_v, n_tokens}
|
||||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope);
|
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||||
|
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
v_states = ggml_cont(ctx0, v_states);
|
v_states = ggml_cont(ctx0, v_states);
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
|
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
||||||
|
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
||||||
|
0);
|
||||||
cb(v_states, "v_states", il);
|
cb(v_states, "v_states", il);
|
||||||
|
|
||||||
|
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
q_pe = ggml_rope_ext(
|
q_pe = ggml_rope_ext(
|
||||||
ctx0, q_pe, inp_pos, nullptr,
|
ctx0, q_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
|
@ -11235,8 +11258,9 @@ struct llm_build_context {
|
||||||
cb(q_pe, "q_pe", il);
|
cb(q_pe, "q_pe", il);
|
||||||
|
|
||||||
// shared RoPE key
|
// shared RoPE key
|
||||||
|
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||||
k_pe = ggml_rope_ext(
|
k_pe = ggml_rope_ext(
|
||||||
ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
|
ctx0, k_pe, inp_pos, nullptr,
|
||||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||||
);
|
);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue