wip: correct tensors up to RoPE

This commit is contained in:
Phillip Kravtsov 2023-09-25 23:49:35 -07:00
parent 7cdc3eaa76
commit 4bcf412d86
3 changed files with 293 additions and 82 deletions

View file

@ -87,12 +87,16 @@ def main(args_in: list[str] | None = None) -> None:
gguf_writer.add_rope_dimension_count(hidden_size // head_count) gguf_writer.add_rope_dimension_count(hidden_size // head_count)
gguf_writer.add_head_count(head_count) gguf_writer.add_head_count(head_count)
gguf_writer.add_head_count_kv(head_count_kv) gguf_writer.add_head_count_kv(head_count_kv)
gguf_writer.add_rope_freq_base(hparams['rotary_emb_base'])
gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon'])
if True: if True:
tokens, scores, toktypes = handle_tokenizer(dir_model) tokens, scores, toktypes = handle_tokenizer(dir_model)
gguf_writer.add_tokenizer_model('llama') gguf_writer.add_tokenizer_model('llama')
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
gguf_writer.add_bos_token_id(71013)
gguf_writer.add_eos_token_id(71013)
tensor_map = gguf.get_tensor_name_map(arch, block_count) tensor_map = gguf.get_tensor_name_map(arch, block_count)
print(tensor_map) print(tensor_map)
tensors = {} tensors = {}
@ -105,15 +109,17 @@ def main(args_in: list[str] | None = None) -> None:
print(name) print(name)
# we don't need these # we don't need these
if name.endswith(".self_attention.rotary_emb.inv_freq"):
if name.endswith(".self_attention.rotary_emb.inv_freq"):
continue continue
old_dtype = data.dtype old_dtype = data.dtype
if 'layernorm.weight' in name: """
if 'layernorm.weight' in name or 'word_embeddings.weight' in name:
data = data.to(torch.float32) data = data.to(torch.float32)
else: else:
if data.dtype != torch.float16 and data.dtype != torch.float32: if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float16) data = data.to(torch.float32)
"""
data = data.to(torch.float32)
# check for nans # check for nans
if torch.isnan(data).any(): if torch.isnan(data).any():
print("WARNING: tensor '" + name + "' contains NaNs") print("WARNING: tensor '" + name + "' contains NaNs")

111
ggml.c
View file

@ -4290,6 +4290,65 @@ void ggml_print_objects(const struct ggml_context * ctx) {
GGML_PRINT("%s: --- end ---\n", __func__); GGML_PRINT("%s: --- end ---\n", __func__);
} }
static void ggml_print_tensor(const struct ggml_tensor * tensor) {
GGML_PRINT("Tensor (null): %s | rank %d | shape (", ggml_type_name(tensor->type), tensor->n_dims);
for (int i=0; i<tensor->n_dims; ++i) {
GGML_PRINT("%lld ", tensor->ne[i]);
}
GGML_PRINT(") | strides (");
for (int i=0; i<tensor->n_dims; ++i) {
GGML_PRINT("%lld ", tensor->nb[i]);
}
GGML_PRINT(")\n");
}
static void ggml_print_tensor_values(const struct ggml_tensor * tensor, int starts[], int dim, int nelts) {
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
GGML_PRINT("printing values for %s[", tensor->name);
for (int i=0; i<tensor->n_dims; ++i) {
if (i!=dim) {
GGML_PRINT("%d", starts[i]);
} else {
if (starts[i] > 0) {
GGML_PRINT("%d:%d", starts[i], starts[i]+nelts);
} else {
GGML_PRINT(":%d", starts[i]+nelts);
}
}
if (i<tensor->n_dims-1) {
GGML_PRINT(",");
}
}
GGML_PRINT("]\n");
float *dataPtr = (float *) tensor->data;
// Compute the offset into data for starts
int offset = 0;
for (int j = 0; j < tensor->n_dims; j++) {
offset += (starts[j] * tensor->nb[j]) / sizeof(float); // Assuming nb[j] is in bytes, divide by sizeof(float) to get float offset.
}
dataPtr += offset;
for (int i = 0; i < nelts; i++) {
GGML_PRINT("%f ", *dataPtr);
dataPtr += tensor->nb[dim] / sizeof(float); // Increment by strides for the given dimension.
}
GGML_PRINT("\n");
/*
char * ptr = (char *)tensor->data;
for (int j=0; j<tensor->n_dims;j++) {
ptr += tensor->nb[j]*starts[j];
}
for (int i=0; i<nelts; i++) {
GGML_PRINT("%f ", (*((float *) ptr)));
ptr += tensor->nb[dim];
}
GGML_PRINT("\n");
*/
}
int64_t ggml_nelements(const struct ggml_tensor * tensor) { int64_t ggml_nelements(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@ -6162,6 +6221,7 @@ struct ggml_tensor * ggml_mul_mat(
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] }; const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
GGML_PRINT("ggml_mul_mat result shape : (%lld, %lld, %lld, %lld)\n", ne[0], ne[1], ne[2], ne[3]);
result->op = GGML_OP_MUL_MAT; result->op = GGML_OP_MUL_MAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -8823,6 +8883,15 @@ static void ggml_compute_forward_add_f32(
} }
} }
} }
if ((strncmp(src0->name, "preadd", 6) == 0
|| strncmp(src0->name, "qkv_preadd", 10) == 0)
&& ith == 0) {
// print name
printf("\nadd outputs for %s\n", src0->name);
ggml_print_tensor(dst);
int starts[] = {0, 3, 0};
ggml_print_tensor_values(dst, starts, 0, 10);
}
} }
static void ggml_compute_forward_add_f16_f32( static void ggml_compute_forward_add_f16_f32(
@ -10804,6 +10873,18 @@ static void ggml_compute_forward_norm_f32(
} }
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
// If the name starts with "layer_inputs", and we are on thread 0, print the tensor
if ((strncmp(src0->name, "layer_inputs", 12) == 0
|| strncmp(src0->name, "tmpq", 4) == 0)
&& params->ith == 0) {
GGML_PRINT("\nlayernorm inputs for %s\n", src0->name);
ggml_print_tensor(src0);
int starts[] = {0, 1, 0};
ggml_print_tensor_values(src0, starts, 0, 10);
for (int i=64; i<74; ++i) {
GGML_PRINT("%f ", ggml_get_f32_1d(src0, i));
}
}
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -11227,8 +11308,25 @@ static void ggml_compute_forward_mul_mat(
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
int64_t t0 = ggml_perf_time_us(); int64_t t0 = ggml_perf_time_us();
UNUSED(t0); UNUSED(t0);
if (strncmp(src1->name, "KQ_soft_max", 11) == 0 && params->ith == 0
&& src1->ne[0] == src1->ne[1]) {
GGML_PRINT("\n KQ_softmax at mul mat time for %s\n", src1->name);
ggml_print_tensor(src1);
if (ggml_nelements(src1) >= 14) {
for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) {
if (i % src1->ne[1] == 0) {
GGML_PRINT("\n");
}
GGML_PRINT(" %f ", ((float *)src1->data)[i]);
}
GGML_PRINT("\n");
} else {
GGML_PRINT("Not enough elements to print\n");
}
}
GGML_TENSOR_BINARY_OP_LOCALS; GGML_TENSOR_BINARY_OP_LOCALS;
// If on thread 0, src1 starts with KQ_softmax, print
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
@ -12628,6 +12726,12 @@ static void ggml_compute_forward_rope_f32(
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
} }
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
GGML_PRINT("\nValues at RoPE time for %s\n", src0->name);
ggml_print_tensor(src0);
int starts[] = {0, 0, 1, 0};
ggml_print_tensor_values(src0, starts, 1, 10);
}
float freq_base; float freq_base;
float freq_scale; float freq_scale;
@ -12756,6 +12860,13 @@ static void ggml_compute_forward_rope_f32(
} }
} }
} }
if (strncmp(src0->name, "qrot", 4) == 0 && params->ith == 0) {
GGML_PRINT("\n dest at RoPE time for %s\n", src0->name);
// print shape and strides
int starts[4] = {0,0,0,0};
ggml_print_tensor(dst);
ggml_print_tensor_values(dst, starts, 0, 10);
}
} }
static void ggml_compute_forward_rope_f16( static void ggml_compute_forward_rope_f16(

250
llama.cpp
View file

@ -2337,25 +2337,6 @@ static void llm_load_tensors(
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT;
auto & layer = model.layers[i]; auto & layer = model.layers[i];
/*
input_layernorm.bias torch.Size([4096])
input_layernorm.weight torch.Size([4096])
mlp.dense_4h_to_h.bias torch.Size([4096])
mlp.dense_4h_to_h.weight torch.Size([4096, 16384])
mlp.dense_h_to_4h.bias torch.Size([16384])
mlp.dense_h_to_4h.weight torch.Size([16384, 4096])
post_attention_layernorm.bias torch.Size([4096])
post_attention_layernorm.weight torch.Size([4096])
self_attention.dense.bias torch.Size([4096])
self_attention.dense.weight torch.Size([4096, 4096])
self_attention.k_layernorm.bias torch.Size([64])
self_attention.k_layernorm.weight torch.Size([64])
self_attention.q_layernorm.bias torch.Size([64])
self_attention.q_layernorm.weight torch.Size([64])
self_attention.query_key_value.bias torch.Size([12288])
self_attention.query_key_value.weight torch.Size([12288, 4096])
self_attention.rotary_emb.inv_freq torch.Size([16])
*/
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
@ -3744,6 +3725,20 @@ static struct ggml_cgraph * llm_build_starcoder(
return gf; return gf;
} }
static void log_tensor(
ggml_tensor * a
) {
LLAMA_LOG_INFO("Shape of %s is ", a->name);
for (int i = 0; i < a->n_dims; ++i) {
LLAMA_LOG_INFO("%d", a->ne[i]);
if (i < a->n_dims - 1) {
LLAMA_LOG_INFO(",");
}
LLAMA_LOG_INFO(" ");
}
LLAMA_LOG_INFO("\n");
}
static struct ggml_cgraph * llm_build_adept( static struct ggml_cgraph * llm_build_adept(
llama_context & lctx, llama_context & lctx,
const llama_token * tokens, const llama_token * tokens,
@ -3760,7 +3755,7 @@ static struct ggml_cgraph * llm_build_adept(
GGML_ASSERT(!!kv_self.ctx); GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = 1;
const int64_t n_ctx = hparams.n_ctx; const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
@ -3785,18 +3780,28 @@ static struct ggml_cgraph * llm_build_adept(
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
if (tokens) { if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, inp_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, inp_tokens, N*ggml_element_size(inp_tokens)); memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
} }
ggml_set_name(inp_tokens, "inp_tokens");
LLAMA_LOG_INFO("Token ids:\n", __func__); LLAMA_LOG_INFO("Token ids:\n", __func__);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
LLAMA_LOG_INFO(" %d ", tokens[i]); LLAMA_LOG_INFO(" %d ", tokens[i]);
} }
ggml_set_name(inp_tokens, "inp_tokens"); LLAMA_LOG_INFO("\n", __func__);
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
/*
LLAMA_LOG_INFO("\ninpL:\n", __func__);
if (ggml_nelements(model.tok_embeddings) >= 5) {
for (int i=0; i < 5; ++i) {
LLAMA_LOG_INFO(" %f ", ggml_get_f32_1d(model.tok_embeddings, i));
}
LLAMA_LOG_INFO("\n");
} else {
LLAMA_LOG_INFO("Not enough elements to print\n", __func__);
}
*/
} else { } else {
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
ggml_allocr_alloc(lctx.alloc, inpL); ggml_allocr_alloc(lctx.alloc, inpL);
@ -3804,7 +3809,6 @@ static struct ggml_cgraph * llm_build_adept(
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
} }
} }
// Log all of the token ids sequentially
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale); ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
@ -3813,63 +3817,159 @@ static struct ggml_cgraph * llm_build_adept(
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
//LLAMA_LOG_INFO("Entering n_layers loop\n", __func__); //LLAMA_LOG_INFO("Entering n_layers loop\n", __func__);
for (int il=0; il < n_layer; ++il) { for (int il=0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;
offload_func_t offload_func = llama_nop; offload_func_t offload_func = llama_nop;
// Input is (d_model, L)
// Attention // Attention
struct ggml_tensor * residual = inpL;
ggml_set_name(residual, format((char*)"layer_inputs_%d", il).c_str());
{ {
// input norming // input norming
attn_norm = ggml_norm(ctx0, inpL, hparams.f_norm_eps); cur = ggml_norm(ctx0, inpL, hparams.f_norm_eps);
attn_norm = ggml_add(ctx0, ggml_mul( cur = ggml_add(ctx0, ggml_mul(
ctx0, attn_norm, model.layers[il].attn_norm), ctx0, cur, model.layers[il].attn_norm),
model.layers[il].attn_norm_b); model.layers[il].attn_norm_b);
}
// QKV + bias ggml_set_name(cur, "cur");
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm); {
// QKV
log_tensor(cur);
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
// 3 * d_model, L
// or 2 * n_head_kv + n_embd_head, L
// + bias
ggml_format_name(cur, "qkv_preadd_%d", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
const size_t wsize = ggml_type_size(cur->type);
// Apply Q, K layernorm // Apply Q, K layernorm
// Where is the Q/K/V? it's in order. Hopefully...
// So q has offset 0.
// And split into heads
// -> (d_h, n_head, L)
const size_t wsize = ggml_type_size(cur->type);
GGML_ASSERT(n_head_kv == n_head);
LLAMA_LOG_INFO("N: %d\n", N);
ggml_set_name(cur, format("qkv_%d", il).c_str());
log_tensor(cur);
// cur is (3 * d_head * n_head, N)
struct ggml_tensor * tmpqkv = ggml_view_4d(
ctx0, cur, n_embd_head, 3, n_head, N,
/* nb1 = */ wsize * n_embd_head,
/* nb2 = */ wsize * n_embd_head * 3,
/* nb3 = */ wsize * n_embd_head * 3 * n_head,
/* offset = */ 0
);
// get it to (d_h, n_head, L, 3)
struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2));
ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il);
log_tensor(tmpqkv_perm);
struct ggml_tensor * tmpq = ggml_cont( struct ggml_tensor * tmpq = ggml_cont(
ctx0, ggml_view_3d( ctx0,
ctx0, cur, n_embd_head, n_head, N, ggml_view_3d(
wsize * n_embd_head, ctx0, tmpqkv_perm, n_embd_head, n_head, N,
wsize * n_embd_head * (n_head + 2 * n_head_kv), /* nb1 = */ sizeof(float) * n_embd_head,
0 /* nb2 = */ sizeof(float) * n_embd_head * n_head,
/* offset = */ 0
) )
); );
struct ggml_tensor * tmpk = ggml_cont( struct ggml_tensor * tmpk = ggml_cont(
ctx0, ggml_view_3d( ctx0,
ctx0, cur, n_embd_head, n_head, N, ggml_view_3d(
wsize * n_embd_head, ctx0, tmpqkv_perm, n_embd_head, n_head, N,
wsize * n_embd_head * (n_head + 2 * n_head_kv), /* nb1 = */ sizeof(float) * n_embd_head,
wsize * n_embd_head * n_head /* nb2 = */ sizeof(float) * n_embd_head * n_head,
/* offset = */ sizeof(float) * n_embd_head * n_head * N
) )
); );
struct ggml_tensor * tmpv = ggml_cont(
ctx0,
ggml_view_3d(
ctx0, tmpqkv_perm, n_embd_head, n_head, N,
/* nb1 = */ sizeof(float) * n_embd_head,
/* nb2 = */ sizeof(float) * n_embd_head * n_head,
/* offset = */ sizeof(float) * n_embd_head * n_head * N * 2
)
);
ggml_set_name(tmpq, format("tmpq_%d", il).c_str());
tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps);
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
ggml_set_name(tmpq, format("preadd_%d", il).c_str());
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps); tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps);
tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
ggml_set_name(tmpq, format("tmpq_%d", il).c_str());
tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); ggml_set_name(tmpk, format("tmpk_%d", il).c_str());
tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); log_tensor(tmpq);
tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); log_tensor(tmpk);
struct ggml_tensor * Qcur = ggml_rope_custom_inplace( const size_t n_rot = n_embd_head / 2;
ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d(
); ctx0, tmpq, n_rot, n_head, N,
struct ggml_tensor * Kcur = ggml_rope_custom_inplace( /* nb1 = */ wsize * n_embd_head,
ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale /* nb2 = */ wsize * n_embd_head * n_head,
); /* offset = */ 0
));
struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_permute(ctx0, ggml_view_3d(
ctx0, tmpq, n_rot, n_head, N,
/* nb1 = */ wsize * n_rot,
/* nb2 = */ wsize * n_rot * n_head,
/* offset = */ (wsize * n_embd_head * n_head) / 2
), 2, 1, 0, 3));
ggml_set_name(qrot, format("qrot_%d", il).c_str());
ggml_set_name(qpass, format("qpass_%d", il).c_str());
log_tensor(qrot);
log_tensor(qpass);
struct ggml_tensor * tmpv = ggml_view_3d( struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
ctx0, cur, n_embd_head, n_head_kv, N, ctx0, tmpk, n_rot, n_head, N,
wsize * n_embd_head, /* nb1 = */ wsize * n_rot,
wsize * n_embd_head * (n_head + 2 * n_head_kv), /* nb2 = */ wsize * n_rot * n_head,
wsize * n_embd_head * (n_head + n_head_kv)); /* offset = */ 0
));
struct ggml_tensor * kpass = ggml_cont(ctx0,
ggml_permute(ctx0,
ggml_view_3d(
ctx0, tmpk, n_rot, n_head, N,
/* nb1 = */ wsize * n_rot,
/* nb2 = */ wsize * n_rot * n_head,
/* offset = */ (wsize * n_embd_head * n_head) / 2
), 2, 1, 0, 3));
ggml_set_name(krot, format("krot_%d", il).c_str());
ggml_set_name(kpass, format("kpass_%d", il).c_str());
log_tensor(krot);
log_tensor(kpass);
struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_rope_custom_inplace(
ctx0, qrot, n_past, n_rot, 0, 0, freq_base, freq_scale
),
2, 1, 0, 3
));
struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0,
ggml_rope_custom_inplace(
ctx0, krot, n_past, n_rot, 0, 0, freq_base, freq_scale
),
2, 1, 0, 3
));
ggml_set_name(qrotated, format("qrotated_%d", il).c_str());
ggml_set_name(krotated, format("krotated_%d", il).c_str());
log_tensor(qrotated);
log_tensor(krotated);
struct ggml_tensor * Qcur = ggml_cont(ctx0,
ggml_permute(ctx0,
ggml_concat(ctx0, qrotated, qpass),
2, 1, 0, 3));
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_permute(ctx0, ggml_concat(ctx0, krotated, kpass), 2, 1, 0, 3));
ggml_set_name(Qcur, format("Qcur_%d", il).c_str());
ggml_set_name(Kcur, format("Kcur_%d", il).c_str());
log_tensor(Qcur);
log_tensor(Kcur);
{ {
// Set kv cache elements?
struct ggml_tensor * Vcur = ggml_transpose( struct ggml_tensor * Vcur = ggml_transpose(
ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N) ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N)
); );
@ -3886,11 +3986,11 @@ static struct ggml_cgraph * llm_build_adept(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
} }
//LLAMA_LOG_INFO("3889\n", __func__);
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
ggml_set_name(Q, "Q"); ggml_set_name(Q, "Q");
log_tensor(Q);
// index into kv cache? // view kv cache?
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k, ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv, n_embd_head, n_past + N, n_head_kv,
@ -3907,12 +4007,11 @@ static struct ggml_cgraph * llm_build_adept(
ggml_set_name(KQ_scaled, "KQ_scaled"); ggml_set_name(KQ_scaled, "KQ_scaled");
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
ggml_set_name(KQ_masked, "KQ_soft_max"); ggml_set_name(KQ_masked, "KQ_mask");
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
ggml_set_name(KQ_soft_max, "KQ_soft_max"); ggml_set_name(KQ_soft_max, format("KQ_soft_max_%d", il).c_str());
//LLAMA_LOG_INFO("3915\n", __func__);
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv, n_past + N, n_embd_head, n_head_kv,
@ -3932,15 +4031,19 @@ static struct ggml_cgraph * llm_build_adept(
cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
ggml_set_name(cur, "result_wo"); ggml_set_name(cur, "result_wo");
//LLAMA_LOG_INFO("EoWo\n", __func__);
} }
struct ggml_tensor * attn_out = cur; cur = ggml_add(ctx0, residual, cur);
residual = cur;
ggml_set_name(residual, "residual");
{ {
struct ggml_tensor * inpFF = attn_norm; struct ggml_tensor * inpFF = cur;
// Norm // Norm
{ {
cur = ggml_norm(ctx0, inpFF, hparams.f_norm_eps); cur = ggml_norm(ctx0, inpFF, hparams.f_norm_eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.layers[il].ffn_norm),
model.layers[il].ffn_norm_b
);
} }
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
// Squared ReLU // Squared ReLU
@ -3948,31 +4051,22 @@ static struct ggml_cgraph * llm_build_adept(
cur = ggml_mul(ctx0, cur, cur); cur = ggml_mul(ctx0, cur, cur);
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
} }
cur = ggml_add(ctx0, cur, attn_out); cur = ggml_add(ctx0, cur, residual);
ggml_set_name(cur, "inpFF_+_attn_out"); ggml_set_name(cur, "inpFF_+_attn_out");
inpL = cur; inpL = cur;
//LLAMA_LOG_INFO("EoL\n", __func__);
} }
//LLAMA_LOG_INFO("Exited from n_layers loop\n", __func__);
cur = inpL; cur = inpL;
{ {
//LLAMA_LOG_INFO("norm\n", __func__);
cur = ggml_norm(ctx0, cur, hparams.f_norm_eps); cur = ggml_norm(ctx0, cur, hparams.f_norm_eps);
//LLAMA_LOG_INFO("ggml_norm\n", __func__);
cur = ggml_add(ctx0, cur = ggml_add(ctx0,
ggml_mul(ctx0, cur, model.output_norm), ggml_mul(ctx0, cur, model.output_norm),
model.output_norm_b); model.output_norm_b);
//LLAMA_LOG_INFO("result_norm\n", __func__);
ggml_set_name(cur, "result_norm"); ggml_set_name(cur, "result_norm");
} }
//LLAMA_LOG_INFO("matmul\n", __func__);
cur = ggml_mul_mat(ctx0, model.output, cur); cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output"); ggml_set_name(cur, "result_output");
//LLAMA_LOG_INFO("bf expand\n", __func__);
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
//LLAMA_LOG_INFO("Freeing ctx0\n", __func__);
ggml_free(ctx0); ggml_free(ctx0);
//LLAMA_LOG_INFO("Exiting fun\n", __func__);
return gf; return gf;
} }