diff --git a/convert-adept-st-to-gguf.py b/convert-persimmon-st-to-gguf.py similarity index 78% rename from convert-adept-st-to-gguf.py rename to convert-persimmon-st-to-gguf.py index 1a6eda8a1..ee0d2b1d8 100644 --- a/convert-adept-st-to-gguf.py +++ b/convert-persimmon-st-to-gguf.py @@ -19,7 +19,7 @@ def file_is_safetensors(path: Path) -> bool: return False return struct.unpack(' None: - parser = argparse.ArgumentParser(description="Convert an Adept model (e.g. Persimmon 8b) to a GGML compatible file") + parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--outtype", choices=["f32"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--outtype", choices=["f32"], help="currently only support fp32") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.safetensors)") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") args = parser.parse_args(args_in) assert file_is_safetensors(args.model), 'Error: model file is not a SafeTensors file' - model = lazy_load_safetensors_file(open(args.model, 'rb'), args.model) dir_model = args.model.parent with open(dir_model / 'config.json', 'r') as f: hparams = json.load(f) pprint(hparams) - arch = gguf.MODEL_ARCH.ADEPT + arch = gguf.MODEL_ARCH.PERSIMMON gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch]) block_count = hparams['num_layers'] @@ -90,7 +89,7 @@ def main(args_in: list[str] | None = None) -> None: gguf_writer.add_rope_freq_base(hparams['rotary_emb_base']) gguf_writer.add_layer_norm_eps(hparams['layernorm_epsilon']) if True: - tokens, scores, toktypes = handle_tokenizer(dir_model) + tokens, scores, toktypes = get_tokenizer_info(dir_model) gguf_writer.add_tokenizer_model('llama') gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) @@ -103,32 +102,13 @@ def main(args_in: list[str] | None = None) -> None: with safe_open(args.model, framework="pt") as f: for k in f.keys(): tensors[k] = f.get_tensor(k) - print(len(tensors.keys())) for name in tensors.keys(): data = tensors[name] - print(name) - - # we don't need these if name.endswith(".self_attention.rotary_emb.inv_freq"): continue old_dtype = data.dtype - """ - if 'layernorm.weight' in name or 'word_embeddings.weight' in name: - data = data.to(torch.float32) - else: - if data.dtype != torch.float16 and data.dtype != torch.float32: - data = data.to(torch.float32) - """ - data = data.to(torch.float32) - # check for nans - if torch.isnan(data).any(): - print("WARNING: tensor '" + name + "' contains NaNs") - sys.exit() - if torch.isinf(data).any(): - print("WARNING: tensor '" + name + "' contains infinities") - sys.exit() - - data = data.squeeze().numpy() + # TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?) + data = data.to(torch.float32).squeeze().numpy() new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) if new_name is None: print("Can not map tensor '" + name + "'") diff --git a/ggml.c b/ggml.c index a1afd037f..7e4099dcd 100644 --- a/ggml.c +++ b/ggml.c @@ -4304,49 +4304,34 @@ static void ggml_print_tensor(const struct ggml_tensor * tensor) { static void ggml_print_tensor_values(const struct ggml_tensor * tensor, int starts[], int dim, int nelts) { GGML_ASSERT(tensor->type == GGML_TYPE_F32); - GGML_PRINT("printing values for %s[", tensor->name); + GGML_PRINT("Printing values for tensor %s[", tensor->name); for (int i=0; in_dims; ++i) { - if (i!=dim) { - GGML_PRINT("%d", starts[i]); - } else { - if (starts[i] > 0) { + GGML_ASSERT(starts[i] >= 0); + if (i == dim) { + if (starts[i] > 0) { GGML_PRINT("%d:%d", starts[i], starts[i]+nelts); } else { GGML_PRINT(":%d", starts[i]+nelts); } + } else { + GGML_PRINT("%d", starts[i]); } if (in_dims-1) { GGML_PRINT(","); } } GGML_PRINT("]\n"); - - float *dataPtr = (float *) tensor->data; - - // Compute the offset into data for starts + float *data_ptr = (float *) tensor->data; int offset = 0; for (int j = 0; j < tensor->n_dims; j++) { - offset += (starts[j] * tensor->nb[j]) / sizeof(float); // Assuming nb[j] is in bytes, divide by sizeof(float) to get float offset. + offset += (starts[j] * tensor->nb[j]) / ggml_type_size(GGML_TYPE_F32); } - - dataPtr += offset; - + data_ptr += offset; for (int i = 0; i < nelts; i++) { - GGML_PRINT("%f ", *dataPtr); - dataPtr += tensor->nb[dim] / sizeof(float); // Increment by strides for the given dimension. + GGML_PRINT("%f ", *data_ptr); + data_ptr += tensor->nb[dim] / ggml_type_size(GGML_TYPE_F32); } GGML_PRINT("\n"); - /* - char * ptr = (char *)tensor->data; - for (int j=0; jn_dims;j++) { - ptr += tensor->nb[j]*starts[j]; - } - for (int i=0; inb[dim]; - } - GGML_PRINT("\n"); - */ } int64_t ggml_nelements(const struct ggml_tensor * tensor) { @@ -8883,14 +8868,14 @@ static void ggml_compute_forward_add_f32( } } } - if ( - strncmp(src0->name, "printme", 7) == 0 + if ((strncmp(src0->name, "printme", 7) == 0 + ||strncmp(src1->name, "printme", 7) == 0) && params->ith == 0) { GGML_PRINT("\noutputs of add: %s + %s\n", src0->name, src1->name); ggml_print_tensor(src0); ggml_print_tensor(src1); ggml_print_tensor(dst); - int starts[] = {0, 1, 0}; + int starts[] = {0, 0, 0}; ggml_print_tensor_values(dst, starts, 0, 10); } } @@ -10879,11 +10864,8 @@ static void ggml_compute_forward_norm_f32( && params->ith == 0) { GGML_PRINT("\nlayernorm inputs for %s\n", src0->name); ggml_print_tensor(src0); - int starts[] = {0, 1, 0}; + int starts[] = {0, 0, 0}; ggml_print_tensor_values(src0, starts, 0, 10); - for (int i=64; i<74; ++i) { - GGML_PRINT("%f ", ggml_get_f32_1d(src0, i)); - } } const int ith = params->ith; @@ -11313,15 +11295,14 @@ static void ggml_compute_forward_mul_mat( && params->ith == 0) { GGML_PRINT("\nInputs to matmul: %s\n", src1->name); ggml_print_tensor(src1); - /* + size_t offset = 0;//(src1->ne[0] * src1->ne[1]) for (int i=0; i < src1->ne[0] * src1->ne[1]; ++i) { if (i % src1->ne[0] == 0) { GGML_PRINT("\n"); } - GGML_PRINT(" %f ", ((float *)src1->data)[i + (src1->ne[0] * src1->ne[1])]); + GGML_PRINT(" %f ", ((float *)src1->data)[i + offset]); } GGML_PRINT("\n"); - */ } GGML_TENSOR_BINARY_OP_LOCALS; diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 93a397109..8a1fc9316 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -85,7 +85,7 @@ class MODEL_ARCH(IntEnum): GPTNEOX : int = auto() MPT : int = auto() STARCODER : int = auto() - ADEPT : int = auto() + PERSIMMON : int = auto() class MODEL_TENSOR(IntEnum): @@ -119,7 +119,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GPTNEOX: "gptneox", MODEL_ARCH.MPT: "mpt", MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.ADEPT: "adept", + MODEL_ARCH.PERSIMMON: "persimmon", } MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { @@ -189,7 +189,7 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", }, - MODEL_ARCH.ADEPT: { + MODEL_ARCH.PERSIMMON: { MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.OUTPUT: "output", MODEL_TENSOR.OUTPUT_NORM: "output_norm", @@ -219,7 +219,7 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.ATTN_ROT_EMBD, ], - MODEL_ARCH.ADEPT: [ + MODEL_ARCH.PERSIMMON: [ MODEL_TENSOR.ROPE_FREQS, ] } diff --git a/llama.cpp b/llama.cpp index 66cef8b59..1b155e5b7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -162,7 +162,7 @@ enum llm_arch { LLM_ARCH_GPTNEOX, LLM_ARCH_MPT, LLM_ARCH_STARCODER, - LLM_ARCH_ADEPT, + LLM_ARCH_PERSIMMON, LLM_ARCH_UNKNOWN, }; @@ -175,7 +175,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_ADEPT, "adept" }, + { LLM_ARCH_PERSIMMON, "persimmon" }, }; enum llm_kv { @@ -378,7 +378,7 @@ static std::map> LLM_TENSOR_NAMES = }, }, { - LLM_ARCH_ADEPT, + LLM_ARCH_PERSIMMON, { { LLM_TENSOR_TOKEN_EMBD, "token_embd"}, { LLM_TENSOR_OUTPUT_NORM, "output_norm"}, @@ -2323,7 +2323,7 @@ static void llm_load_tensors( } } } break; - case LLM_ARCH_ADEPT: + case LLM_ARCH_PERSIMMON: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); @@ -3739,7 +3739,7 @@ static void log_tensor( LLAMA_LOG_INFO("\n"); } -static struct ggml_cgraph * llm_build_adept( +static struct ggml_cgraph * llm_build_persimmon( llama_context & lctx, const llama_token * tokens, const float * embd, @@ -3756,6 +3756,7 @@ static struct ggml_cgraph * llm_build_adept( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; + //const int64_t n_layer = 1; const int64_t n_ctx = hparams.n_ctx; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head = hparams.n_head; @@ -3811,105 +3812,74 @@ static struct ggml_cgraph * llm_build_adept( // Input is (d_model, L) // Attention struct ggml_tensor * residual = ggml_dup(ctx0, inpL); - ggml_set_name(residual, format((char*)"layer_inputs_%d", il).c_str()); + //ggml_format_name(inpL, "printme_layer_inputs_%d", il); { // input norming cur = ggml_norm(ctx0, inpL, hparams.f_norm_eps); - cur = ggml_add(ctx0, ggml_mul( - ctx0, cur, model.layers[il].attn_norm), - model.layers[il].attn_norm_b); + cur = ggml_mul( + ctx0, cur, model.layers[il].attn_norm); + //ggml_format_name(cur, "printme_normed_%d", il); + cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b); } ggml_set_name(cur, "cur"); { // QKV //log_tensor(cur); cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); - // 3 * d_model, L - // or 2 * n_head_kv + n_embd_head, L - // + bias ggml_format_name(cur, "qkv_preadd_%d", il); cur = ggml_add(ctx0, cur, model.layers[il].bqkv); // Apply Q, K layernorm - // Where is the Q/K/V? it's in order. Hopefully... - // So q has offset 0. - // And split into heads - // -> (d_h, n_head, L) - const size_t wsize = ggml_type_size(cur->type); + // split qkv GGML_ASSERT(n_head_kv == n_head); - //LLAMA_LOG_INFO("N: %d\n", N); ggml_set_name(cur, format("qkv_%d", il).c_str()); - //log_tensor(cur); - - // cur is (3 * d_head * n_head, N) - struct ggml_tensor * tmpqkv = ggml_view_4d( - ctx0, cur, n_embd_head, 3, n_head, N, - /* nb1 = */ wsize * n_embd_head, - /* nb2 = */ wsize * n_embd_head * 3, - /* nb3 = */ wsize * n_embd_head * 3 * n_head, - /* offset = */ 0 - ); + struct ggml_tensor * tmpqkv = ggml_reshape_4d(ctx0, cur, n_embd_head, 3, n_head, N); // get it to (d_h, n_head, L, 3) struct ggml_tensor * tmpqkv_perm = ggml_cont(ctx0, ggml_permute(ctx0, tmpqkv, 0, 3, 1, 2)); ggml_format_name(tmpqkv_perm, "tmpqkv_perm_%d", il); - //log_tensor(tmpqkv_perm); - struct ggml_tensor * tmpq = ggml_cont( - ctx0, - ggml_view_3d( + struct ggml_tensor * tmpq = ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, N, - /* nb1 = */ sizeof(float) * n_embd_head, - /* nb2 = */ sizeof(float) * n_embd_head * n_head, + /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, /* offset = */ 0 - ) - ); - struct ggml_tensor * tmpk = ggml_cont( - ctx0, - ggml_view_3d( + ); + struct ggml_tensor * tmpk = ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, N, - /* nb1 = */ sizeof(float) * n_embd_head, - /* nb2 = */ sizeof(float) * n_embd_head * n_head, - /* offset = */ sizeof(float) * n_embd_head * n_head * N - ) - ); - struct ggml_tensor * tmpv = ggml_cont( - ctx0, - ggml_view_3d( + /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + /* offset = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * N + ); + + struct ggml_tensor * tmpv = ggml_view_3d( ctx0, tmpqkv_perm, n_embd_head, n_head, N, - /* nb1 = */ sizeof(float) * n_embd_head, - /* nb2 = */ sizeof(float) * n_embd_head * n_head, - /* offset = */ sizeof(float) * n_embd_head * n_head * N * 2 - ) - ); - // Q / K layernorm - ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); + /* nb1 = */ ggml_element_size(tmpqkv_perm) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head, + /* offset = */ ggml_element_size(tmpqkv_perm) * n_embd_head * n_head * N * 2 + ); + //ggml_format_name(tmpq, "printme_tmpq_%d", il); tmpq = ggml_norm(ctx0, tmpq, hparams.f_norm_eps); tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm); tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b); - ggml_set_name(tmpq, format("tmpq_%d", il).c_str()); - //log_tensor(tmpq); + //ggml_format_name(tmpq, "printme_tmpk_%d", il); tmpk = ggml_norm(ctx0, tmpk, hparams.f_norm_eps); tmpk = ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm); - ggml_set_name(tmpk, format("preadd_%d", il).c_str()); tmpk = ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b); - ggml_set_name(tmpk, format("tmpk_%d", il).c_str()); - //log_tensor(tmpk); - - const size_t n_rot = n_embd_head / 2; + struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpq, n_rot, n_head, N, - /* nb1 = */ wsize * n_embd_head, - /* nb2 = */ wsize * n_embd_head * n_head, + /* nb1 = */ ggml_element_size(tmpq) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpq) * n_embd_head * n_head, /* offset = */ 0 )); // get the second half of tmpq, e.g tmpq[n_rot:, :, :] struct ggml_tensor * qpass = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpq, n_rot, n_head, N, - /* nb1 = */ wsize * n_embd_head, - /* nb2 = */ wsize * n_embd_head * n_head, - /* offset = */ wsize * n_rot + /* nb1 = */ ggml_element_size(tmpq) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpq) * n_embd_head * n_head, + /* offset = */ ggml_element_size(tmpq) * n_rot )); ggml_set_name(qrot, format("qrot_%d", il).c_str()); ggml_set_name(qpass, format("qpass_%d", il).c_str()); @@ -3918,20 +3888,18 @@ static struct ggml_cgraph * llm_build_adept( struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpk, n_rot, n_head, N, - /* nb1 = */ wsize * n_embd_head, - /* nb2 = */ wsize * n_embd_head * n_head, + /* nb1 = */ ggml_element_size(tmpk) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpk) * n_embd_head * n_head, /* offset = */ 0 )); struct ggml_tensor * kpass = ggml_cont(ctx0, ggml_view_3d( ctx0, tmpk, n_rot, n_head, N, - /* nb1 = */ wsize * n_embd_head, - /* nb2 = */ wsize * n_embd_head * n_head, - /* offset = */ wsize * n_rot + /* nb1 = */ ggml_element_size(tmpk) * n_embd_head, + /* nb2 = */ ggml_element_size(tmpk) * n_embd_head * n_head, + /* offset = */ ggml_element_size(tmpk) * n_rot )); ggml_set_name(krot, format("krot_%d", il).c_str()); ggml_set_name(kpass, format("kpass_%d", il).c_str()); - //log_tensor(krot); - //log_tensor(kpass); struct ggml_tensor * qrotated = ggml_cont(ctx0, ggml_permute(ctx0, ggml_rope_custom_inplace( @@ -3939,17 +3907,15 @@ static struct ggml_cgraph * llm_build_adept( ), 2, 1, 0, 3 )); - ggml_set_name(qrotated, format("qrotated_%d", il).c_str()); - //log_tensor(qrotated); qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3)); + + //ggml_format_name(krot, "printme_krot_%d", il); struct ggml_tensor * krotated = ggml_cont(ctx0, ggml_permute(ctx0, ggml_rope_custom_inplace( ctx0, krot, n_past, n_rot, 2, 0, freq_base, freq_scale ), 2, 1, 0, 3 )); - ggml_set_name(krotated, format("krotated_%d", il).c_str()); - //log_tensor(krotated); kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3)); struct ggml_tensor * Qcur = ggml_cont(ctx0, @@ -3962,16 +3928,12 @@ static struct ggml_cgraph * llm_build_adept( ); ggml_set_name(Qcur, format("Qcur_%d", il).c_str()); ggml_set_name(Kcur, format("Kcur_%d", il).c_str()); - //log_tensor(Qcur); - //////log_tensor(Kcur); - //log_tensor(kv_self.k); { // View v as (N, n_embd) struct ggml_tensor * Vcur = ggml_transpose( ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd, N) ); ggml_set_name(Vcur, "Vcur"); - // Select k from kv cache as 1d view (N * n_embd) struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past) @@ -3997,7 +3959,6 @@ static struct ggml_cgraph * llm_build_adept( ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il)); ggml_set_name(K, "K"); - //log_tensor(K); struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); ggml_set_name(KQ, "KQ"); @@ -4009,7 +3970,7 @@ static struct ggml_cgraph * llm_build_adept( ggml_set_name(KQ_masked, "KQ_mask"); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); - ggml_set_name(KQ_soft_max, format("KQ_soft_max_%d", il).c_str()); + //ggml_set_name(KQ_soft_max, format("printme_KQ_soft_max_%d", il).c_str()); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, @@ -4031,7 +3992,6 @@ static struct ggml_cgraph * llm_build_adept( cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); cur = ggml_add(ctx0, cur, model.layers[il].bo); ggml_set_name(cur, "result_wo"); - //log_tensor(cur); } cur = ggml_add(ctx0, residual, cur); struct ggml_tensor * residual2 = ggml_dup(ctx0, cur); @@ -4044,17 +4004,12 @@ static struct ggml_cgraph * llm_build_adept( model.layers[il].ffn_norm_b ); } - // FFN cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); - ggml_format_name(cur, "pre_act_%d", il); cur = ggml_add(ctx0, cur, model.layers[il].b3); - // //log_tensor(cur); - // Correct through here. - // Squared ReLU cur = ggml_relu(ctx0, cur); cur = ggml_sqr(ctx0, cur); cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); - ggml_format_name(cur, "post_ffn_down_%d", il); + //ggml_format_name(cur, "printme_ffn_down_%d", il); struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, model.layers[il].b2); @@ -4105,9 +4060,9 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past); } break; - case LLM_ARCH_ADEPT: + case LLM_ARCH_PERSIMMON: { - result = llm_build_adept(lctx, tokens, embd, n_tokens, n_past); + result = llm_build_persimmon(lctx, tokens, embd, n_tokens, n_past); } break; default: GGML_ASSERT(false);