Layer skipping demo
This commit is contained in:
parent
96981f37b1
commit
d6f35c7ca5
3 changed files with 179 additions and 31 deletions
|
@ -2,6 +2,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
@ -320,6 +321,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
const int n_batch = params.n_batch;
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
|
||||||
|
|
||||||
|
const int32_t n_layers = 32; // model layer count
|
||||||
|
const int test_count = 6; // num perplexity chunks to run for each test
|
||||||
|
const size_t prune_target = 4; // prune this many of the worst results each pass
|
||||||
|
// end tunables
|
||||||
|
|
||||||
|
// 1 = attn, 2 = mlp, 3 = both
|
||||||
|
int32_t test_skip_type = 0; // but don't mess with this, it's set automatically.
|
||||||
|
std::vector<int32_t> layers;
|
||||||
|
layers.resize(n_layers + 1);
|
||||||
|
std::fill(layers.begin(), layers.end(), 0);
|
||||||
|
batch.run_layers = layers.data();
|
||||||
|
int32_t skip_layer = -1;
|
||||||
|
std::vector<int32_t> skips;
|
||||||
|
std::vector<int32_t> skip_types;
|
||||||
|
skip_types.resize(n_layers);
|
||||||
|
std::fill(skip_types.begin(), skip_types.end(), 0);
|
||||||
|
std::vector<std::tuple<int32_t, int32_t, double>> pass_results;
|
||||||
|
std::vector<int32_t> worsts;
|
||||||
|
worsts.resize(n_layers);
|
||||||
|
std::fill(worsts.begin(), worsts.end(), 0);
|
||||||
|
int32_t curr_best_layer = -1, curr_best_type = 0;
|
||||||
|
double curr_best_ppl = -1, ref_ppl = -1;
|
||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
double nll2 = 0.0;
|
double nll2 = 0.0;
|
||||||
|
@ -327,8 +353,88 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||||
|
static const char * label = "?AMB";
|
||||||
|
|
||||||
|
auto test_t_start = std::chrono::high_resolution_clock::now();
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
|
if (i > 0 && i % test_count == 0) {
|
||||||
|
auto test_t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
float test_t_total = std::chrono::duration<float>(test_t_end - test_t_start).count();
|
||||||
|
|
||||||
|
skip_layer = n_layers;
|
||||||
|
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||||
|
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
|
||||||
|
// printf("##%d, %d\n", new_sl, curr_skipped);
|
||||||
|
if (curr_skipped == 3) continue; // Already tested or perm skip.
|
||||||
|
skip_layer = new_sl;
|
||||||
|
test_skip_type = (curr_skipped & 1) != 0 ? 2 : 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (skip_layer >= n_layers) {
|
||||||
|
if (curr_best_layer == -1) break;
|
||||||
|
if (pass_results.size() >= prune_target * 2) {
|
||||||
|
std::sort(pass_results.begin(), pass_results.end(),
|
||||||
|
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
|
||||||
|
return std::get<2>(a) > std::get<2>(b);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
const size_t num_prune = std::min(pass_results.size(), prune_target);
|
||||||
|
for (size_t temp = 0; temp < num_prune; temp++) {
|
||||||
|
int32_t lidx = std::get<0>(pass_results[temp]);
|
||||||
|
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
|
||||||
|
worsts[lidx] |= std::get<1>(pass_results[temp]);
|
||||||
|
printf("\nPrune[%zu]: %d (%d) - %.2f\n", temp, lidx, std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pass_results.clear();
|
||||||
|
printf("\n\nADD SKIP %c%3d - ppl vs ref %.4f",
|
||||||
|
int(label[curr_best_type]), curr_best_layer,
|
||||||
|
curr_best_ppl - ref_ppl);
|
||||||
|
if (curr_best_ppl > ref_ppl * 1.75) break;
|
||||||
|
skip_types[curr_best_layer] += curr_best_type;
|
||||||
|
if (std::find(skips.begin(), skips.end(), curr_best_layer) == skips.end()) {
|
||||||
|
skips.push_back(curr_best_layer);
|
||||||
|
}
|
||||||
|
curr_best_layer = -1;
|
||||||
|
curr_best_ppl = -1;
|
||||||
|
curr_best_type = 0;
|
||||||
|
skip_layer = n_layers;
|
||||||
|
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||||
|
skip_types[new_sl] = (skip_types[new_sl] & 3) | (worsts[new_sl] << 2);
|
||||||
|
}
|
||||||
|
for (int32_t new_sl = 0; new_sl < n_layers; new_sl++) {
|
||||||
|
int32_t curr_skipped = (skip_types[new_sl] >> 2) | (skip_types[new_sl] & 3);
|
||||||
|
// printf("||%d, %d\n", new_sl, curr_skipped);
|
||||||
|
if (curr_skipped == 3) continue; // Already tested or perm skip.
|
||||||
|
skip_layer = new_sl;
|
||||||
|
test_skip_type = (curr_skipped & 1) != 0 ? 2 : 1;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (skip_layer == -1 || skip_layer == n_layers) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
count = 0;
|
||||||
|
nll = 0;
|
||||||
|
nll2 = 0;
|
||||||
|
logit_history.clear();
|
||||||
|
prob_history.clear();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < n_layers; i++) {
|
||||||
|
layers[i] = (skip_types[i] & 3) | (i == skip_layer ? test_skip_type : 0);
|
||||||
|
}
|
||||||
|
layers[n_layers] = -1;
|
||||||
|
printf("\nTEST %c%3d + [", int(label[test_skip_type]), skip_layer);
|
||||||
|
for (const auto l : skips) {
|
||||||
|
printf("%c%d, ", int(label[skip_types[l] & 3]), l);
|
||||||
|
}
|
||||||
|
printf("] - len: %3zu, best:(%c%3d @ %.3f), last took %.2f sec\n",
|
||||||
|
skips.size() + 1,
|
||||||
|
int(label[curr_best_type]), curr_best_layer,
|
||||||
|
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
|
||||||
|
test_t_total);
|
||||||
|
test_t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
}
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
|
@ -353,7 +459,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
tokens[batch_start] = llama_token_bos(ctx);
|
tokens[batch_start] = llama_token_bos(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
batch.n_tokens = batch_size;
|
||||||
|
batch.token = tokens.data() + batch_start;
|
||||||
|
batch.all_pos_0 = j * n_batch;
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
@ -367,7 +477,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
|
|
||||||
const auto t_end = std::chrono::high_resolution_clock::now();
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
if (i == 0) {
|
if (i == 0 && skip_layer < 0 && skips.empty()) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total * n_chunk);
|
||||||
|
@ -396,8 +506,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
count += n_ctx - first - 1;
|
count += n_ctx - first - 1;
|
||||||
|
|
||||||
// perplexity is e^(average negative log-likelihood)
|
// perplexity is e^(average negative log-likelihood)
|
||||||
|
double ppl = std::exp(nll / count);
|
||||||
if (params.ppl_output_type == 0) {
|
if (params.ppl_output_type == 0) {
|
||||||
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
printf("[%d]%.4lf,", i + 1, ppl);
|
||||||
} else {
|
} else {
|
||||||
double av = nll/count;
|
double av = nll/count;
|
||||||
double av2 = nll2/count - av*av;
|
double av2 = nll2/count - av*av;
|
||||||
|
@ -405,6 +516,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||||
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
if (skip_layer >= 0 && (i + 1 == test_count || (i > 1 && ppl > ref_ppl * 3))) {
|
||||||
|
i = test_count - 1;
|
||||||
|
skip_types[skip_layer] |= test_skip_type << 2;
|
||||||
|
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
|
||||||
|
curr_best_layer = skip_layer;
|
||||||
|
curr_best_ppl = ppl;
|
||||||
|
curr_best_type = test_skip_type;
|
||||||
|
}
|
||||||
|
printf(" -- %.3f", ppl - ref_ppl);
|
||||||
|
pass_results.push_back({skip_layer, test_skip_type, ppl});
|
||||||
|
} else if (skip_layer < 0) {
|
||||||
|
ref_ppl = ppl;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
|
79
llama.cpp
79
llama.cpp
|
@ -3252,7 +3252,31 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t * run_layer = batch.run_layers;
|
||||||
|
bool run_attn = false, run_mlp = false;
|
||||||
|
cur = inpL;
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
run_attn = run_mlp = true;
|
||||||
|
if (run_layer != NULL) {
|
||||||
|
if (*run_layer >= 0) {
|
||||||
|
run_attn = (*run_layer & 1) == 0;
|
||||||
|
run_mlp = (*run_layer & 2) == 0;
|
||||||
|
run_layer++;
|
||||||
|
} else {
|
||||||
|
run_layer = NULL;
|
||||||
|
}
|
||||||
|
} else if (ggml_allocr_is_measure(lctx.alloc) && il == n_layer - 1) {
|
||||||
|
// No idea why this is needed, but otherwise we run out of space
|
||||||
|
// when skipping attn or mlp (but not both) on the last layer
|
||||||
|
run_mlp = false;
|
||||||
|
} else if (ggml_allocr_is_measure(lctx.alloc) && il == n_layer - 2) {
|
||||||
|
// No idea why this is needed, but otherwise we run out of space
|
||||||
|
// when skipping attn or mlp (but not both) on the last layer
|
||||||
|
run_attn = false;
|
||||||
|
}
|
||||||
|
if (!run_attn && !run_mlp) continue;
|
||||||
|
|
||||||
ggml_format_name(inpL, "layer_inp_%d", il);
|
ggml_format_name(inpL, "layer_inp_%d", il);
|
||||||
|
|
||||||
offload_func_t offload_func = llama_nop;
|
offload_func_t offload_func = llama_nop;
|
||||||
|
@ -3263,10 +3287,11 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
}
|
}
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
|
|
||||||
struct ggml_tensor * inpSA = inpL;
|
struct ggml_tensor * inpFF = nullptr;
|
||||||
|
|
||||||
// norm
|
// self-attention
|
||||||
{
|
if (run_attn) {
|
||||||
|
// norm
|
||||||
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "rms_norm_0");
|
ggml_set_name(cur, "rms_norm_0");
|
||||||
|
@ -3275,10 +3300,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "attention_norm_0");
|
ggml_set_name(cur, "attention_norm_0");
|
||||||
}
|
|
||||||
|
|
||||||
// self-attention
|
|
||||||
{
|
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||||
offload_func_kq(tmpk);
|
offload_func_kq(tmpk);
|
||||||
|
@ -3395,25 +3417,25 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
cur);
|
cur);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "result_wo");
|
ggml_set_name(cur, "result_wo");
|
||||||
|
|
||||||
|
inpFF = ggml_add(ctx0, cur, inpL);
|
||||||
|
offload_func(inpFF);
|
||||||
|
ggml_set_name(inpFF, "inpFF");
|
||||||
|
} else {
|
||||||
|
inpFF = inpL;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
|
||||||
offload_func(inpFF);
|
|
||||||
ggml_set_name(inpFF, "inpFF");
|
|
||||||
|
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
{
|
if (run_mlp) {
|
||||||
// norm
|
// norm
|
||||||
{
|
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
||||||
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
|
offload_func(cur);
|
||||||
offload_func(cur);
|
ggml_set_name(cur, "rms_norm_1");
|
||||||
ggml_set_name(cur, "rms_norm_1");
|
|
||||||
|
|
||||||
// cur = cur*ffn_norm(broadcasted)
|
// cur = cur*ffn_norm(broadcasted)
|
||||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "ffn_norm");
|
ggml_set_name(cur, "ffn_norm");
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||||
model.layers[il].w3,
|
model.layers[il].w3,
|
||||||
|
@ -3441,18 +3463,18 @@ static struct ggml_cgraph * llm_build_llama(
|
||||||
cur);
|
cur);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "result_w2");
|
ggml_set_name(cur, "result_w2");
|
||||||
}
|
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, inpFF);
|
cur = ggml_add(ctx0, cur, inpFF);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "inpFF_+_result_w2");
|
ggml_set_name(cur, "inpFF_+_result_w2");
|
||||||
|
} else {
|
||||||
|
cur = inpFF;
|
||||||
|
}
|
||||||
|
|
||||||
// input for next layer
|
// input for next layer
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
cur = inpL;
|
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
|
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
|
||||||
|
@ -9582,7 +9604,7 @@ int llama_eval_embd(
|
||||||
int n_past) {
|
int n_past) {
|
||||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
llama_batch batch = { n_tokens, nullptr, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||||
|
|
||||||
const int ret = llama_decode_internal(*ctx, batch);
|
const int ret = llama_decode_internal(*ctx, batch);
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
|
@ -9604,6 +9626,7 @@ struct llama_batch llama_batch_get_one(
|
||||||
llama_seq_id seq_id) {
|
llama_seq_id seq_id) {
|
||||||
return {
|
return {
|
||||||
/*n_tokens =*/ n_tokens,
|
/*n_tokens =*/ n_tokens,
|
||||||
|
/*run_layers =*/ nullptr,
|
||||||
/*tokens =*/ tokens,
|
/*tokens =*/ tokens,
|
||||||
/*embd =*/ nullptr,
|
/*embd =*/ nullptr,
|
||||||
/*pos =*/ nullptr,
|
/*pos =*/ nullptr,
|
||||||
|
@ -9617,7 +9640,7 @@ struct llama_batch llama_batch_get_one(
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
|
struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) {
|
||||||
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
||||||
|
|
||||||
if (embd) {
|
if (embd) {
|
||||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
||||||
|
|
1
llama.h
1
llama.h
|
@ -132,6 +132,7 @@ extern "C" {
|
||||||
//
|
//
|
||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
int32_t *run_layers; // end marked by negative value.
|
||||||
|
|
||||||
llama_token * token;
|
llama_token * token;
|
||||||
float * embd;
|
float * embd;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue