rebase to the latest

This commit is contained in:
ds5t5 2023-09-29 01:13:41 -07:00
parent 8b8c6d5052
commit af19099ab1
3 changed files with 87 additions and 55 deletions

View file

@ -6,10 +6,8 @@ from __future__ import annotations
import argparse import argparse
import json import json
import os import os
import struct
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Any
import numpy as np import numpy as np
import torch import torch
@ -235,6 +233,27 @@ for part_name in part_names:
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu") model_part = torch.load(dir_model / part_name, map_location="cpu")
for i in range(block_count):
if f"transformer.h.{i}.attn.kv.weight" in model_part:
data = model_part[f"transformer.h.{i}.attn.kv.weight"]
model_part[f"model.layers.{i}.self_attn.k_proj.weight"] = data[
: n_head_kv * head_dim
]
model_part[f"model.layers.{i}.self_attn.v_proj.weight"] = data[
n_head_kv * head_dim :
]
del model_part[f"transformer.h.{i}.attn.kv.weight"]
if f"transformer.h.{i}.attn.q.weight" in model_part:
model_part[f"model.layers.{i}.self_attn.q_proj.weight"] = model_part[
f"transformer.h.{i}.attn.q.weight"
]
del model_part[f"transformer.h.{i}.attn.q.weight"]
if f"transformer.h.{i}.mlp.gate_up_proj.weight" in model_part:
data = model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
model_part[f"model.layers.{i}.mlp.gate_proj.weight"] = data[:ff_dim]
model_part[f"model.layers.{i}.mlp.up_proj.weight"] = data[ff_dim:]
del model_part[f"transformer.h.{i}.mlp.gate_up_proj.weight"]
for name in model_part.keys(): for name in model_part.keys():
data = model_part[name] data = model_part[name]

View file

@ -286,21 +286,18 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf "model.layers.{bid}.self_attn.q_proj", # llama-hf
"transformer.h.{bid}.attn.q", # refact
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
), ),
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf "model.layers.{bid}.self_attn.k_proj", # llama-hf
"transformer.h.{bid}.attn.k", # refact
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
), ),
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf "model.layers.{bid}.self_attn.v_proj", # llama-hf
"transformer.h.{bid}.attn.v", # refact
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
), ),
@ -335,15 +332,13 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.c_fc", # gpt2 "transformer.h.{bid}.mlp.c_fc", # gpt2
"transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.blocks.{bid}.ffn.up_proj", # mpt
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
"model.layers.{bid}.mlp.up_proj", # llama-hf "model.layers.{bid}.mlp.up_proj", # llama-hf refact
"layers.{bid}.feed_forward.w3", # llama-pth "layers.{bid}.feed_forward.w3", # llama-pth
"transformer.h.{bid}.mlp.linear_3", # refact
), ),
# Feed-forward gate # Feed-forward gate
MODEL_TENSOR.FFN_GATE: ( MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"transformer.h.{bid}.mlp.linear_1", # refact
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
), ),

110
llama.cpp
View file

@ -3369,17 +3369,10 @@ static struct ggml_cgraph * llm_build_baichaun(
static struct ggml_cgraph * llm_build_refact( static struct ggml_cgraph * llm_build_refact(
llama_context & lctx, llama_context & lctx,
const llama_token * tokens, const llama_batch & batch) {
const float * embd,
int n_tokens,
int n_past) {
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
const int N = n_tokens;
const auto & model = lctx.model; const auto & model = lctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self; const auto & kv_self = lctx.kv_self;
@ -3387,7 +3380,7 @@ static struct ggml_cgraph * llm_build_refact(
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx; const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_head = hparams.n_embd_head();
@ -3397,6 +3390,12 @@ static struct ggml_cgraph * llm_build_refact(
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
// printf("n_kv = %d\n", n_kv);
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
@ -3414,12 +3413,12 @@ static struct ggml_cgraph * llm_build_refact(
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
if (tokens) { if (batch.token) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
ggml_allocr_alloc(lctx.alloc, inp_tokens); ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens));
} }
ggml_set_name(inp_tokens, "inp_tokens"); ggml_set_name(inp_tokens, "inp_tokens");
@ -3429,11 +3428,11 @@ static struct ggml_cgraph * llm_build_refact(
GGML_ASSERT(false && "not implemented"); GGML_ASSERT(false && "not implemented");
#endif #endif
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens);
ggml_allocr_alloc(lctx.alloc, inpL); ggml_allocr_alloc(lctx.alloc, inpL);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL));
} }
} }
@ -3442,9 +3441,6 @@ static struct ggml_cgraph * llm_build_refact(
// offload functions set the tensor output backend to GPU // offload functions set the tensor output backend to GPU
// tensors are GPU-accelerated if any input or the output has been offloaded // tensors are GPU-accelerated if any input or the output has been offloaded
//
// with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
// in that case ggml_cuda_assign_buffers has no effect
offload_func_t offload_func_nr = llama_nop; // nr = non-repeating offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
offload_func_t offload_func_kq = llama_nop; offload_func_t offload_func_kq = llama_nop;
offload_func_t offload_func_v = llama_nop; offload_func_t offload_func_v = llama_nop;
@ -3461,12 +3457,36 @@ static struct ggml_cgraph * llm_build_refact(
} }
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
// KQ_scale
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
ggml_allocr_alloc(lctx.alloc, KQ_scale); ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) { if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head)));
}
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
offload_func_kq(KQ_mask);
ggml_set_name(KQ_mask, "KQ_mask");
ggml_allocr_alloc(lctx.alloc, KQ_mask);
if (!ggml_allocr_is_measure(lctx.alloc)) {
float * data = (float *) KQ_mask->data;
memset(data, 0, ggml_nbytes(KQ_mask));
for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j];
for (int i = 0; i < n_kv; ++i) {
if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
}
}
}
}
} }
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
@ -3504,36 +3524,33 @@ static struct ggml_cgraph * llm_build_refact(
offload_func_kq(tmpq); offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq"); ggml_set_name(tmpq, "tmpq");
struct ggml_tensor * Kcur; struct ggml_tensor * Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens);
struct ggml_tensor * Qcur;
Kcur = ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N);
Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N);
offload_func_kq(Kcur); offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur"); ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens);
offload_func_kq(Qcur); offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur"); ggml_set_name(Qcur, "Qcur");
// store key and value to memory // store key and value to memory
{ {
// compute the transposed [N, n_embd] V matrix // compute the transposed [n_tokens, n_embd] V matrix
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur); struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
offload_func_v(tmpv); offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv"); ggml_set_name(tmpv, "tmpv");
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, N)); struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd_gqa, n_tokens));
offload_func_v(Vcur); offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur"); ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head));
offload_func_kq(k); offload_func_kq(k);
ggml_set_name(k, "k"); ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v), ( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v));
offload_func_v(v); offload_func_v(v);
ggml_set_name(v, "v"); ggml_set_name(v, "v");
@ -3547,7 +3564,7 @@ static struct ggml_cgraph * llm_build_refact(
struct ggml_tensor * K = struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k, ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv, n_embd_head, n_kv, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa, ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head, ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
@ -3560,25 +3577,28 @@ static struct ggml_cgraph * llm_build_refact(
ggml_set_name(KQ, "KQ"); ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head) // KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + N, N, n_head, 1] // KQ_scaled shape [n_kv, n_tokens, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled); offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled"); ggml_set_name(KQ_scaled, "KQ_scaled");
struct ggml_tensor * KQ_masked; // KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_scaled_alibi; struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ 0, n_head, 8);
KQ_scaled_alibi =ggml_alibi(ctx0, KQ_scaled, n_past, n_head, 8);
ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi");
KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_alibi, n_past);
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
offload_func_v(KQ_soft_max); offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max"); ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads // split cached V into n_head heads
struct ggml_tensor * V = struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v, ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv, n_kv, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx, ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head, ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
@ -3593,7 +3613,7 @@ static struct ggml_cgraph * llm_build_refact(
// make V contiguous in memory to speed up the matmul, however we waste time on the copy // make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation // on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way? // is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head)); struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif #endif
@ -3602,10 +3622,8 @@ static struct ggml_cgraph * llm_build_refact(
offload_func_v(KQV_merged); offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged"); ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, N) // cur = KQV_merged.contiguous().view(n_embd, n_tokens)
cur = ggml_cpy(ctx0, cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens);
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
offload_func_v(cur); offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous"); ggml_set_name(cur, "KQV_merged_contiguous");
@ -4338,7 +4356,7 @@ static struct ggml_cgraph * llama_build_graph(
} break; } break;
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
{ {
result = llm_build_refact(lctx, tokens, embd, n_tokens, n_past); result = llm_build_refact(lctx, batch);
} break; } break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);