[wip] open_llama_3b support

Not working, perplexity around 2000.
This commit is contained in:
Henri Vasserman 2023-05-24 21:35:46 +03:00
parent ac7876ac20
commit ff99507049
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986
2 changed files with 20 additions and 3 deletions

View file

@ -143,12 +143,22 @@ class Params:
def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params': def guessed(model: 'LazyModel', file_type: GGMLFileType) -> 'Params':
n_vocab, n_embd = model["tok_embeddings.weight"].shape n_vocab, n_embd = model["tok_embeddings.weight"].shape
n_mult=256
n_head=n_embd // 128
n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model)
# TODO: hack for open_llama_3b
if n_embd == 3200:
n_mult = 108
n_head = 32
n_layer = 26
return Params( return Params(
n_vocab=n_vocab, n_vocab=n_vocab,
n_embd=n_embd, n_embd=n_embd,
n_mult=256, n_mult=n_mult,
n_head=n_embd // 128, n_head=n_head,
n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model), n_layer=n_layer,
file_type=file_type, file_type=file_type,
) )

View file

@ -42,6 +42,7 @@
// available llama models // available llama models
enum e_model { enum e_model {
MODEL_UNKNOWN, MODEL_UNKNOWN,
MODEL_3B,
MODEL_7B, MODEL_7B,
MODEL_13B, MODEL_13B,
MODEL_30B, MODEL_30B,
@ -58,6 +59,7 @@ static const size_t MB = 1024*1024;
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0() static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 128ull * MB },
{ MODEL_7B, 512ull * MB }, { MODEL_7B, 512ull * MB },
{ MODEL_13B, 512ull * MB }, { MODEL_13B, 512ull * MB },
{ MODEL_30B, 512ull * MB }, { MODEL_30B, 512ull * MB },
@ -69,6 +71,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1() static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 128ull * MB },
{ MODEL_7B, 512ull * MB }, { MODEL_7B, 512ull * MB },
{ MODEL_13B, 512ull * MB }, { MODEL_13B, 512ull * MB },
{ MODEL_30B, 512ull * MB }, { MODEL_30B, 512ull * MB },
@ -81,6 +84,7 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1()
static const std::map<e_model, size_t> & MEM_REQ_KV_SELF() static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 682ull * MB },
{ MODEL_7B, 1026ull * MB }, { MODEL_7B, 1026ull * MB },
{ MODEL_13B, 1608ull * MB }, { MODEL_13B, 1608ull * MB },
{ MODEL_30B, 3124ull * MB }, { MODEL_30B, 3124ull * MB },
@ -94,6 +98,7 @@ static const std::map<e_model, size_t> & MEM_REQ_KV_SELF()
static const std::map<e_model, size_t> & MEM_REQ_EVAL() static const std::map<e_model, size_t> & MEM_REQ_EVAL()
{ {
static std::map<e_model, size_t> k_sizes = { static std::map<e_model, size_t> k_sizes = {
{ MODEL_3B, 512ull * MB },
{ MODEL_7B, 768ull * MB }, { MODEL_7B, 768ull * MB },
{ MODEL_13B, 1024ull * MB }, { MODEL_13B, 1024ull * MB },
{ MODEL_30B, 1280ull * MB }, { MODEL_30B, 1280ull * MB },
@ -899,6 +904,7 @@ static const char *llama_ftype_name(enum llama_ftype ftype) {
static const char *llama_model_type_name(e_model type) { static const char *llama_model_type_name(e_model type) {
switch (type) { switch (type) {
case MODEL_3B: return "3B";
case MODEL_7B: return "7B"; case MODEL_7B: return "7B";
case MODEL_13B: return "13B"; case MODEL_13B: return "13B";
case MODEL_30B: return "30B"; case MODEL_30B: return "30B";
@ -932,6 +938,7 @@ static void llama_model_load_internal(
{ {
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break; case 40: model.type = e_model::MODEL_13B; break;
case 60: model.type = e_model::MODEL_30B; break; case 60: model.type = e_model::MODEL_30B; break;