falcon : CPU inference working
This commit is contained in:
parent
085228e1f5
commit
3c7c325b98
1 changed files with 288 additions and 15 deletions
295
llama.cpp
295
llama.cpp
|
@ -1031,8 +1031,6 @@ struct llama_context {
|
|||
// key + value cache for the self attention
|
||||
struct llama_kv_cache kv_self;
|
||||
|
||||
size_t mem_per_token = 0;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
bool logits_all = false;
|
||||
|
@ -2014,7 +2012,7 @@ static bool llama_model_load(
|
|||
return true;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
static struct ggml_cgraph * llm_build_llama(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
|
@ -2048,7 +2046,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
auto & mem_per_token = lctx.mem_per_token;
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
|
@ -2340,20 +2337,296 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
// logits -> probs
|
||||
//cur = ggml_soft_max_inplace(ctx0, cur);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llm_build_falcon(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
|
||||
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
|
||||
|
||||
const int N = n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
GGML_ASSERT(!!kv_self.ctx);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_layer = hparams.n_layer;
|
||||
const int64_t n_ctx = hparams.n_ctx;
|
||||
const int64_t n_head = hparams.n_head;
|
||||
const int64_t n_head_kv = hparams.n_head_kv;
|
||||
const int64_t n_embd_head = hparams.n_embd_head();
|
||||
//const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute.size,
|
||||
/*.mem_buffer =*/ buf_compute.data,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
params.no_alloc = true;
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
if (tokens) {
|
||||
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inp_tokens);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
|
||||
}
|
||||
ggml_set_name(inp_tokens, "inp_tokens");
|
||||
|
||||
inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
|
||||
} else {
|
||||
#ifdef GGML_USE_MPI
|
||||
GGML_ASSERT(false && "not implemented");
|
||||
#endif
|
||||
|
||||
inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
|
||||
|
||||
ggml_allocr_alloc(lctx.alloc, inpL);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL));
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
||||
ggml_allocr_alloc(lctx.alloc, KQ_scale);
|
||||
if (!ggml_allocr_is_measure(lctx.alloc)) {
|
||||
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||
}
|
||||
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * layernorm_output;
|
||||
|
||||
// self-attention
|
||||
{
|
||||
layernorm_output = ggml_norm(ctx0, inpL);
|
||||
|
||||
layernorm_output = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm, layernorm_output),
|
||||
layernorm_output),
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_b, layernorm_output));
|
||||
|
||||
if ( hparams.n_head_kv == 8 ) { // Falcon-40B
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_2, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.layers[il].attn_norm_2_b, cur));
|
||||
}
|
||||
else { // Falcon 7B
|
||||
cur = layernorm_output;
|
||||
}
|
||||
|
||||
// compute QKV
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
|
||||
|
||||
// Note that the strides for Kcur, Vcur are set up so that the
|
||||
// resulting views are misaligned with the tensor's storage
|
||||
// (by applying the K/V offset we shift the tensor's original
|
||||
// view to stick out behind the viewed QKV tensor's allocated
|
||||
// memory, so to say). This is ok because no actual accesses
|
||||
// happen to that out-of-range memory, but it can require some
|
||||
// trickery when trying to accurately dump these views for
|
||||
// debugging.
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head, N,
|
||||
n_embd_head * ggml_type_size(GGML_TYPE_F32),
|
||||
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
|
||||
0);
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
n_embd_head * ggml_type_size(GGML_TYPE_F32),
|
||||
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
|
||||
n_embd_head * n_head * ggml_type_size(GGML_TYPE_F32));
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_view_3d(
|
||||
ctx0, cur, n_embd_head, n_head_kv, N,
|
||||
n_embd_head * ggml_type_size(GGML_TYPE_F32),
|
||||
n_embd_head * (n_head + 2 * n_head_kv) * ggml_type_size(GGML_TYPE_F32),
|
||||
n_embd_head * (n_head + n_head_kv) * ggml_type_size(GGML_TYPE_F32));
|
||||
|
||||
// using mode = 2 for neox mode
|
||||
Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_embd_head, 2, 0);
|
||||
Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_embd_head, 2, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
struct ggml_tensor* k = ggml_view_1d(
|
||||
ctx0, kv_self.k, N * n_head_kv * n_embd_head,
|
||||
(ggml_element_size(kv_self.k) * n_head_kv * n_embd_head) *
|
||||
(il * n_ctx + n_past));
|
||||
struct ggml_tensor* v = ggml_view_1d(
|
||||
ctx0, kv_self.v, N * n_head_kv * n_embd_head,
|
||||
(ggml_element_size(kv_self.v) * n_head_kv * n_embd_head) *
|
||||
(il * n_ctx + n_past));
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(
|
||||
ctx0,
|
||||
ggml_reshape_3d(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N) * n_head_kv * n_embd_head,
|
||||
il * n_ctx *
|
||||
ggml_element_size(kv_self.k) *
|
||||
n_head_kv *
|
||||
n_embd_head),
|
||||
n_embd_head, n_head_kv, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
|
||||
// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy));
|
||||
|
||||
struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale_inplace(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd_head)))
|
||||
);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor* V = ggml_permute(
|
||||
ctx0,
|
||||
ggml_reshape_3d(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, kv_self.v, (n_past + N) * n_head_kv * n_embd_head,
|
||||
il * n_ctx *
|
||||
ggml_element_size(kv_self.v) *
|
||||
n_head_kv *
|
||||
n_embd_head),
|
||||
n_embd_head, n_head_kv, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy)));
|
||||
V = ggml_cont(ctx0, ggml_transpose(ctx0, V));
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_tensor* inpFF = layernorm_output;
|
||||
struct ggml_tensor* attn_out = ggml_cpy(
|
||||
ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
{
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, attn_out);
|
||||
cur = ggml_add(ctx0, cur, inpL);
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_norm(ctx0, inpL);
|
||||
|
||||
cur = ggml_add(ctx0,
|
||||
ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.output_norm, cur),
|
||||
cur),
|
||||
ggml_repeat(ctx0, model.output_norm_b, cur));
|
||||
ggml_set_name(cur, "result_norm");
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph(
|
||||
llama_context & lctx,
|
||||
const llama_token * tokens,
|
||||
const float * embd,
|
||||
int n_tokens,
|
||||
int n_past) {
|
||||
const auto & model = lctx.model;
|
||||
|
||||
struct ggml_cgraph * result = NULL;
|
||||
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
result = llm_build_llama(lctx, tokens, embd, n_tokens, n_past);
|
||||
} break;
|
||||
case LLM_ARCH_FALCON:
|
||||
{
|
||||
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue