Compare commits

...
Sign in to create a new pull request.

20 commits

Author SHA1 Message Date
Georgi Gerganov
92a4f86879
llama : make starcoder graph build more consistent with others 2023-09-15 17:57:10 +03:00
Georgi Gerganov
f82328ab65
metal : fix out-of-bounds access in soft_max kernels 2023-09-15 17:56:49 +03:00
Meng Zhang
6c353dc7c2 cleanup useless code 2023-09-15 19:00:14 +08:00
Meng Zhang
a1cf66ea94 working in cpu, metal buggy 2023-09-15 18:45:43 +08:00
Meng Zhang
101c578715 add TBD 2023-09-15 15:23:50 +08:00
Meng Zhang
8bc76a225d add input embeddings handling 2023-09-15 14:47:04 +08:00
Meng Zhang
ab13d071e1 store mqa directly 2023-09-15 14:18:36 +08:00
Meng Zhang
4420cff654 fix vram calculation for starcoder 2023-09-15 13:52:43 +08:00
Meng Zhang
dac31da489 fix comments 2023-09-15 12:57:38 +08:00
Meng Zhang
0be15e162c fix head count kv 2023-09-15 12:56:20 +08:00
Meng Zhang
77c7ec179c properly load all starcoder params 2023-09-15 12:47:22 +08:00
Meng Zhang
2683611944 set n_positions to max_positioin_embeddings 2023-09-15 12:35:46 +08:00
Meng Zhang
a17ef39792 add max_position_embeddings 2023-09-15 12:35:17 +08:00
Meng Zhang
57f064d7c2 load starcoder weight 2023-09-15 12:12:33 +08:00
Meng Zhang
166a259f67 set head_count_kv = 1 2023-09-15 12:12:27 +08:00
Meng Zhang
7298c37e7e add LLM_ARCH_STARCODER to llama.cpp 2023-09-15 11:49:21 +08:00
Meng Zhang
7e0a843b6a fix ffn_down name 2023-09-15 11:45:18 +08:00
Meng Zhang
76d32cca59 convert MQA to MHA 2023-09-15 11:42:16 +08:00
Meng Zhang
eb7f0eba3e support convert starcoder weights to gguf 2023-09-15 11:24:24 +08:00
Meng Zhang
0c5d4d87b0 add placeholder of starcoder in gguf / llama.cpp 2023-09-15 10:39:47 +08:00
4 changed files with 679 additions and 29 deletions

267
convert-starcoder-hf-to-gguf.py Executable file
View file

@ -0,0 +1,267 @@
#!/usr/bin/env python3
# HF starcoder --> gguf conversion
from __future__ import annotations
import argparse
import json
import os
import struct
import sys
from pathlib import Path
from typing import Any
import numpy as np
import torch
from transformers import AutoTokenizer # type: ignore[import]
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf
def bytes_to_unicode():
# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
return dict(zip(bs, (chr(n) for n in cs)))
def count_model_parts(dir_model: Path) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"):
num_parts += 1
if num_parts > 0:
print("gguf: found " + str(num_parts) + " model parts")
return num_parts
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a StarCoder model to a GGML compatible file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)")
parser.add_argument("ftype", type=int, help="output format - use 0 for float32, 1 for float16", choices=[0, 1], default = 1)
return parser.parse_args()
args = parse_args()
dir_model = args.model
ftype = args.ftype
if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file = sys.stderr)
sys.exit(1)
# possible tensor data types
# ftype == 0 -> float32
# ftype == 1 -> float16
# map from ftype to string
ftype_str = ["f32", "f16"]
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf'
print("gguf: loading model "+dir_model.name)
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
if hparams["architectures"][0] != "GPTBigCodeForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0])
sys.exit(1)
# get number of model parts
num_parts = count_model_parts(dir_model)
ARCH=gguf.MODEL_ARCH.STARCODER
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
print("gguf: get model metadata")
block_count = hparams["n_layer"]
gguf_writer.add_name("StarCoder")
gguf_writer.add_context_length(2048) # not in config.json
gguf_writer.add_embedding_length(hparams["n_embd"])
gguf_writer.add_max_position_embeddings(hparams["n_positions"])
gguf_writer.add_feed_forward_length(4 * hparams["n_embd"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"])
gguf_writer.add_head_count_kv(hparams["n_head"])
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
gguf_writer.add_file_type(ftype)
# TOKENIZATION
print("gguf: get tokenizer metadata")
tokens: list[bytearray] = []
scores: list[float] = []
toktypes: list[int] = []
tokenizer_json_file = dir_model / 'tokenizer.json'
if not tokenizer_json_file.is_file():
print(f'Error: Missing {tokenizer_json_file}', file = sys.stderr)
sys.exit(1)
# gpt2 tokenizer
gguf_writer.add_tokenizer_model("gpt2")
with open(tokenizer_json_file, "r", encoding="utf-8") as f:
tokenizer_json = json.load(f)
print("gguf: get gpt2 tokenizer vocab")
# The number of tokens in tokenizer.json can differ from the expected vocab size.
# This causes downstream issues with mismatched tensor sizes when running the inference
vocab_size = hparams["vocab_size"] if "vocab_size" in hparams else len(tokenizer_json["model"]["vocab"])
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
for i in range(vocab_size):
if i in reverse_vocab:
try:
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
except KeyError:
text = bytearray()
for c in reverse_vocab[i]:
if ord(c) < 256: # single byte character
text.append(byte_decoder[ord(c)])
else: # multibyte special token character
text.extend(c.encode('utf-8'))
else:
print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
pad_token = f"[PAD{i}]".encode("utf8")
text = bytearray(pad_token)
tokens.append(text)
scores.append(0.0) # dymmy
toktypes.append(gguf.TokenType.NORMAL) # dummy
gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True)
special_vocab.add_to_gguf(gguf_writer)
# TENSORS
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# params for qkv transform
n_head = hparams["n_head"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
head_dim = hparams["n_embd"] // n_head
# tensor info
print("gguf: get tensor metadata")
if num_parts == 0:
part_names = iter(("pytorch_model.bin",))
else:
part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
)
for part_name in part_names:
if args.vocab_only:
break
print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu")
for name in model_part.keys():
data = model_part[name]
old_dtype = data.dtype
# convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32)
data = data.squeeze().numpy()
if name.endswith(".attn.c_attn.weight") or name.endswith(".attn.c_attn.bias"):
print("Duplicate K,V heads to use MHA instead of MQA for", name)
embed_dim = hparams["n_embd"]
head_dim = embed_dim // hparams["n_head"]
# ((n_heads + 2) * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
q, k ,v = np.split(data, (hparams["n_head"] * head_dim, (hparams["n_head"] + 1) * head_dim), axis=0)
# duplicate k, v along the first axis (head_dim, hidden_dim) -> (n_heads * head_dim, hidden_dim)
if len(k.shape) == 2:
k = np.tile(k, (hparams["n_head"], 1))
v = np.tile(v, (hparams["n_head"], 1))
elif len(k.shape) == 1:
k = np.tile(k, (hparams["n_head"]))
v = np.tile(v, (hparams["n_head"]))
# concat q, k, v along the first axis (n_heads * head_dim, hidden_dim) -> (3 * n_heads * head_dim, hidden_dim)
data = np.concatenate((q, k, v), axis=0)
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None:
print("Can not map tensor '" + name + "'")
sys.exit()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data)
print("gguf: write header")
gguf_writer.write_header_to_file()
print("gguf: write metadata")
gguf_writer.write_kv_data_to_file()
if not args.vocab_only:
print("gguf: write tensors")
gguf_writer.write_tensors_to_file()
gguf_writer.close()
print(f"gguf: model successfully exported to '{fname_out}'")
print("")

View file

@ -118,7 +118,7 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max // parallel max
float lmax = psrc0[tpitg[0]]; float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]); lmax = MAX(lmax, psrc0[i00]);
} }
@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max // parallel max
float4 lmax4 = psrc4[tpitg[0]]; float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
lmax4 = fmax(lmax4, psrc4[i00]); lmax4 = fmax(lmax4, psrc4[i00]);
} }

View file

@ -36,12 +36,13 @@ KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository"
KEY_GENERAL_FILE_TYPE = "general.file_type" KEY_GENERAL_FILE_TYPE = "general.file_type"
# LLM # LLM
KEY_CONTEXT_LENGTH = "{arch}.context_length" KEY_CONTEXT_LENGTH = "{arch}.context_length"
KEY_EMBEDDING_LENGTH = "{arch}.embedding_length" KEY_EMBEDDING_LENGTH = "{arch}.embedding_length"
KEY_BLOCK_COUNT = "{arch}.block_count" KEY_BLOCK_COUNT = "{arch}.block_count"
KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" KEY_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" KEY_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" KEY_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
KEY_MAX_POSITION_EMBEDDINGS = "{arch}.max_position_embeddings"
# attention # attention
KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count" KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count"
@ -77,13 +78,14 @@ KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world"
class MODEL_ARCH(IntEnum): class MODEL_ARCH(IntEnum):
LLAMA : int = auto() LLAMA : int = auto()
FALCON : int = auto() FALCON : int = auto()
BAICHUAN:int = auto() BAICHUAN : int = auto()
GPT2 : int = auto() GPT2 : int = auto()
GPTJ : int = auto() GPTJ : int = auto()
GPTNEOX: int = auto() GPTNEOX : int = auto()
MPT : int = auto() MPT : int = auto()
STARCODER : int = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -107,13 +109,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon", MODEL_ARCH.FALCON: "falcon",
MODEL_ARCH.BAICHUAN:"baichuan", MODEL_ARCH.BAICHUAN: "baichuan",
MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPT2: "gpt2",
MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTJ: "gptj",
MODEL_ARCH.GPTNEOX: "gptneox", MODEL_ARCH.GPTNEOX: "gptneox",
MODEL_ARCH.MPT: "mpt", MODEL_ARCH.MPT: "mpt",
MODEL_ARCH.STARCODER: "starcoder",
} }
MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = { MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
@ -171,6 +174,18 @@ MODEL_TENSOR_NAMES: dict[MODEL_ARCH, dict[MODEL_TENSOR, str]] = {
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
}, },
MODEL_ARCH.STARCODER: {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.POS_EMBD: "position_embd",
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
MODEL_TENSOR.OUTPUT: "output",
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
MODEL_ARCH.GPT2: { MODEL_ARCH.GPT2: {
# TODO # TODO
}, },
@ -703,6 +718,10 @@ class GGUFWriter:
self.add_uint32( self.add_uint32(
KEY_EMBEDDING_LENGTH.format(arch=self.arch), length) KEY_EMBEDDING_LENGTH.format(arch=self.arch), length)
def add_max_position_embeddings(self, length: int):
self.add_uint32(
KEY_MAX_POSITION_EMBEDDINGS.format(arch=self.arch), length)
def add_block_count(self, length: int): def add_block_count(self, length: int):
self.add_uint32( self.add_uint32(
KEY_BLOCK_COUNT.format(arch=self.arch), length) KEY_BLOCK_COUNT.format(arch=self.arch), length)

378
llama.cpp
View file

@ -160,17 +160,19 @@ enum llm_arch {
LLM_ARCH_GPTJ, LLM_ARCH_GPTJ,
LLM_ARCH_GPTNEOX, LLM_ARCH_GPTNEOX,
LLM_ARCH_MPT, LLM_ARCH_MPT,
LLM_ARCH_STARCODER,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
static std::map<llm_arch, std::string> LLM_ARCH_NAMES = { static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GPT2, "gpt2" }, { LLM_ARCH_GPT2, "gpt2" },
{ LLM_ARCH_GPTJ, "gptj" }, { LLM_ARCH_GPTJ, "gptj" },
{ LLM_ARCH_GPTNEOX, "gptneox" }, { LLM_ARCH_GPTNEOX, "gptneox" },
{ LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_MPT, "mpt" },
{ LLM_ARCH_BAICHUAN,"baichuan" }, { LLM_ARCH_BAICHUAN, "baichuan" },
{ LLM_ARCH_STARCODER, "starcoder" },
}; };
enum llm_kv { enum llm_kv {
@ -191,6 +193,7 @@ enum llm_kv {
LLM_KV_FEED_FORWARD_LENGTH, LLM_KV_FEED_FORWARD_LENGTH,
LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_USE_PARALLEL_RESIDUAL,
LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_TENSOR_DATA_LAYOUT,
LLM_KV_MAX_POSITION_EMBEDDINGS,
LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV, LLM_KV_ATTENTION_HEAD_COUNT_KV,
@ -235,6 +238,7 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
{ LLM_KV_MAX_POSITION_EMBEDDINGS, "%s.max_position_embeddings" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -376,6 +380,21 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
}, },
}, },
{
LLM_ARCH_STARCODER,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
{ {
@ -895,6 +914,7 @@ static llama_state g_state;
// available llama models // available llama models
enum e_model { enum e_model {
MODEL_UNKNOWN, MODEL_UNKNOWN,
MODEL_1B,
MODEL_3B, MODEL_3B,
MODEL_7B, MODEL_7B,
MODEL_13B, MODEL_13B,
@ -919,6 +939,7 @@ struct llama_hparams {
uint32_t n_layer = 32; uint32_t n_layer = 32;
uint32_t n_rot = 64; uint32_t n_rot = 64;
uint32_t n_ff = 11008; uint32_t n_ff = 11008;
uint32_t n_positions = 0; // StarCoder
float f_norm_eps = 1e-5; float f_norm_eps = 1e-5;
float f_norm_rms_eps = 1e-5; float f_norm_rms_eps = 1e-5;
@ -966,13 +987,22 @@ struct llama_layer {
struct ggml_tensor * wo; struct ggml_tensor * wo;
struct ggml_tensor * wqkv; struct ggml_tensor * wqkv;
// attention bias
struct ggml_tensor * bo;
struct ggml_tensor * bqkv;
// normalization // normalization
struct ggml_tensor * ffn_norm; struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;
// ff // ff
struct ggml_tensor * w1; // ffn_gate struct ggml_tensor * w1; // ffn_gate
struct ggml_tensor * w2; // ffn_down struct ggml_tensor * w2; // ffn_down
struct ggml_tensor * w3; // ffn_up struct ggml_tensor * w3; // ffn_up
// ff bias
struct ggml_tensor * b2; // ffn_down
struct ggml_tensor * b3; // ffn_up
}; };
struct llama_kv_cache { struct llama_kv_cache {
@ -1050,6 +1080,7 @@ struct llama_model {
llama_vocab vocab; llama_vocab vocab;
struct ggml_tensor * tok_embeddings; struct ggml_tensor * tok_embeddings;
struct ggml_tensor * pos_embeddings;
struct ggml_tensor * output_norm; struct ggml_tensor * output_norm;
struct ggml_tensor * output_norm_b; struct ggml_tensor * output_norm_b;
@ -1634,6 +1665,7 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); GGUF_GET_KEY(ctx, hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(ctx, hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); GGUF_GET_KEY(ctx, hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
GGUF_GET_KEY(ctx, hparams.n_positions, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_MAX_POSITION_EMBEDDINGS));
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
@ -1713,6 +1745,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_STARCODER:
{
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));
switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
}; };
@ -2166,6 +2206,85 @@ static void llm_load_tensors(
} }
} }
} break; } break;
case LLM_ARCH_STARCODER:
{
model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_positions}, GGML_BACKEND_CPU);
// output
{
ggml_backend backend_norm;
ggml_backend backend_output;
if (n_gpu_layers > int(n_layer)) {
// norm is not performance relevant on its own but keeping it in VRAM reduces data copying
// on Windows however this is detrimental unless everything is on the GPU
#ifndef _WIN32
backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#else
backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
#endif // _WIN32
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.output_norm);
vram_weights += ggml_nbytes(model.output_norm_b);
}
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
vram_weights += ggml_nbytes(model.output);
}
}
const uint32_t n_ff = hparams.n_ff;
const int i_gpu_start = n_layer - n_gpu_layers;
model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend);
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {3*n_embd}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);
layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) +
ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) +
ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2) +
ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3);
}
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
}; };
@ -3305,6 +3424,247 @@ static struct ggml_cgraph * llm_build_falcon(
return gf; return gf;
} }
static struct ggml_cgraph * llm_build_starcoder(
llama_context & lctx,
const llama_token * tokens,
const float * embd,
int n_tokens,
int n_past) {
GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
const int N = n_tokens;
const auto & model = lctx.model;
const auto & hparams = model.hparams;
const auto & kv_self = lctx.kv_self;
GGML_ASSERT(!!kv_self.ctx);
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = hparams.n_ctx;
const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head = hparams.n_embd_head();
const int64_t n_embd_gqa = hparams.n_embd_gqa();
GGML_ASSERT(n_embd_head == hparams.n_rot);
const float norm_eps = hparams.f_norm_eps;
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,
/*.no_alloc =*/ false,
};
params.no_alloc = true;
struct ggml_context * ctx0 = ggml_init(params);
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * token;
struct ggml_tensor * position;
struct ggml_tensor * inpL;
if (tokens) {
struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, inp_tokens);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
}
ggml_set_name(inp_tokens, "inp_tokens");
token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
} else {
#ifdef GGML_USE_MPI
GGML_ASSERT(false && "not implemented");
#endif
token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
ggml_allocr_alloc(lctx.alloc, token);
if (!ggml_allocr_is_measure(lctx.alloc)) {
memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL));
}
}
{
// Compute position embeddings.
struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(lctx.alloc, inp_positions);
if (!ggml_allocr_is_measure(lctx.alloc)) {
for (int i = 0; i < N; ++i) {
((int32_t *) inp_positions->data)[i] = n_past + i;
}
}
ggml_set_name(inp_positions, "inp_positions");
position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions);
}
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(lctx.alloc, KQ_scale);
if (!ggml_allocr_is_measure(lctx.alloc)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
inpL = ggml_add(ctx0, token, position);
ggml_set_name(inpL, "inpL");
for (int il = 0; il < n_layer; ++il) {
{
// Norm
cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b);
}
{
// Self Attention
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv);
struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*sizeof(float)*n_embd);
struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*sizeof(float)*n_embd);
struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*sizeof(float)*n_embd);
struct ggml_tensor * Qcur = tmpq;
struct ggml_tensor * Kcur = tmpk;
{
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
ggml_set_name(k, "k");
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd/n_head, n_head, N)),
0, 2, 1, 3);
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_past + N, n_head_kv,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
ggml_set_name(K, "K");
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
ggml_element_size(kv_self.v)*n_ctx,
ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
ggml_set_name(V, "V");
#if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
ggml_set_name(KQV, "KQV");
#else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
ggml_set_name(cur, "KQV_merged_contiguous");
}
// Projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo);
// add the input
cur = ggml_add(ctx0, cur, inpL);
struct ggml_tensor * inpFF = cur;
// FF
{
// norm
{
cur = ggml_norm(ctx0, inpFF, norm_eps);
// cur = ln_2_g*cur + ln_2_b
// [ 768, N]
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b);
}
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
// GELU activation
cur = ggml_gelu(ctx0, cur);
// projection
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2);
}
inpL = ggml_add(ctx0, cur, inpFF);
}
// norm
{
cur = ggml_norm(ctx0, inpL, norm_eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b);
}
ggml_set_name(cur, "result_norm");
cur = ggml_mul_mat(ctx0, model.output, cur);
ggml_set_name(cur, "result_output");
ggml_build_forward_expand(gf, cur);
ggml_free(ctx0);
// norm
return gf;
}
static struct ggml_cgraph * llama_build_graph( static struct ggml_cgraph * llama_build_graph(
llama_context & lctx, llama_context & lctx,
const llama_token * tokens, const llama_token * tokens,
@ -3328,6 +3688,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past); result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
} break; } break;
case LLM_ARCH_STARCODER:
{
result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past);
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
}; };