Merge branch 'master' into HEAD

This commit is contained in:
Georgi Gerganov 2023-07-25 14:21:14 +03:00
commit ea02f675f9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
27 changed files with 2855 additions and 1140 deletions

View file

@ -323,6 +323,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
common.o: examples/common.cpp examples/common.h common.o: examples/common.cpp examples/common.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h
$(CXX) $(CXXFLAGS) -c $< -o $@
libllama.so: llama.o ggml.o $(OBJS) libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
@ -333,7 +336,7 @@ clean:
# Examples # Examples
# #
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS) main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo @echo
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@ -357,7 +360,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml.
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS) save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)

View file

@ -13,6 +13,8 @@ set(TARGET common)
add_library(${TARGET} OBJECT add_library(${TARGET} OBJECT
common.h common.h
common.cpp common.cpp
grammar-parser.h
grammar-parser.cpp
) )
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)

View file

@ -8,6 +8,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static const float rms_norm_eps = 1e-6f;
float frand() { float frand() {
return (float)rand()/(float)RAND_MAX; return (float)rand()/(float)RAND_MAX;
} }
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur // cur = attention_norm*cur
cur = ggml_mul(ctx0, cur = ggml_mul(ctx0,
@ -685,7 +687,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur // cur = ffn_norm*cur
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -729,7 +731,7 @@ struct ggml_tensor * forward(
{ {
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL // inpL = norm*inpL
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
{ {
// inpL shape [n_embd,N*n_batch,1,1] // inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur // cur = attention_norm*cur
cur = ggml_mul(ctx0, cur = ggml_mul(ctx0,
@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur // cur = ffn_norm*cur
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
{ {
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL // inpL = norm*inpL
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]

View file

@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.n_gqa = std::stoi(argv[i]); params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") { } else if (arg == "--rope-freq-base") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -438,6 +444,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--grammar") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.grammar = argv[i];
} else if (arg == "--grammar-file") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.grammar)
);
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);
@ -497,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
@ -514,6 +543,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
@ -591,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx; lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch; lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa; lparams.n_gqa = params.n_gqa;
lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers; lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu; lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split; lparams.tensor_split = params.tensor_split;

View file

@ -34,6 +34,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rms_norm_eps = 1e-6; // rms norm epsilon
float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
@ -63,6 +64,7 @@ struct gpt_params {
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with std::string input_suffix = ""; // string to suffix user inputs with
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string lora_adapter = ""; // lora adapter path std::string lora_adapter = ""; // lora adapter path

423
examples/grammar-parser.cpp Normal file
View file

@ -0,0 +1,423 @@
#include "grammar-parser.h"
#include <cstdint>
#include <cwchar>
#include <string>
#include <utility>
#include <stdexcept>
#include <exception>
namespace grammar_parser {
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
uint8_t first_byte = static_cast<uint8_t>(*src);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = src + len; // may overrun!
const char * pos = src + 1;
for ( ; pos < end && *pos; pos++) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
}
return std::make_pair(value, pos);
}
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
return result.first->second;
}
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
return next_id;
}
void add_rule(
parse_state & state,
uint32_t rule_id,
const std::vector<llama_grammar_element> & rule) {
if (state.rules.size() <= rule_id) {
state.rules.resize(rule_id + 1);
}
state.rules[rule_id] = rule;
}
bool is_word_char(char c) {
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
const char * pos = src;
const char * end = src + size;
uint32_t value = 0;
for ( ; pos < end && *pos; pos++) {
value <<= 4;
char c = *pos;
if ('a' <= c && c <= 'f') {
value += c - 'a' + 10;
} else if ('A' <= c && c <= 'F') {
value += c - 'A' + 10;
} else if ('0' <= c && c <= '9') {
value += c - '0';
} else {
break;
}
}
if (pos != end) {
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
}
return std::make_pair(value, pos);
}
const char * parse_space(const char * src, bool newline_ok) {
const char * pos = src;
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
if (*pos == '#') {
while (*pos && *pos != '\r' && *pos != '\n') {
pos++;
}
} else {
pos++;
}
}
return pos;
}
const char * parse_name(const char * src) {
const char * pos = src;
while (is_word_char(*pos)) {
pos++;
}
if (pos == src) {
throw std::runtime_error(std::string("expecting name at ") + src);
}
return pos;
}
std::pair<uint32_t, const char *> parse_char(const char * src) {
if (*src == '\\') {
switch (src[1]) {
case 'x': return parse_hex(src + 2, 2);
case 'u': return parse_hex(src + 2, 4);
case 'U': return parse_hex(src + 2, 8);
case 't': return std::make_pair('\t', src + 2);
case 'r': return std::make_pair('\r', src + 2);
case 'n': return std::make_pair('\n', src + 2);
case '\\':
case '"':
case '[':
case ']':
return std::make_pair(src[1], src + 2);
default:
throw std::runtime_error(std::string("unknown escape at ") + src);
}
} else if (*src) {
return decode_utf8(src);
}
throw std::runtime_error("unexpected end of input");
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested);
const char * parse_sequence(
parse_state & state,
const char * src,
const std::string & rule_name,
std::vector<llama_grammar_element> & out_elements,
bool is_nested) {
size_t last_sym_start = out_elements.size();
const char * pos = src;
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = out_elements.size();
while (*pos != '"') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = out_elements.size();
while (*pos != ']') {
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < out_elements.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
out_elements.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = out_elements.size();
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
last_sym_start = out_elements.size();
// output reference to synthesized rule
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
if (last_sym_start == out_elements.size()) {
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
}
// apply transformation to previous symbol (last_sym_start to end) according to
// rewrite rules:
// S* --> S' ::= S S' |
// S+ --> S' ::= S S' | S
// S? --> S' ::= S |
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
std::vector<llama_grammar_element> sub_rule;
// add preceding symbol to generated rule
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
if (*pos == '*' || *pos == '+') {
// cause generated rule to recurse
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
}
// mark start of alternate def
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
if (*pos == '+') {
// add preceding symbol as alternate only for '+' (otherwise empty)
sub_rule.insert(
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
}
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, sub_rule_id, sub_rule);
// in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start);
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
pos = parse_space(pos + 1, is_nested);
} else {
break;
}
}
return pos;
}
const char * parse_alternates(
parse_state & state,
const char * src,
const std::string & rule_name,
uint32_t rule_id,
bool is_nested) {
std::vector<llama_grammar_element> rule;
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
while (*pos == '|') {
rule.push_back({LLAMA_GRETYPE_ALT, 0});
pos = parse_space(pos + 1, true);
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
}
rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule(state, rule_id, rule);
return pos;
}
const char * parse_rule(parse_state & state, const char * src) {
const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(state, src, name_len);
const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(state, pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
parse_state parse(const char * src) {
try {
parse_state state;
const char * pos = parse_space(src, true);
while (*pos) {
pos = parse_rule(state, pos);
}
return state;
} catch (const std::exception & err) {
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
return parse_state();
}
}
void print_grammar_char(FILE * file, uint32_t c) {
if (0x20 <= c && c <= 0x7f) {
fprintf(file, "%c", static_cast<char>(c));
} else {
// cop out of encoding UTF-8
fprintf(file, "<U+%04X>", c);
}
}
bool is_char_element(llama_grammar_element elem) {
switch (elem.type) {
case LLAMA_GRETYPE_CHAR: return true;
case LLAMA_GRETYPE_CHAR_NOT: return true;
case LLAMA_GRETYPE_CHAR_ALT: return true;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
default: return false;
}
}
void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
for (auto elem : rule) {
switch (elem.type) {
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
}
switch (elem.type) {
case LLAMA_GRETYPE_END:
case LLAMA_GRETYPE_ALT:
case LLAMA_GRETYPE_RULE_REF:
fprintf(file, "(%u) ", elem.value);
break;
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
case LLAMA_GRETYPE_CHAR_ALT:
fprintf(file, "(\"");
print_grammar_char(file, elem.value);
fprintf(file, "\") ");
break;
}
}
fprintf(file, "\n");
}
void print_rule(
FILE * file,
uint32_t rule_id,
const std::vector<llama_grammar_element> & rule,
const std::map<uint32_t, std::string> & symbol_id_names) {
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
throw std::runtime_error(
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
}
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
llama_grammar_element elem = rule[i];
switch (elem.type) {
case LLAMA_GRETYPE_END:
throw std::runtime_error(
"unexpected end of rule: " + std::to_string(rule_id) + "," +
std::to_string(i));
case LLAMA_GRETYPE_ALT:
fprintf(file, "| ");
break;
case LLAMA_GRETYPE_RULE_REF:
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
break;
case LLAMA_GRETYPE_CHAR:
fprintf(file, "[");
print_grammar_char(file, elem.value);
break;
case LLAMA_GRETYPE_CHAR_NOT:
fprintf(file, "[^");
print_grammar_char(file, elem.value);
break;
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
fprintf(file, "-");
print_grammar_char(file, elem.value);
break;
case LLAMA_GRETYPE_CHAR_ALT:
if (i == 0 || !is_char_element(rule[i - 1])) {
throw std::runtime_error(
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
std::to_string(rule_id) + "," + std::to_string(i));
}
print_grammar_char(file, elem.value);
break;
}
if (is_char_element(elem)) {
switch (rule[i + 1].type) {
case LLAMA_GRETYPE_CHAR_ALT:
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
break;
default:
fprintf(file, "] ");
}
}
}
fprintf(file, "\n");
}
void print_grammar(FILE * file, const parse_state & state) {
try {
std::map<uint32_t, std::string> symbol_id_names;
for (auto kv : state.symbol_ids) {
symbol_id_names[kv.second] = kv.first;
}
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
// fprintf(file, "%zu: ", i);
// print_rule_binary(file, state.rules[i]);
print_rule(file, i, state.rules[i], symbol_id_names);
// fprintf(file, "\n");
}
} catch (const std::exception & err) {
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
}
}
std::vector<const llama_grammar_element *> parse_state::c_rules() {
std::vector<const llama_grammar_element *> ret;
for (const auto & rule : rules) {
ret.push_back(rule.data());
}
return ret;
}
}

29
examples/grammar-parser.h Normal file
View file

@ -0,0 +1,29 @@
// Implements a parser for an extended Backus-Naur form (BNF), producing the
// binary context-free grammar format specified by llama.h. Supports character
// ranges, grouping, and repetition operators. As an example, a grammar for
// arithmetic might look like:
//
// root ::= expr
// expr ::= term ([-+*/] term)*
// term ::= num | "(" space expr ")" space
// num ::= [0-9]+ space
// space ::= [ \t\n]*
#pragma once
#include "llama.h"
#include <vector>
#include <map>
#include <cstdint>
#include <string>
namespace grammar_parser {
struct parse_state {
std::map<std::string, uint32_t> symbol_ids;
std::vector<std::vector<llama_grammar_element>> rules;
std::vector<const llama_grammar_element *> c_rules();
};
parse_state parse(const char * src);
void print_grammar(FILE * file, const parse_state & state);
}

View file

@ -6,6 +6,7 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#include "build-info.h" #include "build-info.h"
#include "grammar-parser.h"
#include <cassert> #include <cassert>
#include <cinttypes> #include <cinttypes>
@ -337,6 +338,31 @@ int main(int argc, char ** argv) {
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n"); fprintf(stderr, "\n\n");
grammar_parser::parse_state parsed_grammar;
llama_grammar * grammar = NULL;
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
fprintf(stderr, "\n");
{
auto it = params.logit_bias.find(llama_token_eos());
if (it != params.logit_bias.end() && it->second == -INFINITY) {
fprintf(stderr,
"%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// TODO: replace with ring-buffer // TODO: replace with ring-buffer
std::vector<llama_token> last_n_tokens(n_ctx); std::vector<llama_token> last_n_tokens(n_ctx);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
@ -570,6 +596,10 @@ int main(int argc, char ** argv) {
logits[llama_token_nl()] = nl_logit; logits[llama_token_nl()] = nl_logit;
} }
if (grammar != NULL) {
llama_sample_grammar(ctx, &candidates_p, grammar);
}
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p); id = llama_sample_token_greedy(ctx, &candidates_p);
@ -595,6 +625,10 @@ int main(int argc, char ** argv) {
} }
// printf("`%d`", candidates_p.size); // printf("`%d`", candidates_p.size);
if (grammar != NULL) {
llama_grammar_accept_token(ctx, grammar, id);
}
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);
} }
@ -725,6 +759,18 @@ int main(int argc, char ** argv) {
} }
if (n_past > 0) { if (n_past > 0) {
if (is_interacting) {
// reset grammar state if we're restarting generation
if (grammar != NULL) {
llama_grammar_free(grammar);
std::vector<const llama_grammar_element *> grammar_rules(
parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
}
}
is_interacting = false; is_interacting = false;
} }
} }
@ -756,6 +802,9 @@ int main(int argc, char ** argv) {
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);
if (grammar != NULL) {
llama_grammar_free(grammar);
}
llama_backend_free(); llama_backend_free();
return 0; return 0;

File diff suppressed because it is too large Load diff

View file

@ -73,6 +73,37 @@
margin: 0; margin: 0;
} }
fieldset.two {
display: grid;
grid-template: "a a";
gap: 1em;
}
fieldset.three {
display: grid;
grid-template: "a a a";
gap: 1em;
}
details {
border: 1px solid #aaa;
border-radius: 4px;
padding: 0.5em 0.5em 0;
margin-top: 0.5em;
}
summary {
font-weight: bold;
margin: -0.5em -0.5em 0;
padding: 0.5em;
cursor: pointer;
}
details[open] {
padding: 0.5em;
}
textarea { textarea {
padding: 5px; padding: 5px;
flex-grow: 1; flex-grow: 1;
@ -125,10 +156,17 @@
const params = signal({ const params = signal({
n_predict: 400, n_predict: 400,
temperature: 0.7, temperature: 0.7,
repeat_last_n: 256, repeat_last_n: 256, // 0 = disable penalty, -1 = context size
repeat_penalty: 1.18, repeat_penalty: 1.18, // 1.0 = disabled
top_k: 40, top_k: 40, // <= 0 to use vocab size
top_p: 0.5, top_p: 0.5, // 1.0 = disabled
tfs_z: 1.0, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
presence_penalty: 0.0, // 0.0 = disabled
frequency_penalty: 0.0, // 0.0 = disabled
mirostat: 0, // 0/1/2
mirostat_tau: 5, // target entropy
mirostat_eta: 0.1, // learning rate
}) })
const llamaStats = signal(null) const llamaStats = signal(null)
@ -264,6 +302,27 @@
const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value } const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value }
const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value } const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value }
const updateParamsFloat = (el) => params.value = { ...params.value, [el.target.name]: parseFloat(el.target.value) } const updateParamsFloat = (el) => params.value = { ...params.value, [el.target.name]: parseFloat(el.target.value) }
const updateParamsInt = (el) => params.value = { ...params.value, [el.target.name]: Math.floor(parseFloat(el.target.value)) }
const FloatField = ({label, max, min, name, step, value}) => {
return html`
<div>
<label for="${name}">${label}</label>
<input type="range" id="${name}" min="${min}" max="${max}" step="${step}" name="${name}" value="${value}" oninput=${updateParamsFloat} />
<span>${value}</span>
</div>
`
};
const IntField = ({label, max, min, name, value}) => {
return html`
<div>
<label for="${name}">${label}</label>
<input type="range" id="${name}" min="${min}" max="${max}" name="${name}" value="${value}" oninput=${updateParamsInt} />
<span>${value}</span>
</div>
`
};
return html` return html`
<form> <form>
@ -272,7 +331,9 @@
<label for="prompt">Prompt</label> <label for="prompt">Prompt</label>
<textarea type="text" name="prompt" value="${session.value.prompt}" rows=4 oninput=${updateSession}/> <textarea type="text" name="prompt" value="${session.value.prompt}" rows=4 oninput=${updateSession}/>
</div> </div>
</fieldset>
<fieldset class="two">
<div> <div>
<label for="user">User name</label> <label for="user">User name</label>
<input type="text" name="user" value="${session.value.user}" oninput=${updateSession} /> <input type="text" name="user" value="${session.value.user}" oninput=${updateSession} />
@ -282,7 +343,9 @@
<label for="bot">Bot name</label> <label for="bot">Bot name</label>
<input type="text" name="char" value="${session.value.char}" oninput=${updateSession} /> <input type="text" name="char" value="${session.value.char}" oninput=${updateSession} />
</div> </div>
</fieldset>
<fieldset>
<div> <div>
<label for="template">Prompt template</label> <label for="template">Prompt template</label>
<textarea id="template" name="template" value="${session.value.template}" rows=4 oninput=${updateSession}/> <textarea id="template" name="template" value="${session.value.template}" rows=4 oninput=${updateSession}/>
@ -292,38 +355,44 @@
<label for="template">Chat history template</label> <label for="template">Chat history template</label>
<textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/> <textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/>
</div> </div>
<div>
<label for="temperature">Temperature</label>
<input type="range" id="temperature" min="0.0" max="1.0" step="0.01" name="temperature" value="${params.value.temperature}" oninput=${updateParamsFloat} />
<span>${params.value.temperature}</span>
</div>
<div>
<label for="nPredict">Predictions</label>
<input type="range" id="nPredict" min="1" max="2048" step="1" name="n_predict" value="${params.value.n_predict}" oninput=${updateParamsFloat} />
<span>${params.value.n_predict}</span>
</div>
<div>
<label for="repeat_penalty">Penalize repeat sequence</label>
<input type="range" id="repeat_penalty" min="0.0" max="2.0" step="0.01" name="repeat_penalty" value="${params.value.repeat_penalty}" oninput=${updateParamsFloat} />
<span>${params.value.repeat_penalty}</span>
</div>
<div>
<label for="repeat_last_n">Consider N tokens for penalize</label>
<input type="range" id="repeat_last_n" min="0.0" max="2048" name="repeat_last_n" value="${params.value.repeat_last_n}" oninput=${updateParamsFloat} />
<span>${params.value.repeat_last_n}</span>
</div>
</fieldset> </fieldset>
<fieldset class="two">
${IntField({label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict})}
${FloatField({label: "Temperature", max: 1.5, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature})}
${FloatField({label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty})}
${IntField({label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n})}
${IntField({label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k})}
${FloatField({label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p})}
</fieldset>
<details>
<summary>More options</summary>
<fieldset class="two">
${FloatField({label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z})}
${FloatField({label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p})}
${FloatField({label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty})}
${FloatField({label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty})}
</fieldset>
<hr />
<fieldset class="three">
<div>
<label><input type="radio" name="mirostat" value="0" checked=${params.value.mirostat == 0} oninput=${updateParamsInt} /> no Mirostat</label>
<label><input type="radio" name="mirostat" value="1" checked=${params.value.mirostat == 1} oninput=${updateParamsInt} /> Mirostat v1</label>
<label><input type="radio" name="mirostat" value="2" checked=${params.value.mirostat == 2} oninput=${updateParamsInt} /> Mirostat v2</label>
</div>
${FloatField({label: "Mirostat tau", max: 10.0, min: 0.0, name: "mirostat_tau", step: 0.01, value: params.value.mirostat_tau})}
${FloatField({label: "Mirostat eta", max: 1.0, min: 0.0, name: "mirostat_eta", step: 0.01, value: params.value.mirostat_eta})}
</fieldset>
</details>
</form> </form>
` `
} }
// poor mans markdown replacement // poor mans markdown replacement
const Markdownish = (params) => { const Markdownish = (params) => {
const md = params.text const md = params.text
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/^#{1,6} (.*)$/gim, '<h3>$1</h3>') .replace(/^#{1,6} (.*)$/gim, '<h3>$1</h3>')
.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>') .replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')
.replace(/__(.*?)__/g, '<strong>$1</strong>') .replace(/__(.*?)__/g, '<strong>$1</strong>')

View file

@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
params.n_gqa = std::stoi(argv[i]); params.n_gqa = std::stoi(argv[i]);
} }
else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc)
{
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]);
}
else if (arg == "--rope-freq-base") else if (arg == "--rope-freq-base")
{ {
if (++i >= argc) if (++i >= argc)

View file

@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
static const float rms_norm_eps = 1e-6f;
struct random_normal_distribution { struct random_normal_distribution {
std::mt19937 gen; std::mt19937 gen;
std::normal_distribution<float> rd; std::normal_distribution<float> rd;
@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// cur = attention_norm*cur // cur = attention_norm*cur
cur = ggml_mul(ctx0, cur = ggml_mul(ctx0,
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm // norm
{ {
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
// cur = ffn_norm*cur // cur = ffn_norm*cur
// cur shape [n_embd,N,1,1] // cur shape [n_embd,N,1,1]
@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{ {
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
// inpL = norm*inpL // inpL = norm*inpL
// inpL shape [n_embd,N,1,1] // inpL shape [n_embd,N,1,1]
@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{ {
// inpL shape [n_embd,N*n_batch,1,1] // inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm // norm
{ {
// cur shape [n_embd,N*n_batch,1,1] // cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{ {
// inpL shape [n_embd,N*n_batch,1,1] // inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = attention_norm*cur // cur = attention_norm*cur
@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{ {
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch); assert_shape_2d(cur, n_embd, N*n_batch);
// cur = ffn_norm*cur // cur = ffn_norm*cur
@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm // norm
{ {
inpL = ggml_rms_norm(ctx0, inpL); inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch); assert_shape_2d(inpL, n_embd, N*n_batch);
// inpL = norm*inpL // inpL = norm*inpL
@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il]; struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1) // tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed. // other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch); use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch); use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch); use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
} }
clr_buf(0); clr_buf(0);
use_buf(0); use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch); struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch); struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch); struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1); use_buf(-1);

View file

@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
} }
} }
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
@ -1566,12 +1564,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq; const block_q4_K * bq4_K = (const block_q4_K *) vbq;
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
float sumf_d = 0.0f; float sumf_d = 0.0f;
float sumf_m = 0.0f; float sumf_m = 0.0f;
#ifndef GGML_QKK_64
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
const float d = bq4_K->d; const float d = bq4_K->d;
const float dmin = bq4_K->dmin; const float dmin = bq4_K->dmin;
@ -1616,6 +1616,43 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
} }
return d*sumf_d - dmin*sumf_m; return d*sumf_d - dmin*sumf_m;
#else
uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;
const uint16_t * a = (const uint16_t *)bq4_K->scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
const float dall = bq4_K->d[0];
const float dmin = bq4_K->d[1];
const float d8_1 = bq8_1[0].d;
const float d8_2 = bq8_1[1].d;
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
const int * q4 = (const int *)bq4_K->qs + iqs;
const int v1 = q4[0];
const int v2 = q4[4];
const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
return dall * sumf_d - dmin * sumf_m;
#endif
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@ -1627,6 +1664,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq; const block_q5_K * bq5_K = (const block_q5_K *) vbq;
#ifndef GGML_QKK_64
const int bq8_offset = QR5_K * (iqs / (QI8_1/2)); const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4)); const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4)); const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
@ -1682,6 +1721,42 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
} }
return d*sumf_d - dmin*sumf_m; return d*sumf_d - dmin*sumf_m;
#else
const int8_t * s = bq5_K->scales;
const float d = bq5_K->d;
const float d8_1 = bq8_1[0].d;
const float d8_2 = bq8_1[1].d;
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
const int * ql = (const int *)bq5_K->qs + iqs;
const int vl1 = ql[0];
const int vl2 = ql[4];
const int step = 4 * iqs; // 0, 4, 8, 12
const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
const int in = step%8; // 0, 4, 0, 4
const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
return d * sumf_d;
#endif
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@ -2122,10 +2197,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
} }
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols); rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} }
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) { static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
@ -2876,8 +2951,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
// compute // compute
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main); rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
(void) src1; (void) src1;
(void) dst; (void) dst;
@ -3962,18 +4040,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
} }
func = ggml_cuda_mul; func = ggml_cuda_mul;
break; break;
case GGML_OP_GELU: case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
if (!any_on_device) { if (!any_on_device) {
return false; return false;
} }
func = ggml_cuda_gelu; func = ggml_cuda_gelu;
break; break;
case GGML_OP_SILU: case GGML_UNARY_OP_SILU:
if (!any_on_device) { if (!any_on_device) {
return false; return false;
} }
func = ggml_cuda_silu; func = ggml_cuda_silu;
break; break;
default:
return false;
} break;
case GGML_OP_NORM: case GGML_OP_NORM:
if (!any_on_device) { if (!any_on_device) {
return false; return false;

View file

@ -629,7 +629,9 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_SILU: case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
@ -643,7 +645,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
@ -657,7 +659,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_GELU: case GGML_UNARY_OP_GELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
@ -671,6 +673,12 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
} break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
{ {
if (encoder == nil) { if (encoder == nil) {
@ -914,7 +922,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor]; encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
} }
const float eps = 1e-6f; float eps;
memcpy(&eps, dst->op_params, sizeof(float));
const int nth = 512; const int nth = 512;
@ -1089,10 +1098,12 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
default: default:
{
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false); GGML_ASSERT(false);
} }
} }
}
if (encoder != nil) { if (encoder != nil) {
[encoder endEncoding]; [encoder endEncoding];

View file

@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
} }
} }
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { // il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = qb_curr->d;
float4 acc = 0.f; float2 acc = 0.f;
device uint16_t * qs = ((device uint16_t *)qb_curr + 1); device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
for (int i = 0; i < 16; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i] * (qs[i / 2] & 0x000F); acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + yl[i + 9] * (qs[i / 2] & 0xF000);
} }
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); return d * (sumy * -8.f + acc[0] + acc[1]);
} }
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { // il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d; float d = qb_curr->d;
float m = qb_curr->m; float m = qb_curr->m;
float4 acc = 0.f; device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
device uint16_t * qs = ((device uint16_t *)qb_curr + 2); float2 acc = 0.f;
for (int i = 0; i < 16; i+=2) { for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i] * (qs[i / 2] & 0x000F); acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + yl[i + 1] * (qs[i / 2] & 0x0F00);
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + yl[i + 9] * (qs[i / 2] & 0xF000);
} }
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; return d * (acc[0] + acc[1]) + sumy * m;
} }
// putting them in the kernel cause a significant performance penalty // putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows #define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
template<typename block_q_type> //Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
uint2 tgpig, uint tiisg, uint sgitg) { uint2 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; const int first_row = (r0 * nsg + sgitg) * nr;
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
float4 y_curr[8]; // src1 vector cache float yl[16]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum; float sumf[nr]={0.f};
thread float * yl=(thread float *)y_curr;
// each thread in a SIMD group deals with 1 block. const int ix = tiisg/2;
for (int column = 0; column < nb / N_SIMDWIDTH; column++) { const int il = 8*(tiisg%2);
device const float * yb = y + ix * QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
float sumy = 0; float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) { for (int i = 0; i < 8; i += 2) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); sumy += yb[i] + yb[i+1];
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; yl[i+0] = yb[i+ 0];
yl[i+1] = yb[i+ 1]/256.f;
sumy += yb[i+16] + yb[i+17];
yl[i+8] = yb[i+16]/16.f;
yl[i+9] = yb[i+17]/4096.f;
} }
for (int row = 0; row < N_DST; row++) { for (int row = 0; row < nr; row++) {
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
}
} }
// from now loads two rows every time and 16 blocks per row yb += QK4_0 * 16;
int ir = tiisg / (N_SIMDWIDTH / 2);
int ib = tiisg % (N_SIMDWIDTH / 2);
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
} }
for (int row = 0; row < N_DST; row+=2) { for (int row = 0; row < nr; ++row) {
if (nb_start + ib < nb) { const float tot = simd_sum(sumf[row]);
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); if (tiisg == 0 && first_row + row < ne01) {
} dst[r1*ne0 + first_row + row] = tot;
}
}
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
} }
} }
@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mat_q4_1_f32(
@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mat_f16_f32(

779
ggml.c

File diff suppressed because it is too large Load diff

67
ggml.h
View file

@ -330,16 +330,6 @@ extern "C" {
GGML_OP_ARGMAX, GGML_OP_ARGMAX,
GGML_OP_REPEAT, GGML_OP_REPEAT,
GGML_OP_REPEAT_BACK, GGML_OP_REPEAT_BACK,
GGML_OP_ABS,
GGML_OP_SGN,
GGML_OP_NEG,
GGML_OP_STEP,
GGML_OP_TANH,
GGML_OP_ELU,
GGML_OP_RELU,
GGML_OP_GELU,
GGML_OP_GELU_QUICK,
GGML_OP_SILU,
GGML_OP_SILU_BACK, GGML_OP_SILU_BACK,
GGML_OP_NORM, // normalize GGML_OP_NORM, // normalize
GGML_OP_RMS_NORM, GGML_OP_RMS_NORM,
@ -378,6 +368,8 @@ extern "C" {
GGML_OP_WIN_PART, GGML_OP_WIN_PART,
GGML_OP_WIN_UNPART, GGML_OP_WIN_UNPART,
GGML_OP_UNARY,
GGML_OP_MAP_UNARY, GGML_OP_MAP_UNARY,
GGML_OP_MAP_BINARY, GGML_OP_MAP_BINARY,
@ -391,6 +383,18 @@ extern "C" {
GGML_OP_COUNT, GGML_OP_COUNT,
}; };
enum ggml_unary_op {
GGML_UNARY_OP_ABS,
GGML_UNARY_OP_SGN,
GGML_UNARY_OP_NEG,
GGML_UNARY_OP_STEP,
GGML_UNARY_OP_TANH,
GGML_UNARY_OP_ELU,
GGML_UNARY_OP_RELU,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
};
// ggml object // ggml object
struct ggml_object { struct ggml_object {
@ -535,6 +539,7 @@ extern "C" {
GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_type_name(enum ggml_type type);
GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_name (enum ggml_op op);
GGML_API const char * ggml_op_symbol(enum ggml_op op);
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
@ -558,6 +563,7 @@ extern "C" {
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
@ -617,9 +623,11 @@ extern "C" {
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_name(struct ggml_tensor * tensor, const char * name);
GGML_API struct ggml_tensor * ggml_format_name(struct ggml_tensor * tensor, const char * fmt, ...); GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
// //
// operations on tensors with backpropagation // operations on tensors with backpropagation
@ -629,6 +637,11 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// in-place, returns view(a)
GGML_API struct ggml_tensor * ggml_dup_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_add( GGML_API struct ggml_tensor * ggml_add(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -853,14 +866,17 @@ extern "C" {
GGML_API struct ggml_tensor * ggml_rms_norm( GGML_API struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
GGML_API struct ggml_tensor * ggml_rms_norm_inplace( GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a,
float eps);
// a - x // a - x
// b - dy // b - dy
// TODO: update with configurable eps
GGML_API struct ggml_tensor * ggml_rms_norm_back( GGML_API struct ggml_tensor * ggml_rms_norm_back(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -952,11 +968,22 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// a -> b, in-place, return view(b)
GGML_API struct ggml_tensor * ggml_cpy_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
// make contiguous // make contiguous
GGML_API struct ggml_tensor * ggml_cont( GGML_API struct ggml_tensor * ggml_cont(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// make contiguous, in-place
GGML_API struct ggml_tensor * ggml_cont_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
// return view(a), b specifies the new shape // return view(a), b specifies the new shape
// TODO: when we start computing gradient, make a copy instead of view // TODO: when we start computing gradient, make a copy instead of view
GGML_API struct ggml_tensor * ggml_reshape( GGML_API struct ggml_tensor * ggml_reshape(
@ -1268,6 +1295,16 @@ extern "C" {
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
GGML_API struct ggml_tensor * ggml_unary(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_unary_op op);
GGML_API struct ggml_tensor * ggml_unary_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_unary_op op);
GGML_API struct ggml_tensor * ggml_map_unary_f32( GGML_API struct ggml_tensor * ggml_map_unary_f32(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,

6
grammars/arithmetic.gbnf Normal file
View file

@ -0,0 +1,6 @@
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*

13
grammars/chess.gbnf Normal file
View file

@ -0,0 +1,13 @@
# Specifies chess moves as a list in algebraic notation, using PGN conventions
# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern
root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+
move ::= (pawn | nonpawn | castle) [+#]?
# piece type, optional file/rank, optional capture, dest file & rank
nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8]
# optional file & capture, dest file & rank, optional promotion
pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])?
castle ::= "O-O" "-O"?

7
grammars/japanese.gbnf Normal file
View file

@ -0,0 +1,7 @@
# A probably incorrect grammar for Japanese
root ::= jp-char+ ([ \t\n] jp-char+)*
jp-char ::= hiragana | katakana | punctuation | cjk
hiragana ::= [ぁ-ゟ]
katakana ::= [ァ-ヿ]
punctuation ::= [、-〾]
cjk ::= [一-鿿]

29
grammars/json.gbnf Normal file
View file

@ -0,0 +1,29 @@
# Grammar for subset of JSON - doesn't support full string or number syntax
root ::= object
value ::= object | array | string | number | boolean | "null"
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}"
array ::=
"[" ws (
value
("," ws value)*
)? "]"
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
# Only plain integers currently
number ::= "-"? [0-9]+ ws
boolean ::= ("true" | "false") ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?

4
grammars/list.gbnf Normal file
View file

@ -0,0 +1,4 @@
root ::= item+
# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"

View file

@ -3297,8 +3297,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
#else #else
int8_t aux8[QK_K];
uint8_t aux8[QK_K];
int16_t aux16[16]; int16_t aux16[16];
float sums [8]; float sums [8];
memset(sums, 0, 8*sizeof(float)); memset(sums, 0, 8*sizeof(float));
@ -3308,7 +3307,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict q4 = x[i].qs; const uint8_t * restrict q4 = x[i].qs;
const uint8_t * restrict hm = x[i].qh; const uint8_t * restrict hm = x[i].qh;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
uint8_t * restrict a = aux8; int8_t * restrict a = aux8;
for (int l = 0; l < 32; ++l) { for (int l = 0; l < 32; ++l) {
a[l+ 0] = q4[l] & 0xF; a[l+ 0] = q4[l] & 0xF;
a[l+32] = q4[l] >> 4; a[l+32] = q4[l] >> 4;

357
llama.cpp
View file

@ -186,6 +186,7 @@ struct llama_hparams {
// LLaMAv2 // LLaMAv2
// TODO: load from model data hparams // TODO: load from model data hparams
float f_ffn_mult = 1.0f; float f_ffn_mult = 1.0f;
float f_rms_norm_eps = 1e-6f;
float rope_freq_base = 10000.0f; float rope_freq_base = 10000.0f;
float rope_freq_scale = 1.0f; float rope_freq_scale = 1.0f;
@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
/*.n_ctx =*/ 512, /*.n_ctx =*/ 512,
/*.n_batch =*/ 512, /*.n_batch =*/ 512,
/*.n_gqa =*/ 1, /*.n_gqa =*/ 1,
/*.rms_norm_eps =*/ 1e-6f,
/*.gpu_layers =*/ 0, /*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr, /*.tensor_split =*/ nullptr,
@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
int n_ctx, int n_ctx,
int n_batch, int n_batch,
int n_gqa, int n_gqa,
float rms_norm_eps,
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
const float * tensor_split, const float * tensor_split,
@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
auto & hparams = model.hparams; auto & hparams = model.hparams;
// TODO: read from file
hparams.f_rms_norm_eps = rms_norm_eps;
{ {
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break; case 26: model.type = e_model::MODEL_3B; break;
@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa()); fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps);
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
@ -1330,6 +1337,7 @@ static bool llama_model_load(
int n_ctx, int n_ctx,
int n_batch, int n_batch,
int n_gqa, int n_gqa,
float rms_norm_eps,
int n_gpu_layers, int n_gpu_layers,
int main_gpu, int main_gpu,
const float * tensor_split, const float * tensor_split,
@ -1343,7 +1351,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void *progress_callback_user_data) { void *progress_callback_user_data) {
try { try {
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true; return true;
} catch (const std::exception & err) { } catch (const std::exception & err) {
@ -1396,10 +1404,12 @@ static bool llama_eval_internal(
const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab = hparams.n_vocab;
const int64_t n_embd_gqa = hparams.n_embd_gqa(); const int64_t n_embd_gqa = hparams.n_embd_gqa();
LLAMA_ASSERT(n_embd_head == hparams.n_rot); LLAMA_ASSERT(n_embd_head == hparams.n_rot);
const float freq_base = hparams.rope_freq_base; const float freq_base = hparams.rope_freq_base;
const float freq_scale = hparams.rope_freq_scale; const float freq_scale = hparams.rope_freq_scale;
const float rms_norm_eps = hparams.f_rms_norm_eps;
const int n_gpu_layers = model.n_gpu_layers; const int n_gpu_layers = model.n_gpu_layers;
@ -1479,7 +1489,7 @@ static bool llama_eval_internal(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_0"); ggml_set_name(cur, "rms_norm_0");
@ -1627,7 +1637,7 @@ static bool llama_eval_internal(
{ {
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpFF); cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
offload_func(cur); offload_func(cur);
ggml_set_name(cur, "rms_norm_1"); ggml_set_name(cur, "rms_norm_1");
@ -1680,7 +1690,7 @@ static bool llama_eval_internal(
// norm // norm
{ {
cur = ggml_rms_norm(ctx0, inpL); cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
offload_func_nr(cur); offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_2"); ggml_set_name(cur, "rms_norm_2");
@ -1968,6 +1978,279 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
return output; return output;
} }
//
// grammar - internal
//
struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;
};
struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
};
// NOTE: assumes valid utf8 (but checks for overrun)
// adds a terminating 0 for use as pointer
std::vector<uint32_t> decode_utf8(const char * src) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
const char * pos = src;
std::vector<uint32_t> code_points;
while (*pos != 0) {
uint8_t first_byte = static_cast<uint8_t>(*pos);
uint8_t highbits = first_byte >> 4;
int len = lookup[highbits];
uint8_t mask = (1 << (8 - len)) - 1;
uint32_t value = first_byte & mask;
const char * end = pos + len; // may overrun!
++pos;
for ( ; pos < end && *pos != 0; ++pos) {
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
}
code_points.push_back(value);
}
code_points.push_back(0);
return code_points;
}
// returns true iff pos points to the end of one of the definitions of a rule
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
switch (pos->type) {
case LLAMA_GRETYPE_END: return true;
case LLAMA_GRETYPE_ALT: return true;
default: return false;
}
}
// returns true iff chr satisfies the char range at pos (regular or inverse range)
// asserts that pos is pointing to a char range element
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
const llama_grammar_element * pos,
const uint32_t chr) {
bool found = false;
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
do {
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
// inclusive range, e.g. [a-z]
found = found || (pos->value <= chr && chr <= pos[1].value);
pos += 2;
} else {
// exact char match, e.g. [a] or "a"
found = found || pos->value == chr;
pos += 1;
}
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
return std::make_pair(found == is_positive_char, pos);
}
// transforms a grammar pushdown stack into N possible stacks, all ending
// at a character range (terminal element)
static void llama_grammar_advance_stack(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<const llama_grammar_element *> & stack,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
if (stack.empty()) {
new_stacks.push_back(stack);
return;
}
const llama_grammar_element * pos = stack.back();
switch (pos->type) {
case LLAMA_GRETYPE_RULE_REF: {
const size_t rule_id = static_cast<size_t>(pos->value);
const llama_grammar_element * subpos = rules[rule_id].data();
do {
// init new stack without the top (pos)
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
// if this rule ref is followed by another element, add that to stack
new_stack.push_back(pos + 1);
}
if (!llama_grammar_is_end_of_sequence(subpos)) {
// if alternate is nonempty, add to stack
new_stack.push_back(subpos);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
while (!llama_grammar_is_end_of_sequence(subpos)) {
// scan to end of alternate def
subpos++;
}
if (subpos->type == LLAMA_GRETYPE_ALT) {
// there's another alternate def of this rule to process
subpos++;
} else {
break;
}
} while (true);
break;
}
case LLAMA_GRETYPE_CHAR:
case LLAMA_GRETYPE_CHAR_NOT:
new_stacks.push_back(stack);
break;
default:
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
// those
LLAMA_ASSERT(false);
}
}
// takes a set of possible pushdown stacks on a grammar, which are required to
// be positioned at a character range (see `llama_grammar_advance_stack`), and
// produces the N possible stacks if the given char is accepted at those
// positions
static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr) {
std::vector<std::vector<const llama_grammar_element *>> new_stacks;
for (const auto & stack : stacks) {
if (stack.empty()) {
continue;
}
auto match = llama_grammar_match_char(stack.back(), chr);
if (match.first) {
const llama_grammar_element * pos = match.second;
// update top of stack to next element, if any
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(pos)) {
new_stack.push_back(pos);
}
llama_grammar_advance_stack(rules, new_stack, new_stacks);
}
}
return new_stacks;
}
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const std::vector<llama_grammar_candidate> & candidates);
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<const llama_grammar_element *> & stack,
const std::vector<llama_grammar_candidate> & candidates) {
std::vector<llama_grammar_candidate> rejects;
if (stack.empty()) {
// accept nothing; EOS is handled elsewhere
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
return rejects;
}
const llama_grammar_element * stack_pos = stack.back();
std::vector<llama_grammar_candidate> next_candidates;
for (auto tok : candidates) {
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
if (tok.code_points[1] != 0) {
next_candidates.push_back({ tok.index, tok.code_points + 1 });
}
} else {
rejects.push_back(tok);
}
}
auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
// update top of stack to next element, if any
std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
stack_after.push_back(stack_pos_after);
}
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
llama_grammar_advance_stack(rules, stack_after, next_stacks);
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
for (auto tok : next_rejects) {
rejects.push_back({ tok.index, tok.code_points - 1 });
}
return rejects;
}
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const std::vector<llama_grammar_candidate> & candidates) {
LLAMA_ASSERT(!stacks.empty()); // REVIEW
if (candidates.empty()) {
return std::vector<llama_grammar_candidate>();
}
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
}
return rejects;
}
//
// grammar - external
//
struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index) {
const llama_grammar_element * pos;
// copy rule definitions into vectors
std::vector<std::vector<llama_grammar_element>> vec_rules(n_rules);
for (size_t i = 0; i < n_rules; i++) {
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
vec_rules[i].push_back(*pos);
}
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
}
// loop over alternates of start rule to build initial stacks
std::vector<std::vector<const llama_grammar_element *>> stacks;
pos = rules[start_rule_index];
do {
std::vector<const llama_grammar_element *> stack;
if (!llama_grammar_is_end_of_sequence(pos)) {
// if alternate is nonempty, add to stack
stack.push_back(pos);
}
llama_grammar_advance_stack(vec_rules, stack, stacks);
while (!llama_grammar_is_end_of_sequence(pos)) {
// scan to end of alternate def
pos++;
}
if (pos->type == LLAMA_GRETYPE_ALT) {
// there's another alternate def of this rule to process
pos++;
} else {
break;
}
} while (true);
return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
}
void llama_grammar_free(struct llama_grammar * grammar) {
delete grammar;
}
// //
// sampling // sampling
// //
@ -2253,6 +2536,47 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
} }
} }
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
assert(ctx);
const int64_t t_start_sample_us = ggml_time_us();
bool allow_eos = false;
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
allow_eos = true;
break;
}
}
const llama_token eos = llama_token_eos();
std::vector<std::vector<uint32_t>> candidates_decoded;
std::vector<llama_grammar_candidate> candidates_grammar;
for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const char * str = llama_token_to_str(ctx, id);
if (id == eos) {
if (!allow_eos) {
candidates->data[i].logit = -INFINITY;
}
} else if (*str == 0) {
candidates->data[i].logit = -INFINITY;
} else {
candidates_decoded.push_back(decode_utf8(str));
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
}
}
const auto rejects =
llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
for (auto & reject : rejects) {
candidates->data[reject.index].logit = -INFINITY;
}
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
static void llama_log_softmax(float * array, size_t size) { static void llama_log_softmax(float * array, size_t size) {
float max_l = *std::max_element(array, array + size); float max_l = *std::max_element(array, array + size);
float sum = 0.f; float sum = 0.f;
@ -2428,6 +2752,29 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result; return result;
} }
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();
if (token == llama_token_eos()) {
for (const auto & stack : grammar->stacks) {
if (stack.empty()) {
return;
}
}
LLAMA_ASSERT(false);
}
const char * str = llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string
auto code_points = decode_utf8(str);
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
}
LLAMA_ASSERT(!grammar->stacks.empty());
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
}
// //
// quantization // quantization
// //
@ -2750,7 +3097,7 @@ struct llama_model * llama_load_model_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers, if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
params.progress_callback_user_data)) { params.progress_callback_user_data)) {

50
llama.h
View file

@ -87,6 +87,7 @@ extern "C" {
int32_t n_ctx; // text context int32_t n_ctx; // text context
int32_t n_batch; // prompt processing batch size int32_t n_batch; // prompt processing batch size
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
int32_t n_gpu_layers; // number of layers to store in VRAM int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors int32_t main_gpu; // the GPU that is used for scratch and small tensors
@ -141,6 +142,40 @@ extern "C" {
bool quantize_output_tensor; // quantize output.weight bool quantize_output_tensor; // quantize output.weight
} llama_model_quantize_params; } llama_model_quantize_params;
// grammar types
struct llama_grammar;
// grammar element type
enum llama_gretype {
// end of rule definition
LLAMA_GRETYPE_END = 0,
// start of alternate definition for rule
LLAMA_GRETYPE_ALT = 1,
// non-terminal element: reference to rule
LLAMA_GRETYPE_RULE_REF = 2,
// terminal element: character (code point)
LLAMA_GRETYPE_CHAR = 3,
// inverse char(s) ([^a], [^a-b] [^abc])
LLAMA_GRETYPE_CHAR_NOT = 4,
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
// be an inclusive range ([a-z])
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
// modifies a preceding LLAMA_GRETYPE_CHAR or
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
LLAMA_GRETYPE_CHAR_ALT = 6,
};
typedef struct llama_grammar_element {
enum llama_gretype type;
uint32_t value; // Unicode code point or rule ID
} llama_grammar_element;
// performance timing information // performance timing information
struct llama_timings { struct llama_timings {
double t_start_ms; double t_start_ms;
@ -333,6 +368,15 @@ extern "C" {
LLAMA_API llama_token llama_token_eos(); // end-of-sentence LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); // next-line LLAMA_API llama_token llama_token_nl(); // next-line
// Grammar
//
LLAMA_API struct llama_grammar * llama_grammar_init(
const llama_grammar_element ** rules,
size_t n_rules,
size_t start_rule_index);
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
// Sampling functions // Sampling functions
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
@ -367,6 +411,9 @@ extern "C" {
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep); LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp); LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@ -388,6 +435,9 @@ extern "C" {
/// @details Randomly selects a token from the candidates based on their probabilities. /// @details Randomly selects a token from the candidates based on their probabilities.
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates); LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
// Performance information // Performance information
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
LLAMA_API void llama_print_timings(struct llama_context * ctx); LLAMA_API void llama_print_timings(struct llama_context * ctx);

View file

@ -64,7 +64,7 @@ void get_random_dims(int64_t * dims, int ndims) {
} }
} }
struct ggml_tensor * get_random_tensor( struct ggml_tensor * get_random_tensor_f32(
struct ggml_context * ctx0, struct ggml_context * ctx0,
int ndims, int ndims,
int64_t ne[], int64_t ne[],
@ -112,7 +112,55 @@ struct ggml_tensor * get_random_tensor(
return result; return result;
} }
struct ggml_tensor * get_random_tensor_int( struct ggml_tensor * get_random_tensor_f16(
struct ggml_context * ctx0,
int ndims,
int64_t ne[],
float fmin,
float fmax) {
struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
switch (ndims) {
case 1:
for (int i0 = 0; i0 < ne[0]; i0++) {
((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
}
break;
case 2:
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
}
}
break;
case 3:
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
}
}
}
break;
case 4:
for (int i3 = 0; i3 < ne[3]; i3++) {
for (int i2 = 0; i2 < ne[2]; i2++) {
for (int i1 = 0; i1 < ne[1]; i1++) {
for (int i0 = 0; i0 < ne[0]; i0++) {
((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
}
}
}
}
break;
default:
assert(false);
};
return result;
}
struct ggml_tensor * get_random_tensor_i32(
struct ggml_context * ctx0, struct ggml_context * ctx0,
int ndims, int ndims,
int64_t ne[], int64_t ne[],
@ -160,23 +208,6 @@ struct ggml_tensor * get_random_tensor_int(
return result; return result;
} }
float get_element(const struct ggml_tensor * t, int idx) {
if (t->type == GGML_TYPE_F32) {
return ((float *)t->data)[idx];
}
if (t->type == GGML_TYPE_I32) {
return ((int32_t *)t->data)[idx];
}
assert(false);
return INFINITY;
}
void set_element(struct ggml_tensor * t, int idx, float value) {
((float *)t->data)[idx] = value;
}
void print_elements(const char* label, const struct ggml_tensor * t) { void print_elements(const char* label, const struct ggml_tensor * t) {
if (!t) { if (!t) {
printf("%s: %s = null\n", __func__, label); printf("%s: %s = null\n", __func__, label);
@ -186,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
printf("%s: %s = [", __func__, label); printf("%s: %s = [", __func__, label);
for (int k = 0; k < nelements; ++k) { for (int k = 0; k < nelements; ++k) {
if (k > 0) { printf(", "); } if (k > 0) { printf(", "); }
printf("%.5f", get_element(t, k)); printf("%.5f", ggml_get_f32_1d(t, k));
} }
printf("] shape: ["); printf("] shape: [");
for (int k = 0; k < t->n_dims; ++k) { for (int k = 0; k < t->n_dims; ++k) {
@ -237,23 +268,23 @@ bool check_gradient(
const int nelements = ggml_nelements(x[i]); const int nelements = ggml_nelements(x[i]);
for (int k = 0; k < nelements; ++k) { for (int k = 0; k < nelements; ++k) {
// compute gradient using finite differences // compute gradient using finite differences
const float x0 = get_element(x[i], k); const float x0 = ggml_get_f32_1d(x[i], k);
const float xm = x0 - eps; const float xm = x0 - eps;
const float xp = x0 + eps; const float xp = x0 + eps;
set_element(x[i], k, xp); ggml_set_f32_1d(x[i], k, xp);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f0 = ggml_get_f32_1d(f, 0); const float f0 = ggml_get_f32_1d(f, 0);
set_element(x[i], k, xm); ggml_set_f32_1d(x[i], k, xm);
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
const float f1 = ggml_get_f32_1d(f, 0); const float f1 = ggml_get_f32_1d(f, 0);
const float g0 = (f0 - f1)/(2.0f*eps); const float g0 = (f0 - f1)/(2.0f*eps);
set_element(x[i], k, x0); ggml_set_f32_1d(x[i], k, x0);
// compute gradient using backward graph // compute gradient using backward graph
ggml_graph_reset (&gf); ggml_graph_reset (&gf);
@ -261,7 +292,7 @@ bool check_gradient(
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads); ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
const float g1 = get_element(x[i]->grad, k); const float g1 = ggml_get_f32_1d(x[i]->grad, k);
const float error_abs = fabsf(g0 - g1); const float error_abs = fabsf(g0 - g1);
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0; const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
@ -392,19 +423,35 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * x[MAX_NARGS]; struct ggml_tensor * x[MAX_NARGS];
// add // add f32
{ {
const int nargs = 2; const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f); check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
}
}
// add f16
{
const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
} }
} }
@ -414,7 +461,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -430,7 +477,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -446,7 +493,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -462,7 +509,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -478,7 +525,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -494,7 +541,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -510,7 +557,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -527,7 +574,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -537,6 +584,40 @@ int main(int argc, const char ** argv) {
} }
} }
// mean, not yet fully implemented
if(0)
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// argmax
if (0)
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// repeat // repeat
{ {
int64_t ne2[4]; int64_t ne2[4];
@ -549,15 +630,36 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY); check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
} }
}
// repeat back
{
int64_t ne2[4];
get_random_dims(ne2, 4);
ne2[0] = ne[0] * ne2[0];
ne2[1] = ne[1] * ne2[1];
ne2[2] = 1;
ne2[3] = 1;
const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
}
} }
// abs (finite differences do not work) // abs (finite differences do not work)
@ -566,7 +668,7 @@ int main(int argc, const char ** argv) {
// for (int ndims = 1; ndims <= 2; ++ndims) { // for (int ndims = 1; ndims <= 2; ++ndims) {
// for (int i = 0; i < nargs; ++i) { // for (int i = 0; i < nargs; ++i) {
// x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); // x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
// ggml_set_param(ctx0, x[i]); // ggml_set_param(ctx0, x[i]);
// } // }
@ -576,17 +678,82 @@ int main(int argc, const char ** argv) {
// } // }
//} //}
// sgn
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// neg
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// step
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// tanh, not yet fully implemented
if(0)
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// mul_mat // mul_mat
{ {
const int nargs = 2; const int nargs = 2;
for (int ndims = 2; ndims <= 2; ++ndims) { for (int ndims = 2; ndims <= 2; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
{ {
int64_t ne2[4]; int64_t ne2[4];
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
ne2[0] = ne[0]; ne2[0] = ne[0];
x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
} }
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -602,13 +769,63 @@ int main(int argc, const char ** argv) {
} }
} }
// elu, not yet fully implemented
if(0)
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// relu
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
}
}
// gelu, not yet fully implemented
if(0)
{
const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
}
}
// silu // silu
{ {
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
@ -629,11 +846,11 @@ int main(int argc, const char ** argv) {
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY); check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
} }
@ -647,8 +864,8 @@ int main(int argc, const char ** argv) {
ne2[0] = 1; ne2[0] = 1;
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
@ -659,20 +876,37 @@ int main(int argc, const char ** argv) {
} }
} }
// cpy // cpy f32
{ {
const int nargs = 2; const int nargs = 2;
for (int ndims = 1; ndims <= 2; ++ndims) { for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) { for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]); ggml_set_param(ctx0, x[i]);
} }
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1] // x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
check_gradient("cpy", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY); check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
}
}
// cpy f16
{
const int nargs = 2;
for (int ndims = 1; ndims <= 2; ++ndims) {
for (int i = 0; i < nargs; ++i) {
x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[i]);
}
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
} }
} }
@ -689,8 +923,8 @@ int main(int argc, const char ** argv) {
for (int i = 0; i < ndims; ++i) { for (int i = 0; i < ndims; ++i) {
ne2[0] *= ne[i]; ne2[0] *= ne[i];
} }
x[0] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -712,8 +946,8 @@ int main(int argc, const char ** argv) {
for (int i = 0; i < ndims; ++i) { for (int i = 0; i < ndims; ++i) {
ne2[0] *= ne[i]; ne2[0] *= ne[i];
} }
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -729,7 +963,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2; const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 1); get_random_dims(ne2, 1);
@ -737,7 +971,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 1); get_random_dims(ne2, 1);
} }
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@ -758,7 +992,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2; const int nargs = 2;
for (int ndims = 2; ndims <= 4; ++ndims) { for (int ndims = 2; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 2); get_random_dims(ne2, 2);
@ -766,7 +1000,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 2); get_random_dims(ne2, 2);
} }
x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@ -790,7 +1024,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2; const int nargs = 2;
for (int ndims = 3; ndims <= 4; ++ndims) { for (int ndims = 3; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 3); get_random_dims(ne2, 3);
@ -798,7 +1032,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 3); get_random_dims(ne2, 3);
} }
x[1] = get_random_tensor(ctx0, 3, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@ -824,7 +1058,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2; const int nargs = 2;
for (int ndims = 4; ndims <= 4; ++ndims) { for (int ndims = 4; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
@ -832,7 +1066,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
} }
x[1] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@ -858,7 +1092,7 @@ int main(int argc, const char ** argv) {
const int nargs = 2; const int nargs = 2;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 1); get_random_dims(ne2, 1);
@ -866,7 +1100,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 1); get_random_dims(ne2, 1);
} }
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
@ -887,7 +1121,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
for (int ndims = 2; ndims <= 4; ++ndims) { for (int ndims = 2; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
get_random_dims(ne2, 2); get_random_dims(ne2, 2);
@ -895,7 +1129,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 2); get_random_dims(ne2, 2);
} }
x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
@ -915,7 +1149,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -941,7 +1175,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
get_random_dims(ne2, 2); get_random_dims(ne2, 2);
while (ne2[0]*ne2[1] > ggml_nelements(x[0])) { while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
@ -971,7 +1205,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
for (int ndims = 1; ndims <= 4; ++ndims) { for (int ndims = 1; ndims <= 4; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
get_random_dims(ne2, 3); get_random_dims(ne2, 3);
while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) { while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
@ -1010,7 +1244,7 @@ int main(int argc, const char ** argv) {
for (int i=ndims; i<4; ++i) { for (int i=ndims; i<4; ++i) {
ne2[i] = 1; ne2[i] = 1;
} }
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -1043,7 +1277,7 @@ int main(int argc, const char ** argv) {
for (int i=ndims; i<4; ++i) { for (int i=ndims; i<4; ++i) {
ne2[i] = 1; ne2[i] = 1;
} }
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -1060,8 +1294,8 @@ int main(int argc, const char ** argv) {
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1}; int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
const int nargs = 1; const int nargs = 1;
const int ndims = 2; const int ndims = 2;
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]); x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -1075,7 +1309,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
const int ndims = 2; const int ndims = 2;
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
int n_past = irand(ne[0]); int n_past = irand(ne[0]);
@ -1090,7 +1324,7 @@ int main(int argc, const char ** argv) {
const int nargs = 1; const int nargs = 1;
const int ndims = 2; const int ndims = 2;
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
int n_past = irand(ne[0]); int n_past = irand(ne[0]);
@ -1108,7 +1342,7 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
for (int ndims = 1; ndims <= 3; ++ndims) { for (int ndims = 1; ndims <= 3; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
@ -1125,8 +1359,8 @@ int main(int argc, const char ** argv) {
get_random_dims(ne2, 4); get_random_dims(ne2, 4);
for (int ndims = 1; ndims <= 3; ++ndims) { for (int ndims = 1; ndims <= 3; ++ndims) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f); x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1])); struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
@ -1136,7 +1370,7 @@ int main(int argc, const char ** argv) {
} }
} }
// rope // rope f32
{ {
const int nargs = 1; const int nargs = 1;
@ -1148,7 +1382,7 @@ int main(int argc, const char ** argv) {
for (int ndims = 3; ndims <= 4; ++ndims) { for (int ndims = 3; ndims <= 4; ++ndims) {
for (int mode = 0; mode < 4; ++mode) { for (int mode = 0; mode < 4; ++mode) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) { for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f); x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
@ -1163,14 +1397,48 @@ int main(int argc, const char ** argv) {
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0)); struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY); check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
} }
} }
} }
} }
// flash_attn // rope f16
{
const int nargs = 1;
int64_t ne2[4];
get_random_dims(ne2, 4);
ne2[0] += ne2[0] % 2;
int n_rot = ne2[0];
for (int ndims = 3; ndims <= 4; ++ndims) {
for (int mode = 0; mode < 4; ++mode) {
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
ggml_set_param(ctx0, x[0]);
const bool skip_past = (mode & 1);
if (skip_past) {
// we have no past, so this would have to work on uninitialized memory.
// we only test the gradients here;
// skip_past should have no influence on gradient computation.
// so when other modes work, we assume that this does as well.
continue;
}
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
}
}
}
}
// flash_attn f32
{ {
const int nargs = 3; const int nargs = 3;
@ -1196,16 +1464,57 @@ int main(int argc, const char ** argv) {
nek[3] = 1; nek[3] = 1;
nev[3] = 1; nev[3] = 1;
} }
x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f); x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f); x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f); x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
ggml_set_param(ctx0, x[0]); ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]); ggml_set_param(ctx0, x[1]);
ggml_set_param(ctx0, x[2]); ggml_set_param(ctx0, x[2]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f); check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
}
}
}
// flash_attn f16, not yet fully implemented
if(0)
{
const int nargs = 3;
int64_t ne2[4];
get_random_dims(ne2, 4);
int64_t D = ne2[0];
int64_t N = ne2[1];
int64_t M = ne2[2] + N;
int64_t B = ne2[3];
for (int masked = 0; masked <= 1; ++masked) {
for (int ndims = 2; ndims <= 4; ++ndims) {
int64_t neq[4] = { D, N, B, ne[3] };
int64_t nek[4] = { D, M, B, ne[3] };
int64_t nev[4] = { M, D, B, ne[3] };
if (ndims == 2) {
neq[2] = 1; neq[3] = 1;
nek[2] = 1; nek[3] = 1;
nev[2] = 1; nev[3] = 1;
} else if (ndims == 3) {
neq[3] = 1;
nek[3] = 1;
nev[3] = 1;
}
x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
ggml_set_param(ctx0, x[0]);
ggml_set_param(ctx0, x[1]);
ggml_set_param(ctx0, x[2]);
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
} }
} }
} }

View file

@ -125,9 +125,9 @@ int main(void) {
}; };
struct ggml_context * ctx = ggml_init(params); struct ggml_context * ctx = ggml_init(params);
int64_t ne1[4] = {4, 1024, 1, 1}; int64_t ne1[4] = {4, 128, 1, 1};
int64_t ne2[4] = {4, 2048, 1, 1};; int64_t ne2[4] = {4, 256, 1, 1};;
int64_t ne3[4] = {1024, 2048, 1, 1}; int64_t ne3[4] = {128, 256, 1, 1};
struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1); struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1); struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);