option to not use scratch
This commit is contained in:
parent
d5e4cf7ffe
commit
43c2891afa
5 changed files with 102 additions and 32 deletions
|
@ -549,7 +549,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
return res;
|
||||
}
|
||||
// determine the required inference memory per token:
|
||||
gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format);
|
||||
gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, calc_mem_with_scratch);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
|
@ -616,14 +616,14 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, calc_mem_with_scratch);
|
||||
|
||||
//if the logits are NAN or duplicated, it means the model is incompatible
|
||||
std::vector<float> oldlogits(logits);
|
||||
|
||||
//this is another hack because they change the library - we run the eval through the model
|
||||
//twice and compare logits. if they give the same logits for different inputs, model is broken
|
||||
gptj_eval(gptj_ctx_v3, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token);
|
||||
gptj_eval(gptj_ctx_v3, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token, calc_mem_with_scratch);
|
||||
|
||||
if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits)))
|
||||
{
|
||||
|
@ -688,7 +688,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
// determine the required inference memory per token:
|
||||
gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
|
||||
gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, calc_mem_with_scratch);
|
||||
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
|
@ -745,7 +745,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
}
|
||||
|
||||
// determine the required inference memory per token:
|
||||
mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token);
|
||||
mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token, calc_mem_with_scratch);
|
||||
return ModelLoadResult::SUCCESS;
|
||||
}
|
||||
else
|
||||
|
@ -1078,7 +1078,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else if(file_format==FileFormat::GPT2_4)
|
||||
{
|
||||
evalres = gpt2_eval(gpt2_ctx_v3, params.n_threads, n_past, embd, logits, mem_per_token, file_format);
|
||||
evalres = gpt2_eval(gpt2_ctx_v3, params.n_threads, n_past, embd, logits, mem_per_token);
|
||||
}
|
||||
else if(file_format==FileFormat::NEOX_1 || file_format == FileFormat::NEOX_2 || file_format == FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5)
|
||||
{
|
||||
|
|
|
@ -389,7 +389,7 @@ bool gpt2_eval(
|
|||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token,
|
||||
FileFormat file_format) {
|
||||
bool use_scratch=true) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -406,13 +406,21 @@ bool gpt2_eval(
|
|||
// use 2 scratch buffers
|
||||
// TODO: very hacky solution - reimplement in a more elegant way
|
||||
static size_t scr0_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr0 = malloc(scr0_size);
|
||||
static void * scr0;
|
||||
|
||||
static size_t scr1_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr1 = malloc(scr1_size);
|
||||
static void * scr1;
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N*1.05 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.15*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
if(use_scratch)
|
||||
{
|
||||
scr0 = malloc(scr0_size);
|
||||
scr1 = malloc(scr1_size);
|
||||
}
|
||||
|
||||
size_t scratch_needed_mem = mem_per_token*N;
|
||||
|
||||
if (mem_per_token > 0 && scratch_needed_mem*1.1 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.2*(scratch_needed_mem); // add 10% to account for ggml object overhead
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
|
@ -455,7 +463,9 @@ bool gpt2_eval(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -603,7 +613,9 @@ bool gpt2_eval(
|
|||
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
|
||||
}
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
|
@ -661,7 +673,9 @@ bool gpt2_eval(
|
|||
inpL = ggml_add(ctx0, cur, inpFF);
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -677,7 +691,9 @@ bool gpt2_eval(
|
|||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
}
|
||||
|
||||
// inpL = WTE * inpL
|
||||
// [ 768, 50257] - model.lm_head
|
||||
|
|
|
@ -382,7 +382,8 @@ bool gptj_eval(
|
|||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
size_t & mem_per_token,
|
||||
bool use_scratch=true) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -400,13 +401,18 @@ bool gptj_eval(
|
|||
// use 2 scratch buffers
|
||||
// TODO: very hacky solution - reimplement in a more elegant way
|
||||
static size_t scr0_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr0 = malloc(scr0_size);
|
||||
static void * scr0;
|
||||
|
||||
static size_t scr1_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr1 = malloc(scr1_size);
|
||||
static void * scr1;
|
||||
if(use_scratch)
|
||||
{
|
||||
scr0 = malloc(scr0_size);
|
||||
scr1 = malloc(scr1_size);
|
||||
}
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N*1.05 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.15*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
if (mem_per_token > 0 && 32u*1024*1024 + mem_per_token*N*1.2 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.2*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
|
@ -441,7 +447,9 @@ bool gptj_eval(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -530,7 +538,9 @@ bool gptj_eval(
|
|||
cur);
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
|
||||
}
|
||||
|
||||
struct ggml_tensor * inpFF = cur;
|
||||
|
||||
|
@ -567,7 +577,9 @@ bool gptj_eval(
|
|||
inpL = ggml_add(ctx0, cur, inpL);
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -581,7 +593,9 @@ bool gptj_eval(
|
|||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
}
|
||||
|
||||
// lm_head
|
||||
{
|
||||
|
|
|
@ -316,7 +316,8 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
|
|||
// - embd_w: the predicted logits for the next token
|
||||
//
|
||||
bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w, bool logits_all, size_t & mem_per_token) {
|
||||
const std::vector<gpt_vocab::id> & embd_inp, std::vector<float> & embd_w,
|
||||
bool logits_all, size_t & mem_per_token, bool use_scratch=true) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -332,22 +333,37 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
|
||||
// use 2 scratch buffers
|
||||
// TODO: very hacky solution - reimplement in a more elegant way
|
||||
|
||||
static size_t scr0_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||
static void * scr0 = malloc(scr0_size);
|
||||
|
||||
static size_t scr1_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||
static void * scr1 = malloc(scr1_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token * N > buf_size) {
|
||||
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
|
||||
if(n_embd>=7168) //MPT 30B needs more scratch memory
|
||||
{
|
||||
scr0_size *= 2;
|
||||
scr1_size *= 2;
|
||||
}
|
||||
|
||||
static void * scr0;
|
||||
static void * scr1;
|
||||
if(use_scratch)
|
||||
{
|
||||
scr0 = malloc(scr0_size);
|
||||
scr1 = malloc(scr1_size);
|
||||
}
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token * N *1.1 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.2 * (mem_per_token * N); // add 10% to account for ggml object overhead
|
||||
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
|
||||
// buf_size, buf_size_new);
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
if (buf_size_new > buf_size)
|
||||
{
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -369,7 +385,9 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// a = self.ln_1(x)
|
||||
{
|
||||
|
@ -465,7 +483,9 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
|
||||
inpL = ggml_add(ctx0, inpL, cur);
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
|
||||
}
|
||||
|
||||
// m = self.ln_2(x)
|
||||
{
|
||||
|
@ -491,7 +511,9 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
inpL = ggml_add(ctx0, inpL, cur);
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -500,7 +522,9 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm_f_weight, inpL), inpL);
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
}
|
||||
|
||||
// output embedding weight tied to input embedding
|
||||
inpL = ggml_mul_mat(ctx0, model.wte_weight, inpL);
|
||||
|
|
|
@ -400,7 +400,8 @@ bool gpt_neox_eval(
|
|||
const int n_past,
|
||||
const std::vector<gpt_vocab::id> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
size_t & mem_per_token,
|
||||
bool use_scratch=true) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
@ -418,13 +419,20 @@ bool gpt_neox_eval(
|
|||
// use 2 scratch buffers
|
||||
// TODO: very hacky solution - reimplement in a more elegant way
|
||||
static size_t scr0_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr0 = malloc(scr0_size);
|
||||
static void * scr0;
|
||||
|
||||
static size_t scr1_size = (n_ctx>1024?512u:256u)*1024*1024;
|
||||
static void * scr1 = malloc(scr1_size);
|
||||
static void * scr1;
|
||||
if(use_scratch)
|
||||
{
|
||||
scr0 = malloc(scr0_size);
|
||||
scr1 = malloc(scr1_size);
|
||||
}
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N*1.05 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.15*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
size_t scratch_needed_mem = mem_per_token*N;
|
||||
|
||||
if (mem_per_token > 0 && scratch_needed_mem*1.1 > buf_size) {
|
||||
const size_t buf_size_new = 64u*1024*1024 + 1.2*(scratch_needed_mem); // add 10% to account for ggml object overhead
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
|
@ -459,7 +467,9 @@ bool gpt_neox_eval(
|
|||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * cur;
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
@ -564,7 +574,9 @@ bool gpt_neox_eval(
|
|||
}
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr1_size, scr1, });
|
||||
}
|
||||
|
||||
if (hparams.par_res == 0) {
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
|
||||
|
@ -588,7 +600,9 @@ bool gpt_neox_eval(
|
|||
}
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, scr0_size, scr0, });
|
||||
}
|
||||
|
||||
// norm
|
||||
{
|
||||
|
@ -602,7 +616,9 @@ bool gpt_neox_eval(
|
|||
ggml_repeat(ctx0, model.ln_f_b, inpL));
|
||||
}
|
||||
|
||||
if(use_scratch){
|
||||
ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
}
|
||||
|
||||
// lm_head
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue