wip: correct tensors up to RoPE
This commit is contained in:
parent
7cdc3eaa76
commit
4bcf412d86
3 changed files with 293 additions and 82 deletions
|
@ -87,12 +87,16 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
|
||||
gguf_writer.add_head_count(head_count)
|
||||
gguf_writer.add_head_count_kv(head_count_kv)
|
||||
gguf_writer.add_rope_freq_base(hparams['rotary_emb_base'])
|
||||
gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon'])
|
||||
if True:
|
||||
tokens, scores, toktypes = handle_tokenizer(dir_model)
|
||||
gguf_writer.add_tokenizer_model('llama')
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
gguf_writer.add_token_types(toktypes)
|
||||
gguf_writer.add_bos_token_id(71013)
|
||||
gguf_writer.add_eos_token_id(71013)
|
||||
tensor_map = gguf.get_tensor_name_map(arch, block_count)
|
||||
print(tensor_map)
|
||||
tensors = {}
|
||||
|
@ -105,15 +109,17 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
print(name)
|
||||
|
||||
# we don't need these
|
||||
|
||||
if name.endswith(".self_attention.rotary_emb.inv_freq"):
|
||||
continue
|
||||
old_dtype = data.dtype
|
||||
if 'layernorm.weight' in name:
|
||||
"""
|
||||
if 'layernorm.weight' in name or 'word_embeddings.weight' in name:
|
||||
data = data.to(torch.float32)
|
||||
else:
|
||||
if data.dtype != torch.float16 and data.dtype != torch.float32:
|
||||
data = data.to(torch.float16)
|
||||
data = data.to(torch.float32)
|
||||
"""
|
||||
data = data.to(torch.float32)
|
||||
# check for nans
|
||||
if torch.isnan(data).any():
|
||||
print("WARNING: tensor '" + name + "' contains NaNs")
|
||||
|
|
111
ggml.c
111
ggml.c
|
@ -4290,6 +4290,65 @@ void ggml_print_objects(const struct ggml_context * ctx) {
|
|||
GGML_PRINT("%s: --- end ---\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_print_tensor(const struct ggml_tensor * tensor) {
|
||||
GGML_PRINT("Tensor (null): %s | rank %d | shape (", ggml_type_name(tensor->type), tensor->n_dims);
|
||||
for (int i=0; i<tensor->n_dims; ++i) {
|
||||
GGML_PRINT("%lld ", tensor->ne[i]);
|
||||
}
|
||||
GGML_PRINT(") | strides (");
|
||||
for (int i=0; i<tensor->n_dims; ++i) {
|
||||
GGML_PRINT("%lld ", tensor->nb[i]);
|
||||
}
|
||||
GGML_PRINT(")\n");
|
||||
}
|
||||
|
||||
static void ggml_print_tensor_values(const struct ggml_tensor * tensor, int starts[], int dim, int nelts) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
||||
GGML_PRINT("printing values for %s[", tensor->name);
|
||||
for (int i=0; i<tensor->n_dims; ++i) {
|
||||
if (i!=dim) {
|
||||
GGML_PRINT("%d", starts[i]);
|
||||
} else {
|
||||
if (starts[i] > 0) {
|
||||
GGML_PRINT("%d:%d", starts[i], starts[i]+nelts);
|
||||
} else {
|
||||
GGML_PRINT(":%d", starts[i]+nelts);
|
||||
}
|
||||
}
|
||||
if (i<tensor->n_dims-1) {
|
||||
GGML_PRINT(",");
|
||||
}
|
||||
}
|
||||
GGML_PRINT("]\n");
|
||||
|
||||
float *dataPtr = (float *) tensor->data;
|
||||
|
||||
// Compute the offset into data for starts
|
||||
int offset = 0;
|
||||
for (int j = 0; j < tensor->n_dims; j++) {
|
||||
offset += (starts[j] * tensor->nb[j]) / sizeof(float); // Assuming nb[j] is in bytes, divide by sizeof(float) to get float offset.
|
||||
}
|
||||
|
||||
dataPtr += offset;
|
||||
|
||||
for (int i = 0; i < nelts; i++) {
|
||||
GGML_PRINT("%f ", *dataPtr);
|
||||
dataPtr += tensor->nb[dim] / sizeof(float); // Increment by strides for the given dimension.
|
||||
}
|
||||
GGML_PRINT("\n");
|
||||
/*
|
||||
char * ptr = (char *)tensor->data;
|
||||
for (int j=0; j<tensor->n_dims;j++) {
|
||||
ptr += tensor->nb[j]*starts[j];
|
||||
}
|
||||
for (int i=0; i<nelts; i++) {
|
||||
GGML_PRINT("%f ", (*((float *) ptr)));
|
||||
ptr += tensor->nb[dim];
|
||||
}
|
||||
GGML_PRINT("\n");
|
||||
*/
|
||||
}
|
||||
|
||||
int64_t ggml_nelements(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
|
@ -6162,6 +6221,7 @@ struct ggml_tensor * ggml_mul_mat(
|
|||
|
||||
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
||||
GGML_PRINT("ggml_mul_mat result shape : (%lld, %lld, %lld, %lld)\n", ne[0], ne[1], ne[2], ne[3]);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
|
@ -8823,6 +8883,15 @@ static void ggml_compute_forward_add_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
if ((strncmp(src0->name, "preadd", 6) == 0
|
||||
|| strncmp(src0->name, "qkv_preadd", 10) == 0)
|
||||
&& ith == 0) {
|
||||
// print name
|
||||
printf("\nadd outputs for %s\n", src0->name);
|
||||
ggml_print_tensor(dst);
|
||||
int starts[] = {0, 3, 0};
|
||||
ggml_print_tensor_values(dst, starts, 0, 10);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_f16_f32(
|
||||
|
@ -10804,6 +10873,18 @@ static void ggml_compute_forward_norm_f32(
|
|||
}
|
||||
|
||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||
// If the name starts with "layer_inputs", and we are on thread 0, print the tensor
|
||||
if ((strncmp(src0->name, "layer_inputs", 12) == 0
|
||||
|| strncmp(src0->name, "tmpq", 4) == 0)
|
||||
&& params->ith == 0) {
|
||||
GGML_PRINT("\nlayernorm inputs for %s\n", src0->name);
|
||||
ggml_print_tensor(src0);
|
||||
int starts[] = {0, 1, 0};
|
||||
ggml_print_tensor_values(src0, starts, 0, 10);
|
||||
for (int i=64; i<74; ++i) {
|
||||
GGML_PRINT("%f ", ggml_get_f32_1d(src0, i));
|
||||
}
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
@ -11227,8 +11308,25 @@ static void ggml_compute_forward_mul_mat(
|
|||
struct ggml_tensor * dst) {
|
||||
int64_t t0 = ggml_perf_time_us();
|
||||
UNUSED(t0);
|
||||
if (strncmp(src1->name, "KQ_soft_max", 11) == 0 && params->ith == 0
|
||||
&& src1->ne[0] == src1->ne[1]) {
|
||||
GGML_PRINT("\n KQ_softmax at mul mat time for %s\n", src1->name);
|
||||
ggml_print_tensor(src1);
|
||||
if (ggml_nelements(src1) >= 14) {
|
||||
for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) {
|
||||
if (i % src1->ne[1] == 0) {
|
||||
GGML_PRINT("\n");
|
||||
}
|
||||
GGML_PRINT(" %f ", ((float *)src1->data)[i]);
|
||||
}
|
||||
GGML_PRINT("\n");
|
||||
} else {
|
||||
GGML_PRINT("Not enough elements to print\n");
|
||||
}
|
||||
}
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS;
|
||||
// If on thread 0, src1 starts with KQ_softmax, print
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
@ -12628,6 +12726,12 @@ static void ggml_compute_forward_rope_f32(
|
|||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
|
||||
GGML_PRINT("\nValues at RoPE time for %s\n", src0->name);
|
||||
ggml_print_tensor(src0);
|
||||
int starts[] = {0, 0, 1, 0};
|
||||
ggml_print_tensor_values(src0, starts, 1, 10);
|
||||
}
|
||||
|
||||
float freq_base;
|
||||
float freq_scale;
|
||||
|
@ -12756,6 +12860,13 @@ static void ggml_compute_forward_rope_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
|
||||
GGML_PRINT("\n dest at RoPE time for %s\n", src0->name);
|
||||
// print shape and strides
|
||||
int starts[4] = {0,0,0,0};
|
||||
ggml_print_tensor(dst);
|
||||
ggml_print_tensor_values(dst, starts, 0, 10);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rope_f16(
|
||||
|
|
250
llama.cpp
250
llama.cpp
|
@ -2337,25 +2337,6 @@ static void llm_load_tensors(
|
|||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
auto & layer = model.layers[i];
|
||||
/*
|
||||
input_layernorm.bias torch.Size([4096])
|
||||
input_layernorm.weight torch.Size([4096])
|
||||
mlp.dense_4h_to_h.bias torch.Size([4096])
|
||||
mlp.dense_4h_to_h.weight torch.Size([4096, 16384])
|
||||
mlp.dense_h_to_4h.bias torch.Size([16384])
|
||||
mlp.dense_h_to_4h.weight torch.Size([16384, 4096])
|
||||
post_attention_layernorm.bias torch.Size([4096])
|
||||
post_attention_layernorm.weight torch.Size([4096])
|
||||
self_attention.dense.bias torch.Size([4096])
|
||||
self_attention.dense.weight torch.Size([4096, 4096])
|
||||
self_attention.k_layernorm.bias torch.Size([64])
|
||||
self_attention.k_layernorm.weight torch.Size([64])
|
||||
self_attention.q_layernorm.bias torch.Size([64])
|
||||
self_attention.q_layernorm.weight torch.Size([64])
|
||||
self_attention.query_key_value.bias torch.Size([12288])
|
||||
self_attention.query_key_value.weight torch.Size([12288, 4096])
|
||||
self_attention.rotary_emb.inv_freq torch.Size([16])
|
||||
*/
|
||||
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
|
||||
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
|
||||
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
|
||||
|
@ -3744,6 +3725,20 @@ static struct ggml_cgraph * llm_build_starcoder(
|
|||
return gf;
|
||||
}
|
||||
|
||||
static void log_tensor(
|
||||
ggml_tensor * a
|
||||
) {
|
||||
LLAMA_LOG_INFO("Shape of %s is ", a->name);
|
||||
for (int i = 0; i < a->n_dims; ++i) {
|
||||
LLAMA_LOG_INFO("%d", a->ne[i]);
|
||||
if (i < a->n_dims - 1) {
|
||||
LLAMA_LOG_INFO(",");
|
||||
}
|
||||
LLAMA_LOG_INFO(" ");
|
||||
}
|
||||
LLAMA_LOG_INFO("\n");
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_adept(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
|
@ -3760,7 +3755,7 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_layer = 1;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
|
@ -3785,18 +3780,28 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
struct ggml_tensor * inpL;
|
||||
if (tokens) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, inp_tokens, N*ggml_element_size(inp_tokens));
|
||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||
}
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
LLAMA_LOG_INFO("Token ids:\n", __func__);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
LLAMA_LOG_INFO(" %d ", tokens[i]);
|
||||
}
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
|
||||
LLAMA_LOG_INFO("\n", __func__);
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
/*
|
||||
LLAMA_LOG_INFO("\ninpL:\n", __func__);
|
||||
if (ggml_nelements(model.tok_embeddings) >= 5) {
|
||||
for (int i=0; i < 5; ++i) {
|
||||
LLAMA_LOG_INFO(" %f ", ggml_get_f32_1d(model.tok_embeddings, i));
|
||||
}
|
||||
LLAMA_LOG_INFO("\n");
|
||||
} else {
|
||||
LLAMA_LOG_INFO("Not enough elements to print\n", __func__);
|
||||
}
|
||||
*/
|
||||
} else {
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
|
@ -3804,7 +3809,6 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
}
|
||||
// Log all of the token ids sequentially
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
|
@ -3813,63 +3817,159 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
//LLAMA_LOG_INFO("Entering n_layers loop\n", __func__);
|
||||
for (int il=0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * attn_norm;
|
||||
offload_func_t offload_func = llama_nop;
|
||||
// Input is (d_model, L)
|
||||
// Attention
|
||||
struct ggml_tensor * residual = inpL;
|
||||
ggml_set_name(residual, format((char*)"layer_inputs_%d", il).c_str());
|
||||
{
|
||||
// input norming
|
||||
attn_norm = ggml_norm(ctx0, inpL, hparams.f_norm_eps);
|
||||
attn_norm = ggml_add(ctx0, ggml_mul(
|
||||
ctx0, attn_norm, model.layers[il].attn_norm),
|
||||
cur = ggml_norm(ctx0, inpL, hparams.f_norm_eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(
|
||||
ctx0, cur, model.layers[il].attn_norm),
|
||||
model.layers[il].attn_norm_b);
|
||||
|
||||
// QKV + bias
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm);
|
||||
}
|
||||
ggml_set_name(cur, "cur");
|
||||
{
|
||||
// QKV
|
||||
log_tensor(cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
// 3 * d_model, L
|
||||
// or 2 * n_head_kv + n_embd_head, L
|
||||
// + bias
|
||||
ggml_format_name(cur, "qkv_preadd_%d", il);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
// Apply Q, K layernorm
|
||||
|
||||
// Where is the Q/K/V? it's in order. Hopefully...
|
||||
// So q has offset 0.
|
||||
// And split into heads
|
||||
// -> (d_h, n_head, L)
|
||||
const size_t wsize = ggml_type_size(cur->type);
|
||||
GGML_ASSERT(n_head_kv == n_head);
|
||||
LLAMA_LOG_INFO("N: %d\n", N);
|
||||
ggml_set_name(cur, format("qkv_%d", il).c_str());
|
||||
log_tensor(cur);
|
||||
|
||||
// cur is (3 * d_head * n_head, N)
|
||||
struct ggml_tensor * tmpqkv = ggml_view_4d(
|
||||
ctx0, cur, n_embd_head, 3, n_head, N,
|
||||
/* nb1 = */ wsize * n_embd_head,
|
||||
/* nb2 = */ wsize * n_embd_head * 3,
|
||||
/* nb3 = */ wsize * n_embd_head * 3 * n_head,
|
||||
/* offset = */ 0
|
||||
);
|
||||
// get it to (d_h, n_head, L, 3)
|
||||
struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
|
||||
ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il);
|
||||
log_tensor(tmpqkv_perm);
|
||||
struct ggml_tensor * tmpq = ggml_cont(
|
||||
ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
0
|
||||
ctx0,
|
||||
ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, N,
|
||||
/* nb1 = */ sizeof(float) * n_embd_head,
|
||||
/* nb2 = */ sizeof(float) * n_embd_head * n_head,
|
||||
/* offset = */ 0
|
||||
)
|
||||
);
|
||||
struct ggml_tensor * tmpk = ggml_cont(
|
||||
ctx0, ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * n_head
|
||||
ctx0,
|
||||
ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, N,
|
||||
/* nb1 = */ sizeof(float) * n_embd_head,
|
||||
/* nb2 = */ sizeof(float) * n_embd_head * n_head,
|
||||
/* offset = */ sizeof(float) * n_embd_head * n_head * N
|
||||
)
|
||||
);
|
||||
struct ggml_tensor * tmpv = ggml_cont(
|
||||
ctx0,
|
||||
ggml_view_3d(
|
||||
ctx0, tmpqkv_perm, n_embd_head, n_head, N,
|
||||
/* nb1 = */ sizeof(float) * n_embd_head,
|
||||
/* nb2 = */ sizeof(float) * n_embd_head * n_head,
|
||||
/* offset = */ sizeof(float) * n_embd_head * n_head * N * 2
|
||||
)
|
||||
);
|
||||
ggml_set_name(tmpq, format("tmpq_%d", il).c_str());
|
||||
tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps);
|
||||
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
|
||||
ggml_set_name(tmpq, format("preadd_%d", il).c_str());
|
||||
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
|
||||
|
||||
tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps);
|
||||
tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
|
||||
tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
|
||||
|
||||
tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps);
|
||||
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
|
||||
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
|
||||
ggml_set_name(tmpq, format("tmpq_%d", il).c_str());
|
||||
ggml_set_name(tmpk, format("tmpk_%d", il).c_str());
|
||||
log_tensor(tmpq);
|
||||
log_tensor(tmpk);
|
||||
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(
|
||||
ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale
|
||||
);
|
||||
const size_t n_rot = n_embd_head / 2;
|
||||
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, tmpq, n_rot, n_head, N,
|
||||
/* nb1 = */ wsize * n_embd_head,
|
||||
/* nb2 = */ wsize * n_embd_head * n_head,
|
||||
/* offset = */ 0
|
||||
));
|
||||
struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_permute(ctx0, ggml_view_3d(
|
||||
ctx0, tmpq, n_rot, n_head, N,
|
||||
/* nb1 = */ wsize * n_rot,
|
||||
/* nb2 = */ wsize * n_rot * n_head,
|
||||
/* offset = */ (wsize * n_embd_head * n_head) / 2
|
||||
), 2, 1, 0, 3));
|
||||
ggml_set_name(qrot, format("qrot_%d", il).c_str());
|
||||
ggml_set_name(qpass, format("qpass_%d", il).c_str());
|
||||
log_tensor(qrot);
|
||||
log_tensor(qpass);
|
||||
|
||||
struct ggml_tensor * tmpv = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
wsize * n_embd_head,
|
||||
wsize * n_embd_head * (n_head + 2 * n_head_kv),
|
||||
wsize * n_embd_head * (n_head + n_head_kv));
|
||||
struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
|
||||
ctx0, tmpk, n_rot, n_head, N,
|
||||
/* nb1 = */ wsize * n_rot,
|
||||
/* nb2 = */ wsize * n_rot * n_head,
|
||||
/* offset = */ 0
|
||||
));
|
||||
struct ggml_tensor * kpass = ggml_cont(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_view_3d(
|
||||
ctx0, tmpk, n_rot, n_head, N,
|
||||
/* nb1 = */ wsize * n_rot,
|
||||
/* nb2 = */ wsize * n_rot * n_head,
|
||||
/* offset = */ (wsize * n_embd_head * n_head) / 2
|
||||
), 2, 1, 0, 3));
|
||||
ggml_set_name(krot, format("krot_%d", il).c_str());
|
||||
ggml_set_name(kpass, format("kpass_%d", il).c_str());
|
||||
log_tensor(krot);
|
||||
log_tensor(kpass);
|
||||
|
||||
struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0,
|
||||
ggml_rope_custom_inplace(
|
||||
ctx0, qrot, n_past, n_rot, 0, 0, freq_base, freq_scale
|
||||
),
|
||||
2, 1, 0, 3
|
||||
));
|
||||
struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0,
|
||||
ggml_rope_custom_inplace(
|
||||
ctx0, krot, n_past, n_rot, 0, 0, freq_base, freq_scale
|
||||
),
|
||||
2, 1, 0, 3
|
||||
));
|
||||
ggml_set_name(qrotated, format("qrotated_%d", il).c_str());
|
||||
ggml_set_name(krotated, format("krotated_%d", il).c_str());
|
||||
log_tensor(qrotated);
|
||||
log_tensor(krotated);
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0,
|
||||
ggml_permute(ctx0,
|
||||
ggml_concat(ctx0, qrotated, qpass),
|
||||
2, 1, 0, 3));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), 2, 1, 0, 3));
|
||||
ggml_set_name(Qcur, format("Qcur_%d", il).c_str());
|
||||
ggml_set_name(Kcur, format("Kcur_%d", il).c_str());
|
||||
log_tensor(Qcur);
|
||||
log_tensor(Kcur);
|
||||
|
||||
{
|
||||
// Set kv cache elements?
|
||||
struct ggml_tensor * Vcur = ggml_transpose(
|
||||
ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)
|
||||
);
|
||||
|
@ -3886,11 +3986,11 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
//LLAMA_LOG_INFO("3889\n", __func__);
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
ggml_set_name(Q, "Q");
|
||||
log_tensor(Q);
|
||||
|
||||
// index into kv cache?
|
||||
// view kv cache?
|
||||
struct ggml_tensor * K =
|
||||
ggml_view_3d(ctx0, kv_self.k,
|
||||
n_embd_head, n_past + N, n_head_kv,
|
||||
|
@ -3907,12 +4007,11 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
ggml_set_name(KQ_masked, "KQ_soft_max");
|
||||
ggml_set_name(KQ_masked, "KQ_mask");
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||
ggml_set_name(KQ_soft_max, format("KQ_soft_max_%d", il).c_str());
|
||||
|
||||
//LLAMA_LOG_INFO("3915\n", __func__);
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, kv_self.v,
|
||||
n_past + N, n_embd_head, n_head_kv,
|
||||
|
@ -3932,15 +4031,19 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
//LLAMA_LOG_INFO("EoWo\n", __func__);
|
||||
}
|
||||
struct ggml_tensor * attn_out = cur;
|
||||
cur = ggml_add(ctx0, residual, cur);
|
||||
residual = cur;
|
||||
ggml_set_name(residual, "residual");
|
||||
{
|
||||
struct ggml_tensor * inpFF = attn_norm;
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
// Norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpFF, hparams.f_norm_eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.layers[il].ffn_norm),
|
||||
model.layers[il].ffn_norm_b
|
||||
);
|
||||
}
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
|
||||
// Squared ReLU
|
||||
|
@ -3948,31 +4051,22 @@ static struct ggml_cgraph * llm_build_adept(
|
|||
cur = ggml_mul(ctx0, cur, cur);
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
cur = ggml_add(ctx0, cur, residual);
|
||||
ggml_set_name(cur, "inpFF_+_attn_out");
|
||||
inpL = cur;
|
||||
//LLAMA_LOG_INFO("EoL\n", __func__);
|
||||
}
|
||||
//LLAMA_LOG_INFO("Exited from n_layers loop\n", __func__);
|
||||
cur = inpL;
|
||||
{
|
||||
//LLAMA_LOG_INFO("norm\n", __func__);
|
||||
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
|
||||
//LLAMA_LOG_INFO("ggml_norm\n", __func__);
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0, cur, model.output_norm),
|
||||
model.output_norm_b);
|
||||
//LLAMA_LOG_INFO("result_norm\n", __func__);
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
//LLAMA_LOG_INFO("matmul\n", __func__);
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
//LLAMA_LOG_INFO("bf expand\n", __func__);
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
//LLAMA_LOG_INFO("Freeing ctx0\n", __func__);
|
||||
ggml_free(ctx0);
|
||||
//LLAMA_LOG_INFO("Exiting fun\n", __func__);
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue