Merge branch 'master' into stablelm-support

This commit is contained in:
Galunid 2023-10-22 21:33:53 +02:00
commit cf5eff36ae
16 changed files with 98 additions and 45 deletions

View file

@ -101,7 +101,7 @@ as the main playground for developing new features for the [ggml](https://github
- Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python)
- Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp), [hlhr202/llama-node](https://github.com/hlhr202/llama-node) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
- Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp) - C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)

View file

@ -632,6 +632,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
process_escapes(params.prompt); process_escapes(params.prompt);
process_escapes(params.input_prefix); process_escapes(params.input_prefix);
process_escapes(params.input_suffix); process_escapes(params.input_suffix);
process_escapes(sparams.cfg_negative_prompt);
for (auto & antiprompt : params.antiprompt) { for (auto & antiprompt : params.antiprompt) {
process_escapes(antiprompt); process_escapes(antiprompt);
} }

View file

@ -230,7 +230,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model) special_vocab = gguf.SpecialVocab(dir_model, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -129,7 +129,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -152,7 +152,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -134,7 +134,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -388,7 +388,9 @@ def handle_metadata(cfg, hp):
cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir,
cfg.vocabtype ) cfg.vocabtype )
# FIXME: Respect cfg.vocab_dir? # FIXME: Respect cfg.vocab_dir?
svocab = gguf.SpecialVocab(cfg.model_metadata_dir) svocab = gguf.SpecialVocab(cfg.model_metadata_dir,
load_merges = cfg.vocabtype == 'bpe',
n_vocab = vocab.vocab_size)
convert.check_vocab_size(params, vocab) convert.check_vocab_size(params, vocab)
return (params, vocab, svocab) return (params, vocab, svocab)

View file

@ -128,18 +128,25 @@ vocab_size = hparams["vocab_size"]
# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model) tokenizer = AutoTokenizer.from_pretrained(dir_model)
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
for i in range(vocab_size): for i in range(vocab_size):
tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]") if i not in reverse_vocab:
scores.append(0.0) # dummy tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.NORMAL) toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
# NOTE: wouldn't we like to distinguish CONTROL tokens here?
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
gguf_writer.add_token_list(tokens) gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -150,7 +150,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) special_vocab = gguf.SpecialVocab(dir_model, load_merges=True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -122,7 +122,7 @@ gguf_writer.add_token_list(tokens)
gguf_writer.add_token_scores(scores) gguf_writer.add_token_scores(scores)
gguf_writer.add_token_types(toktypes) gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) special_vocab = gguf.SpecialVocab(dir_model, load_merges = True, n_vocab = len(tokens))
special_vocab.add_to_gguf(gguf_writer) special_vocab.add_to_gguf(gguf_writer)
# TENSORS # TENSORS

View file

@ -369,7 +369,7 @@ class SentencePieceVocab:
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
actual_ids = sorted(added_tokens.values()) actual_ids = sorted(added_tokens.values())
if expected_ids != actual_ids: if expected_ids != actual_ids:
raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") raise Exception(f"Expected added token IDs to be sequential and start at {vocab_size}; got {actual_ids}")
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
self.added_tokens_list = [text for (text, idx) in items] self.added_tokens_list = [text for (text, idx) in items]
@ -1163,10 +1163,13 @@ def main(args_in: list[str] | None = None) -> None:
vocab: Vocab vocab: Vocab
if args.vocab_only: if args.vocab_only:
assert args.outfile, "need --outfile if using --vocab-only" if not args.outfile:
raise ValueError("need --outfile if using --vocab-only")
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype)
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe') special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = args.vocabtype == 'bpe',
n_vocab = vocab.vocab_size)
outfile = args.outfile outfile = args.outfile
OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) OutputFile.write_vocab_only(outfile, params, vocab, special_vocab)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
@ -1178,7 +1181,9 @@ def main(args_in: list[str] | None = None) -> None:
vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent
vocab = load_vocab(vocab_dir, args.vocabtype) vocab = load_vocab(vocab_dir, args.vocabtype)
# FIXME: Try to respect vocab_dir somehow? # FIXME: Try to respect vocab_dir somehow?
special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, load_merges = args.vocabtype == 'bpe') special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent,
load_merges = args.vocabtype == 'bpe',
n_vocab = vocab.vocab_size)
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)

View file

@ -761,6 +761,9 @@ int main(int argc, char ** argv) {
n_consumed = embd_inp.size(); n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
} }
if (params.escape) {
process_escapes(buffer);
}
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);

View file

@ -1005,12 +1005,15 @@ class SpecialVocab:
merges: list[str] = [] merges: list[str] = []
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad') special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {} special_token_ids: dict[str, int] = {}
n_vocab: int | None = None
def __init__( def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False, self, path: str | os.PathLike[str], load_merges: bool = False,
special_token_types: tuple[str, ...] | None = None, special_token_types: tuple[str, ...] | None = None,
n_vocab: int | None = None,
): ):
self.special_token_ids = {} self.special_token_ids = {}
self.n_vocab = n_vocab
self.load_merges = load_merges self.load_merges = load_merges
if special_token_types is not None: if special_token_types is not None:
self.special_token_types = special_token_types self.special_token_types = special_token_types
@ -1020,6 +1023,16 @@ class SpecialVocab:
if not self._try_load_from_tokenizer_json(path): if not self._try_load_from_tokenizer_json(path):
self._try_load_from_config_json(path) self._try_load_from_config_json(path)
def _set_special_token(self, typ: str, tid: Any):
if not isinstance(tid, int) or tid < 0:
return
if self.n_vocab is None or tid < self.n_vocab:
self.special_token_ids[typ] = tid
return
print(f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
file = sys.stderr)
def _try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if not tokenizer_file.is_file(): if not tokenizer_file.is_file():
@ -1047,10 +1060,11 @@ class SpecialVocab:
tc_content = entry_content tc_content = entry_content
else: else:
continue continue
for maybe_token_id in (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content): # We only need the first match here.
if isinstance(maybe_token_id, int) and maybe_token_id >= 0: maybe_token_id = next((
self.special_token_ids[typ] = maybe_token_id atok.get('id') for atok in added_tokens
break if atok.get('content') == tc_content), None)
self._set_special_token(typ, maybe_token_id)
return True return True
def _try_load_from_config_json(self, path: Path) -> bool: def _try_load_from_config_json(self, path: Path) -> bool:
@ -1060,21 +1074,21 @@ class SpecialVocab:
with open(config_file, encoding = 'utf-8') as f: with open(config_file, encoding = 'utf-8') as f:
config = json.load(f) config = json.load(f)
for typ in self.special_token_types: for typ in self.special_token_types:
maybe_token_id = config.get(f'{typ}_token_id') self._set_special_token(typ, config.get(f'{typ}_token_id'))
if isinstance(maybe_token_id, int) and maybe_token_id >= 0:
self.special_token_ids[typ] = maybe_token_id
return True return True
def add_to_gguf(self, gw: GGUFWriter) -> None: def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
if len(self.merges) > 0: if len(self.merges) > 0:
print(f'gguf: Adding {len(self.merges)} merge(s).') if not quiet:
print(f'gguf: Adding {len(self.merges)} merge(s).')
gw.add_token_merges(self.merges) gw.add_token_merges(self.merges)
for typ, tokid in self.special_token_ids.items(): for typ, tokid in self.special_token_ids.items():
handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
if handler is None: if handler is None:
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping') print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', file = sys.stderr)
continue continue
print(f'gguf: Setting special token type {typ} to {tokid}') if not quiet:
print(f'gguf: Setting special token type {typ} to {tokid}')
handler(tokid) handler(tokid)
def __repr__(self) -> str: def __repr__(self) -> str:

View file

@ -996,14 +996,15 @@ static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
(void) tensor; (void) tensor;
} }
static std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
std::vector<char> result(8, 0); std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size()); int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
GGML_ASSERT(check == -n_tokens); GGML_ASSERT(check == -n_tokens);
} else { }
else {
result.resize(n_tokens); result.resize(n_tokens);
} }
@ -1223,10 +1224,10 @@ struct llama_vocab {
id special_eot_id = 32010; id special_eot_id = 32010;
int find_bpe_rank(std::string token_left, std::string token_right) const { int find_bpe_rank(std::string token_left, std::string token_right) const {
replace_all(token_left, " ", "\u0120"); GGML_ASSERT(token_left.find(" ") == std::string::npos);
replace_all(token_left, "\n", "\u010A"); GGML_ASSERT(token_left.find("\n") == std::string::npos);
replace_all(token_right, " ", "\u0120"); GGML_ASSERT(token_right.find(" ") == std::string::npos);
replace_all(token_right, "\n", "\u010A"); GGML_ASSERT(token_right.find("\n") == std::string::npos);
auto it = bpe_ranks.find(std::make_pair(token_left, token_right)); auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
if (it == bpe_ranks.end()) { if (it == bpe_ranks.end()) {
@ -2269,15 +2270,35 @@ static void llm_load_vocab(
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n'); vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
} else { } else {
vocab.linefeed_id = llama_tokenize_internal(vocab, "\u010A", false)[0]; const std::vector<int> ids = llama_tokenize_internal(vocab, "\u010A", false);
GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
vocab.linefeed_id = ids[0];
} }
// special tokens // special tokens
GGUF_GET_KEY(ctx, vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_BOS_ID)); {
GGUF_GET_KEY(ctx, vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_EOS_ID)); const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
};
for (const auto & it : special_token_types) {
const std::string & key = kv(std::get<0>(it));
int32_t & id = std::get<1>(it), old_id = id;
GGUF_GET_KEY(ctx, id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, key);
// Must be >= -1 and < vocab size. Since the key is unsigned, -1
// can only come from the default value, so there's no point in
// validating that.
if (size_t(id + 1) > vocab.id_to_token.size()) {
LLAMA_LOG_WARN("%s: bad special token: '%s' = %d, using default id %d\n",
__func__, key.c_str(), id, old_id);
id = old_id;
}
}
}
// build special tokens cache // build special tokens cache
{ {
@ -6616,11 +6637,10 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
} }
static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) {
static const char * hex = "0123456789ABCDEF";
switch (llama_vocab_get_type(vocab)) { switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM: { case LLAMA_VOCAB_TYPE_SPM: {
char buf[7]; const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch);
GGML_ASSERT(0 <= result && result < 7);
return vocab.token_to_id.at(buf); return vocab.token_to_id.at(buf);
} }
case LLAMA_VOCAB_TYPE_BPE: { case LLAMA_VOCAB_TYPE_BPE: {
@ -7993,7 +8013,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
for (size_t i = 0; i < candidates->size; ++i) { for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id; const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_str(ctx, id); const std::string piece = llama_token_to_piece(ctx, id);
if (id == eos) { if (id == eos) {
if (!allow_eos) { if (!allow_eos) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
@ -8205,7 +8225,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false); GGML_ASSERT(false);
} }
const std::string piece = llama_token_to_str(ctx, token); const std::string piece = llama_token_to_piece(ctx, token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8);

BIN
models/ggml-vocab-mpt.gguf Normal file

Binary file not shown.

View file

@ -31,6 +31,7 @@ llama_test_executable (test-tokenizer-1-llama test-tokenizer-1-llama.cpp ${CMAKE
llama_build_executable(test-tokenizer-1-bpe.cpp) llama_build_executable(test-tokenizer-1-bpe.cpp)
llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf) llama_test_executable (test-tokenizer-1-falcon test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) llama_test_executable(test-tokenizer-1-aquila test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
llama_test_executable(test-tokenizer-1-mpt test-tokenizer-1-bpe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_build_and_test_executable(test-grammar-parser.cpp) llama_build_and_test_executable(test-grammar-parser.cpp)
llama_build_and_test_executable(test-llama-grammar.cpp) llama_build_and_test_executable(test-llama-grammar.cpp)
llama_build_and_test_executable(test-grad0.cpp) # SLOW llama_build_and_test_executable(test-grad0.cpp) # SLOW