Merge branch 'master' into HEAD
This commit is contained in:
commit
ea02f675f9
27 changed files with 2855 additions and 1140 deletions
7
Makefile
7
Makefile
|
@ -323,6 +323,9 @@ llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
|
||||||
common.o: examples/common.cpp examples/common.h
|
common.o: examples/common.cpp examples/common.h
|
||||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
|
grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h
|
||||||
|
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||||
|
|
||||||
libllama.so: llama.o ggml.o $(OBJS)
|
libllama.so: llama.o ggml.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
|
||||||
|
|
||||||
|
@ -333,7 +336,7 @@ clean:
|
||||||
# Examples
|
# Examples
|
||||||
#
|
#
|
||||||
|
|
||||||
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
main: examples/main/main.cpp build-info.h ggml.o llama.o common.o grammar-parser.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
@echo
|
@echo
|
||||||
@echo '==== Run ./main -h for help. ===='
|
@echo '==== Run ./main -h for help. ===='
|
||||||
|
@ -357,7 +360,7 @@ embedding: examples/embedding/embedding.cpp build-info.h ggml.
|
||||||
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
|
||||||
|
|
||||||
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
|
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2)
|
||||||
|
|
||||||
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS)
|
||||||
|
|
|
@ -13,6 +13,8 @@ set(TARGET common)
|
||||||
add_library(${TARGET} OBJECT
|
add_library(${TARGET} OBJECT
|
||||||
common.h
|
common.h
|
||||||
common.cpp
|
common.cpp
|
||||||
|
grammar-parser.h
|
||||||
|
grammar-parser.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static const float rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
float frand() {
|
float frand() {
|
||||||
return (float)rand()/(float)RAND_MAX;
|
return (float)rand()/(float)RAND_MAX;
|
||||||
}
|
}
|
||||||
|
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
cur = ggml_mul(ctx0,
|
cur = ggml_mul(ctx0,
|
||||||
|
@ -685,7 +687,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
|
@ -729,7 +731,7 @@ struct ggml_tensor * forward(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
|
@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N*n_batch,1,1]
|
// inpL shape [n_embd,N*n_batch,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
cur = ggml_mul(ctx0,
|
cur = ggml_mul(ctx0,
|
||||||
|
@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
|
@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
|
|
|
@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_gqa = std::stoi(argv[i]);
|
params.n_gqa = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.rms_norm_eps = std::stof(argv[i]);
|
||||||
} else if (arg == "--rope-freq-base") {
|
} else if (arg == "--rope-freq-base") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -438,6 +444,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.input_suffix = argv[i];
|
params.input_suffix = argv[i];
|
||||||
|
} else if (arg == "--grammar") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.grammar = argv[i];
|
||||||
|
} else if (arg == "--grammar-file") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::ifstream file(argv[i]);
|
||||||
|
if (!file) {
|
||||||
|
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::copy(
|
||||||
|
std::istreambuf_iterator<char>(file),
|
||||||
|
std::istreambuf_iterator<char>(),
|
||||||
|
std::back_inserter(params.grammar)
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
gpt_print_usage(argc, argv, default_params);
|
gpt_print_usage(argc, argv, default_params);
|
||||||
|
@ -497,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||||
|
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
||||||
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
|
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
|
||||||
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
|
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
|
||||||
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
|
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
|
||||||
|
@ -514,6 +543,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
|
fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
|
||||||
fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
|
fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
|
||||||
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
|
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
|
||||||
|
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
|
||||||
|
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
|
||||||
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
|
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
|
||||||
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
|
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
|
||||||
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
||||||
|
@ -591,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||||
lparams.n_ctx = params.n_ctx;
|
lparams.n_ctx = params.n_ctx;
|
||||||
lparams.n_batch = params.n_batch;
|
lparams.n_batch = params.n_batch;
|
||||||
lparams.n_gqa = params.n_gqa;
|
lparams.n_gqa = params.n_gqa;
|
||||||
|
lparams.rms_norm_eps = params.rms_norm_eps;
|
||||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
lparams.main_gpu = params.main_gpu;
|
lparams.main_gpu = params.main_gpu;
|
||||||
lparams.tensor_split = params.tensor_split;
|
lparams.tensor_split = params.tensor_split;
|
||||||
|
|
|
@ -34,6 +34,7 @@ struct gpt_params {
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
|
float rms_norm_eps = 1e-6; // rms norm epsilon
|
||||||
float rope_freq_base = 10000.0f; // RoPE base frequency
|
float rope_freq_base = 10000.0f; // RoPE base frequency
|
||||||
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
|
||||||
|
|
||||||
|
@ -63,6 +64,7 @@ struct gpt_params {
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
std::string input_prefix = ""; // string to prefix user inputs with
|
std::string input_prefix = ""; // string to prefix user inputs with
|
||||||
std::string input_suffix = ""; // string to suffix user inputs with
|
std::string input_suffix = ""; // string to suffix user inputs with
|
||||||
|
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||||
|
|
||||||
std::string lora_adapter = ""; // lora adapter path
|
std::string lora_adapter = ""; // lora adapter path
|
||||||
|
|
423
examples/grammar-parser.cpp
Normal file
423
examples/grammar-parser.cpp
Normal file
|
@ -0,0 +1,423 @@
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cwchar>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <exception>
|
||||||
|
|
||||||
|
namespace grammar_parser {
|
||||||
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
|
// copied from llama.cpp
|
||||||
|
std::pair<uint32_t, const char *> decode_utf8(const char * src) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*src);
|
||||||
|
uint8_t highbits = first_byte >> 4;
|
||||||
|
int len = lookup[highbits];
|
||||||
|
uint8_t mask = (1 << (8 - len)) - 1;
|
||||||
|
uint32_t value = first_byte & mask;
|
||||||
|
const char * end = src + len; // may overrun!
|
||||||
|
const char * pos = src + 1;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||||
|
auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
|
||||||
|
return result.first->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
|
||||||
|
uint32_t next_id = static_cast<uint32_t>(state.symbol_ids.size());
|
||||||
|
state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
|
||||||
|
return next_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
void add_rule(
|
||||||
|
parse_state & state,
|
||||||
|
uint32_t rule_id,
|
||||||
|
const std::vector<llama_grammar_element> & rule) {
|
||||||
|
if (state.rules.size() <= rule_id) {
|
||||||
|
state.rules.resize(rule_id + 1);
|
||||||
|
}
|
||||||
|
state.rules[rule_id] = rule;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_word_char(char c) {
|
||||||
|
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||||
|
const char * pos = src;
|
||||||
|
const char * end = src + size;
|
||||||
|
uint32_t value = 0;
|
||||||
|
for ( ; pos < end && *pos; pos++) {
|
||||||
|
value <<= 4;
|
||||||
|
char c = *pos;
|
||||||
|
if ('a' <= c && c <= 'f') {
|
||||||
|
value += c - 'a' + 10;
|
||||||
|
} else if ('A' <= c && c <= 'F') {
|
||||||
|
value += c - 'A' + 10;
|
||||||
|
} else if ('0' <= c && c <= '9') {
|
||||||
|
value += c - '0';
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (pos != end) {
|
||||||
|
throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
|
||||||
|
}
|
||||||
|
return std::make_pair(value, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_space(const char * src, bool newline_ok) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
|
||||||
|
(newline_ok && (*pos == '\r' || *pos == '\n'))) {
|
||||||
|
if (*pos == '#') {
|
||||||
|
while (*pos && *pos != '\r' && *pos != '\n') {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_name(const char * src) {
|
||||||
|
const char * pos = src;
|
||||||
|
while (is_word_char(*pos)) {
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos == src) {
|
||||||
|
throw std::runtime_error(std::string("expecting name at ") + src);
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<uint32_t, const char *> parse_char(const char * src) {
|
||||||
|
if (*src == '\\') {
|
||||||
|
switch (src[1]) {
|
||||||
|
case 'x': return parse_hex(src + 2, 2);
|
||||||
|
case 'u': return parse_hex(src + 2, 4);
|
||||||
|
case 'U': return parse_hex(src + 2, 8);
|
||||||
|
case 't': return std::make_pair('\t', src + 2);
|
||||||
|
case 'r': return std::make_pair('\r', src + 2);
|
||||||
|
case 'n': return std::make_pair('\n', src + 2);
|
||||||
|
case '\\':
|
||||||
|
case '"':
|
||||||
|
case '[':
|
||||||
|
case ']':
|
||||||
|
return std::make_pair(src[1], src + 2);
|
||||||
|
default:
|
||||||
|
throw std::runtime_error(std::string("unknown escape at ") + src);
|
||||||
|
}
|
||||||
|
} else if (*src) {
|
||||||
|
return decode_utf8(src);
|
||||||
|
}
|
||||||
|
throw std::runtime_error("unexpected end of input");
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested);
|
||||||
|
|
||||||
|
const char * parse_sequence(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
std::vector<llama_grammar_element> & out_elements,
|
||||||
|
bool is_nested) {
|
||||||
|
size_t last_sym_start = out_elements.size();
|
||||||
|
const char * pos = src;
|
||||||
|
while (*pos) {
|
||||||
|
if (*pos == '"') { // literal string
|
||||||
|
pos++;
|
||||||
|
last_sym_start = out_elements.size();
|
||||||
|
while (*pos != '"') {
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '[') { // char range(s)
|
||||||
|
pos++;
|
||||||
|
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
|
if (*pos == '^') {
|
||||||
|
pos++;
|
||||||
|
start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
}
|
||||||
|
last_sym_start = out_elements.size();
|
||||||
|
while (*pos != ']') {
|
||||||
|
auto char_pair = parse_char(pos);
|
||||||
|
pos = char_pair.second;
|
||||||
|
enum llama_gretype type = last_sym_start < out_elements.size()
|
||||||
|
? LLAMA_GRETYPE_CHAR_ALT
|
||||||
|
: start_type;
|
||||||
|
|
||||||
|
out_elements.push_back({type, char_pair.first});
|
||||||
|
if (pos[0] == '-' && pos[1] != ']') {
|
||||||
|
auto endchar_pair = parse_char(pos + 1);
|
||||||
|
pos = endchar_pair.second;
|
||||||
|
out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (is_word_char(*pos)) { // rule reference
|
||||||
|
const char * name_end = parse_name(pos);
|
||||||
|
uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
|
||||||
|
pos = parse_space(name_end, is_nested);
|
||||||
|
last_sym_start = out_elements.size();
|
||||||
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
|
} else if (*pos == '(') { // grouping
|
||||||
|
// parse nested alternates into synthesized rule
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||||
|
pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
|
||||||
|
last_sym_start = out_elements.size();
|
||||||
|
// output reference to synthesized rule
|
||||||
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
if (*pos != ')') {
|
||||||
|
throw std::runtime_error(std::string("expecting ')' at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
|
||||||
|
if (last_sym_start == out_elements.size()) {
|
||||||
|
throw std::runtime_error(std::string("expecting preceeding item to */+/? at ") + pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply transformation to previous symbol (last_sym_start to end) according to
|
||||||
|
// rewrite rules:
|
||||||
|
// S* --> S' ::= S S' |
|
||||||
|
// S+ --> S' ::= S S' | S
|
||||||
|
// S? --> S' ::= S |
|
||||||
|
uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
|
||||||
|
std::vector<llama_grammar_element> sub_rule;
|
||||||
|
// add preceding symbol to generated rule
|
||||||
|
sub_rule.insert(
|
||||||
|
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||||
|
if (*pos == '*' || *pos == '+') {
|
||||||
|
// cause generated rule to recurse
|
||||||
|
sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
}
|
||||||
|
// mark start of alternate def
|
||||||
|
sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
if (*pos == '+') {
|
||||||
|
// add preceding symbol as alternate only for '+' (otherwise empty)
|
||||||
|
sub_rule.insert(
|
||||||
|
sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||||
|
}
|
||||||
|
sub_rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule(state, sub_rule_id, sub_rule);
|
||||||
|
|
||||||
|
// in original rule, replace previous symbol with reference to generated rule
|
||||||
|
out_elements.resize(last_sym_start);
|
||||||
|
out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
|
||||||
|
|
||||||
|
pos = parse_space(pos + 1, is_nested);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_alternates(
|
||||||
|
parse_state & state,
|
||||||
|
const char * src,
|
||||||
|
const std::string & rule_name,
|
||||||
|
uint32_t rule_id,
|
||||||
|
bool is_nested) {
|
||||||
|
std::vector<llama_grammar_element> rule;
|
||||||
|
const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
|
||||||
|
while (*pos == '|') {
|
||||||
|
rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
|
pos = parse_space(pos + 1, true);
|
||||||
|
pos = parse_sequence(state, pos, rule_name, rule, is_nested);
|
||||||
|
}
|
||||||
|
rule.push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
add_rule(state, rule_id, rule);
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * parse_rule(parse_state & state, const char * src) {
|
||||||
|
const char * name_end = parse_name(src);
|
||||||
|
const char * pos = parse_space(name_end, false);
|
||||||
|
size_t name_len = name_end - src;
|
||||||
|
uint32_t rule_id = get_symbol_id(state, src, name_len);
|
||||||
|
const std::string name(src, name_len);
|
||||||
|
|
||||||
|
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
|
||||||
|
throw std::runtime_error(std::string("expecting ::= at ") + pos);
|
||||||
|
}
|
||||||
|
pos = parse_space(pos + 3, true);
|
||||||
|
|
||||||
|
pos = parse_alternates(state, pos, name, rule_id, false);
|
||||||
|
|
||||||
|
if (*pos == '\r') {
|
||||||
|
pos += pos[1] == '\n' ? 2 : 1;
|
||||||
|
} else if (*pos == '\n') {
|
||||||
|
pos++;
|
||||||
|
} else if (*pos) {
|
||||||
|
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
|
||||||
|
}
|
||||||
|
return parse_space(pos, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
parse_state parse(const char * src) {
|
||||||
|
try {
|
||||||
|
parse_state state;
|
||||||
|
const char * pos = parse_space(src, true);
|
||||||
|
while (*pos) {
|
||||||
|
pos = parse_rule(state, pos);
|
||||||
|
}
|
||||||
|
return state;
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
|
||||||
|
return parse_state();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_grammar_char(FILE * file, uint32_t c) {
|
||||||
|
if (0x20 <= c && c <= 0x7f) {
|
||||||
|
fprintf(file, "%c", static_cast<char>(c));
|
||||||
|
} else {
|
||||||
|
// cop out of encoding UTF-8
|
||||||
|
fprintf(file, "<U+%04X>", c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_char_element(llama_grammar_element elem) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: return true;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_rule_binary(FILE * file, const std::vector<llama_grammar_element> & rule) {
|
||||||
|
for (auto elem : rule) {
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
|
||||||
|
case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
|
||||||
|
}
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "(%u) ", elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
fprintf(file, "(\"");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
fprintf(file, "\") ");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_rule(
|
||||||
|
FILE * file,
|
||||||
|
uint32_t rule_id,
|
||||||
|
const std::vector<llama_grammar_element> & rule,
|
||||||
|
const std::map<uint32_t, std::string> & symbol_id_names) {
|
||||||
|
if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||||
|
}
|
||||||
|
fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||||
|
for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||||
|
llama_grammar_element elem = rule[i];
|
||||||
|
switch (elem.type) {
|
||||||
|
case LLAMA_GRETYPE_END:
|
||||||
|
throw std::runtime_error(
|
||||||
|
"unexpected end of rule: " + std::to_string(rule_id) + "," +
|
||||||
|
std::to_string(i));
|
||||||
|
case LLAMA_GRETYPE_ALT:
|
||||||
|
fprintf(file, "| ");
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_RULE_REF:
|
||||||
|
fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
fprintf(file, "[");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
fprintf(file, "[^");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
fprintf(file, "-");
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
if (i == 0 || !is_char_element(rule[i - 1])) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
|
||||||
|
std::to_string(rule_id) + "," + std::to_string(i));
|
||||||
|
}
|
||||||
|
print_grammar_char(file, elem.value);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (is_char_element(elem)) {
|
||||||
|
switch (rule[i + 1].type) {
|
||||||
|
case LLAMA_GRETYPE_CHAR_ALT:
|
||||||
|
case LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
fprintf(file, "] ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_grammar(FILE * file, const parse_state & state) {
|
||||||
|
try {
|
||||||
|
std::map<uint32_t, std::string> symbol_id_names;
|
||||||
|
for (auto kv : state.symbol_ids) {
|
||||||
|
symbol_id_names[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
for (size_t i = 0, end = state.rules.size(); i < end; i++) {
|
||||||
|
// fprintf(file, "%zu: ", i);
|
||||||
|
// print_rule_binary(file, state.rules[i]);
|
||||||
|
print_rule(file, i, state.rules[i], symbol_id_names);
|
||||||
|
// fprintf(file, "\n");
|
||||||
|
}
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
||||||
|
std::vector<const llama_grammar_element *> ret;
|
||||||
|
for (const auto & rule : rules) {
|
||||||
|
ret.push_back(rule.data());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
29
examples/grammar-parser.h
Normal file
29
examples/grammar-parser.h
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
// Implements a parser for an extended Backus-Naur form (BNF), producing the
|
||||||
|
// binary context-free grammar format specified by llama.h. Supports character
|
||||||
|
// ranges, grouping, and repetition operators. As an example, a grammar for
|
||||||
|
// arithmetic might look like:
|
||||||
|
//
|
||||||
|
// root ::= expr
|
||||||
|
// expr ::= term ([-+*/] term)*
|
||||||
|
// term ::= num | "(" space expr ")" space
|
||||||
|
// num ::= [0-9]+ space
|
||||||
|
// space ::= [ \t\n]*
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "llama.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace grammar_parser {
|
||||||
|
struct parse_state {
|
||||||
|
std::map<std::string, uint32_t> symbol_ids;
|
||||||
|
std::vector<std::vector<llama_grammar_element>> rules;
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> c_rules();
|
||||||
|
};
|
||||||
|
|
||||||
|
parse_state parse(const char * src);
|
||||||
|
void print_grammar(FILE * file, const parse_state & state);
|
||||||
|
}
|
|
@ -6,6 +6,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "build-info.h"
|
#include "build-info.h"
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
@ -337,6 +338,31 @@ int main(int argc, char ** argv) {
|
||||||
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
llama_grammar * grammar = NULL;
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (parsed_grammar.rules.empty()) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "%s: grammar:\n", __func__);
|
||||||
|
grammar_parser::print_grammar(stderr, parsed_grammar);
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
{
|
||||||
|
auto it = params.logit_bias.find(llama_token_eos());
|
||||||
|
if (it != params.logit_bias.end() && it->second == -INFINITY) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
|
grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> last_n_tokens(n_ctx);
|
std::vector<llama_token> last_n_tokens(n_ctx);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
||||||
|
@ -570,6 +596,10 @@ int main(int argc, char ** argv) {
|
||||||
logits[llama_token_nl()] = nl_logit;
|
logits[llama_token_nl()] = nl_logit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_sample_grammar(ctx, &candidates_p, grammar);
|
||||||
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
id = llama_sample_token_greedy(ctx, &candidates_p);
|
id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||||
|
@ -595,6 +625,10 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
// printf("`%d`", candidates_p.size);
|
// printf("`%d`", candidates_p.size);
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_accept_token(ctx, grammar, id);
|
||||||
|
}
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_n_tokens.erase(last_n_tokens.begin());
|
||||||
last_n_tokens.push_back(id);
|
last_n_tokens.push_back(id);
|
||||||
}
|
}
|
||||||
|
@ -725,6 +759,18 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_past > 0) {
|
if (n_past > 0) {
|
||||||
|
if (is_interacting) {
|
||||||
|
// reset grammar state if we're restarting generation
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_free(grammar);
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(
|
||||||
|
parsed_grammar.c_rules());
|
||||||
|
grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(), grammar_rules.size(),
|
||||||
|
parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
}
|
||||||
is_interacting = false;
|
is_interacting = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -756,6 +802,9 @@ int main(int argc, char ** argv) {
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_free(grammar);
|
||||||
|
}
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -73,6 +73,37 @@
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fieldset.two {
|
||||||
|
display: grid;
|
||||||
|
grid-template: "a a";
|
||||||
|
gap: 1em;
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldset.three {
|
||||||
|
display: grid;
|
||||||
|
grid-template: "a a a";
|
||||||
|
gap: 1em;
|
||||||
|
}
|
||||||
|
|
||||||
|
details {
|
||||||
|
border: 1px solid #aaa;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 0.5em 0.5em 0;
|
||||||
|
margin-top: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
summary {
|
||||||
|
font-weight: bold;
|
||||||
|
margin: -0.5em -0.5em 0;
|
||||||
|
padding: 0.5em;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
details[open] {
|
||||||
|
padding: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
textarea {
|
textarea {
|
||||||
padding: 5px;
|
padding: 5px;
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
|
@ -125,10 +156,17 @@
|
||||||
const params = signal({
|
const params = signal({
|
||||||
n_predict: 400,
|
n_predict: 400,
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
repeat_last_n: 256,
|
repeat_last_n: 256, // 0 = disable penalty, -1 = context size
|
||||||
repeat_penalty: 1.18,
|
repeat_penalty: 1.18, // 1.0 = disabled
|
||||||
top_k: 40,
|
top_k: 40, // <= 0 to use vocab size
|
||||||
top_p: 0.5,
|
top_p: 0.5, // 1.0 = disabled
|
||||||
|
tfs_z: 1.0, // 1.0 = disabled
|
||||||
|
typical_p: 1.0, // 1.0 = disabled
|
||||||
|
presence_penalty: 0.0, // 0.0 = disabled
|
||||||
|
frequency_penalty: 0.0, // 0.0 = disabled
|
||||||
|
mirostat: 0, // 0/1/2
|
||||||
|
mirostat_tau: 5, // target entropy
|
||||||
|
mirostat_eta: 0.1, // learning rate
|
||||||
})
|
})
|
||||||
|
|
||||||
const llamaStats = signal(null)
|
const llamaStats = signal(null)
|
||||||
|
@ -264,6 +302,27 @@
|
||||||
const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value }
|
const updateSession = (el) => session.value = { ...session.value, [el.target.name]: el.target.value }
|
||||||
const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value }
|
const updateParams = (el) => params.value = { ...params.value, [el.target.name]: el.target.value }
|
||||||
const updateParamsFloat = (el) => params.value = { ...params.value, [el.target.name]: parseFloat(el.target.value) }
|
const updateParamsFloat = (el) => params.value = { ...params.value, [el.target.name]: parseFloat(el.target.value) }
|
||||||
|
const updateParamsInt = (el) => params.value = { ...params.value, [el.target.name]: Math.floor(parseFloat(el.target.value)) }
|
||||||
|
|
||||||
|
const FloatField = ({label, max, min, name, step, value}) => {
|
||||||
|
return html`
|
||||||
|
<div>
|
||||||
|
<label for="${name}">${label}</label>
|
||||||
|
<input type="range" id="${name}" min="${min}" max="${max}" step="${step}" name="${name}" value="${value}" oninput=${updateParamsFloat} />
|
||||||
|
<span>${value}</span>
|
||||||
|
</div>
|
||||||
|
`
|
||||||
|
};
|
||||||
|
|
||||||
|
const IntField = ({label, max, min, name, value}) => {
|
||||||
|
return html`
|
||||||
|
<div>
|
||||||
|
<label for="${name}">${label}</label>
|
||||||
|
<input type="range" id="${name}" min="${min}" max="${max}" name="${name}" value="${value}" oninput=${updateParamsInt} />
|
||||||
|
<span>${value}</span>
|
||||||
|
</div>
|
||||||
|
`
|
||||||
|
};
|
||||||
|
|
||||||
return html`
|
return html`
|
||||||
<form>
|
<form>
|
||||||
|
@ -272,7 +331,9 @@
|
||||||
<label for="prompt">Prompt</label>
|
<label for="prompt">Prompt</label>
|
||||||
<textarea type="text" name="prompt" value="${session.value.prompt}" rows=4 oninput=${updateSession}/>
|
<textarea type="text" name="prompt" value="${session.value.prompt}" rows=4 oninput=${updateSession}/>
|
||||||
</div>
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
|
||||||
|
<fieldset class="two">
|
||||||
<div>
|
<div>
|
||||||
<label for="user">User name</label>
|
<label for="user">User name</label>
|
||||||
<input type="text" name="user" value="${session.value.user}" oninput=${updateSession} />
|
<input type="text" name="user" value="${session.value.user}" oninput=${updateSession} />
|
||||||
|
@ -282,7 +343,9 @@
|
||||||
<label for="bot">Bot name</label>
|
<label for="bot">Bot name</label>
|
||||||
<input type="text" name="char" value="${session.value.char}" oninput=${updateSession} />
|
<input type="text" name="char" value="${session.value.char}" oninput=${updateSession} />
|
||||||
</div>
|
</div>
|
||||||
|
</fieldset>
|
||||||
|
|
||||||
|
<fieldset>
|
||||||
<div>
|
<div>
|
||||||
<label for="template">Prompt template</label>
|
<label for="template">Prompt template</label>
|
||||||
<textarea id="template" name="template" value="${session.value.template}" rows=4 oninput=${updateSession}/>
|
<textarea id="template" name="template" value="${session.value.template}" rows=4 oninput=${updateSession}/>
|
||||||
|
@ -292,38 +355,44 @@
|
||||||
<label for="template">Chat history template</label>
|
<label for="template">Chat history template</label>
|
||||||
<textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/>
|
<textarea id="template" name="historyTemplate" value="${session.value.historyTemplate}" rows=1 oninput=${updateSession}/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="temperature">Temperature</label>
|
|
||||||
<input type="range" id="temperature" min="0.0" max="1.0" step="0.01" name="temperature" value="${params.value.temperature}" oninput=${updateParamsFloat} />
|
|
||||||
<span>${params.value.temperature}</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="nPredict">Predictions</label>
|
|
||||||
<input type="range" id="nPredict" min="1" max="2048" step="1" name="n_predict" value="${params.value.n_predict}" oninput=${updateParamsFloat} />
|
|
||||||
<span>${params.value.n_predict}</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="repeat_penalty">Penalize repeat sequence</label>
|
|
||||||
<input type="range" id="repeat_penalty" min="0.0" max="2.0" step="0.01" name="repeat_penalty" value="${params.value.repeat_penalty}" oninput=${updateParamsFloat} />
|
|
||||||
<span>${params.value.repeat_penalty}</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div>
|
|
||||||
<label for="repeat_last_n">Consider N tokens for penalize</label>
|
|
||||||
<input type="range" id="repeat_last_n" min="0.0" max="2048" name="repeat_last_n" value="${params.value.repeat_last_n}" oninput=${updateParamsFloat} />
|
|
||||||
<span>${params.value.repeat_last_n}</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
</fieldset>
|
</fieldset>
|
||||||
|
|
||||||
|
<fieldset class="two">
|
||||||
|
${IntField({label: "Predictions", max: 2048, min: -1, name: "n_predict", value: params.value.n_predict})}
|
||||||
|
${FloatField({label: "Temperature", max: 1.5, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature})}
|
||||||
|
${FloatField({label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty})}
|
||||||
|
${IntField({label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n})}
|
||||||
|
${IntField({label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k})}
|
||||||
|
${FloatField({label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p})}
|
||||||
|
</fieldset>
|
||||||
|
<details>
|
||||||
|
<summary>More options</summary>
|
||||||
|
<fieldset class="two">
|
||||||
|
${FloatField({label: "TFS-Z", max: 1.0, min: 0.0, name: "tfs_z", step: 0.01, value: params.value.tfs_z})}
|
||||||
|
${FloatField({label: "Typical P", max: 1.0, min: 0.0, name: "typical_p", step: 0.01, value: params.value.typical_p})}
|
||||||
|
${FloatField({label: "Presence penalty", max: 1.0, min: 0.0, name: "presence_penalty", step: 0.01, value: params.value.presence_penalty})}
|
||||||
|
${FloatField({label: "Frequency penalty", max: 1.0, min: 0.0, name: "frequency_penalty", step: 0.01, value: params.value.frequency_penalty})}
|
||||||
|
</fieldset>
|
||||||
|
<hr />
|
||||||
|
<fieldset class="three">
|
||||||
|
<div>
|
||||||
|
<label><input type="radio" name="mirostat" value="0" checked=${params.value.mirostat == 0} oninput=${updateParamsInt} /> no Mirostat</label>
|
||||||
|
<label><input type="radio" name="mirostat" value="1" checked=${params.value.mirostat == 1} oninput=${updateParamsInt} /> Mirostat v1</label>
|
||||||
|
<label><input type="radio" name="mirostat" value="2" checked=${params.value.mirostat == 2} oninput=${updateParamsInt} /> Mirostat v2</label>
|
||||||
|
</div>
|
||||||
|
${FloatField({label: "Mirostat tau", max: 10.0, min: 0.0, name: "mirostat_tau", step: 0.01, value: params.value.mirostat_tau})}
|
||||||
|
${FloatField({label: "Mirostat eta", max: 1.0, min: 0.0, name: "mirostat_eta", step: 0.01, value: params.value.mirostat_eta})}
|
||||||
|
</fieldset>
|
||||||
|
</details>
|
||||||
</form>
|
</form>
|
||||||
`
|
`
|
||||||
}
|
}
|
||||||
// poor mans markdown replacement
|
// poor mans markdown replacement
|
||||||
const Markdownish = (params) => {
|
const Markdownish = (params) => {
|
||||||
const md = params.text
|
const md = params.text
|
||||||
|
.replace(/&/g, '&')
|
||||||
|
.replace(/</g, '<')
|
||||||
|
.replace(/>/g, '>')
|
||||||
.replace(/^#{1,6} (.*)$/gim, '<h3>$1</h3>')
|
.replace(/^#{1,6} (.*)$/gim, '<h3>$1</h3>')
|
||||||
.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')
|
.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')
|
||||||
.replace(/__(.*?)__/g, '<strong>$1</strong>')
|
.replace(/__(.*?)__/g, '<strong>$1</strong>')
|
||||||
|
|
|
@ -609,6 +609,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
||||||
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
|
||||||
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
|
||||||
|
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
|
||||||
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
|
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
|
||||||
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
|
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
|
||||||
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
||||||
|
@ -734,6 +735,14 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
||||||
}
|
}
|
||||||
params.n_gqa = std::stoi(argv[i]);
|
params.n_gqa = std::stoi(argv[i]);
|
||||||
}
|
}
|
||||||
|
else if (arg == "-eps" || arg == "--rms-norm-eps") {
|
||||||
|
if (++i >= argc)
|
||||||
|
{
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.rms_norm_eps = std::stof(argv[i]);
|
||||||
|
}
|
||||||
else if (arg == "--rope-freq-base")
|
else if (arg == "--rope-freq-base")
|
||||||
{
|
{
|
||||||
if (++i >= argc)
|
if (++i >= argc)
|
||||||
|
|
|
@ -16,6 +16,8 @@
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static const float rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
struct random_normal_distribution {
|
struct random_normal_distribution {
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
std::normal_distribution<float> rd;
|
std::normal_distribution<float> rd;
|
||||||
|
@ -439,7 +441,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
cur = ggml_mul(ctx0,
|
cur = ggml_mul(ctx0,
|
||||||
|
@ -562,7 +564,7 @@ struct ggml_tensor * forward(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
// cur shape [n_embd,N,1,1]
|
// cur shape [n_embd,N,1,1]
|
||||||
|
@ -606,7 +608,7 @@ struct ggml_tensor * forward(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
// inpL shape [n_embd,N,1,1]
|
// inpL shape [n_embd,N,1,1]
|
||||||
|
@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N*n_batch,1,1]
|
// inpL shape [n_embd,N*n_batch,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
// cur shape [n_embd,N*n_batch,1,1]
|
// cur shape [n_embd,N*n_batch,1,1]
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
||||||
{
|
{
|
||||||
|
|
||||||
// inpL shape [n_embd,N*n_batch,1,1]
|
// inpL shape [n_embd,N*n_batch,1,1]
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = attention_norm*cur
|
// cur = attention_norm*cur
|
||||||
|
@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
assert_shape_2d(cur, n_embd, N*n_batch);
|
assert_shape_2d(cur, n_embd, N*n_batch);
|
||||||
|
|
||||||
// cur = ffn_norm*cur
|
// cur = ffn_norm*cur
|
||||||
|
@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
|
|
||||||
inpL = ggml_rms_norm(ctx0, inpL);
|
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||||
|
|
||||||
// inpL = norm*inpL
|
// inpL = norm*inpL
|
||||||
|
@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
struct my_llama_layer & layer = model->layers[il];
|
struct my_llama_layer & layer = model->layers[il];
|
||||||
// tensors with values necessary for backward pass are in persistent buf(-1)
|
// tensors with values necessary for backward pass are in persistent buf(-1)
|
||||||
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||||
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||||
|
@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||||
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||||
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||||
|
@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
}
|
}
|
||||||
clr_buf(0);
|
clr_buf(0);
|
||||||
use_buf(0);
|
use_buf(0);
|
||||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
|
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||||
use_buf(-1);
|
use_buf(-1);
|
||||||
|
|
105
ggml-cuda.cu
105
ggml-cuda.cu
|
@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
|
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
const float eps = 1e-6f;
|
|
||||||
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
float tmp = 0.0f; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
for (int col = tid; col < ncols; col += WARP_SIZE) {
|
||||||
|
@ -1566,12 +1564,14 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||||
|
|
||||||
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
|
||||||
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
|
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_m = 0.0f;
|
||||||
|
|
||||||
|
#ifndef GGML_QKK_64
|
||||||
|
|
||||||
|
// iqs is in 0...15. bq8_offset = 2 * (iqs/4) -> bq8_offset = 0, 2, 4, 6
|
||||||
|
const int bq8_offset = QR4_K * (iqs / (QI8_1/2));
|
||||||
|
|
||||||
const float d = bq4_K->d;
|
const float d = bq4_K->d;
|
||||||
const float dmin = bq4_K->dmin;
|
const float dmin = bq4_K->dmin;
|
||||||
|
|
||||||
|
@ -1616,6 +1616,43 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
return d*sumf_d - dmin*sumf_m;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
uint16_t aux16[2];
|
||||||
|
const uint8_t * s = (const uint8_t *)aux16;
|
||||||
|
|
||||||
|
const uint16_t * a = (const uint16_t *)bq4_K->scales;
|
||||||
|
aux16[0] = a[0] & 0x0f0f;
|
||||||
|
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
||||||
|
|
||||||
|
const float dall = bq4_K->d[0];
|
||||||
|
const float dmin = bq4_K->d[1];
|
||||||
|
|
||||||
|
const float d8_1 = bq8_1[0].d;
|
||||||
|
const float d8_2 = bq8_1[1].d;
|
||||||
|
|
||||||
|
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
|
||||||
|
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
|
||||||
|
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
|
||||||
|
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
|
||||||
|
|
||||||
|
const int * q4 = (const int *)bq4_K->qs + iqs;
|
||||||
|
const int v1 = q4[0];
|
||||||
|
const int v2 = q4[4];
|
||||||
|
|
||||||
|
const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0));
|
||||||
|
const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0));
|
||||||
|
const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0));
|
||||||
|
const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0));
|
||||||
|
|
||||||
|
sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]);
|
||||||
|
sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]);
|
||||||
|
|
||||||
|
return dall * sumf_d - dmin * sumf_m;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
return 0.0f; // only to satisfy the compiler
|
return 0.0f; // only to satisfy the compiler
|
||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
@ -1627,6 +1664,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||||
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
||||||
|
|
||||||
|
#ifndef GGML_QKK_64
|
||||||
|
|
||||||
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
|
const int bq8_offset = QR5_K * (iqs / (QI8_1/2));
|
||||||
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * (iqs%4));
|
||||||
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
const int * qh = (const int *)(bq5_K->qh + 4 * (iqs%4));
|
||||||
|
@ -1682,6 +1721,42 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||||
}
|
}
|
||||||
|
|
||||||
return d*sumf_d - dmin*sumf_m;
|
return d*sumf_d - dmin*sumf_m;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
const int8_t * s = bq5_K->scales;
|
||||||
|
|
||||||
|
const float d = bq5_K->d;
|
||||||
|
|
||||||
|
const float d8_1 = bq8_1[0].d;
|
||||||
|
const float d8_2 = bq8_1[1].d;
|
||||||
|
|
||||||
|
const int ui1 = *((const int *)bq8_1[0].qs + iqs);
|
||||||
|
const int ui2 = *((const int *)bq8_1[0].qs + iqs + 4);
|
||||||
|
const int ui3 = *((const int *)bq8_1[1].qs + iqs);
|
||||||
|
const int ui4 = *((const int *)bq8_1[1].qs + iqs + 4);
|
||||||
|
|
||||||
|
const int * ql = (const int *)bq5_K->qs + iqs;
|
||||||
|
const int vl1 = ql[0];
|
||||||
|
const int vl2 = ql[4];
|
||||||
|
|
||||||
|
const int step = 4 * iqs; // 0, 4, 8, 12
|
||||||
|
const int im = step/8; // = 0 for iqs = 0, 1, = 1 for iqs = 2, 3
|
||||||
|
const int in = step%8; // 0, 4, 0, 4
|
||||||
|
const int vh = (*((const int *)(bq5_K->qh + in))) >> im;
|
||||||
|
|
||||||
|
const int v1 = (((vh << 4) & 0x10101010) ^ 0x10101010) | ((vl1 >> 0) & 0x0f0f0f0f);
|
||||||
|
const int v2 = (((vh << 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 0) & 0x0f0f0f0f);
|
||||||
|
const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f);
|
||||||
|
const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f);
|
||||||
|
|
||||||
|
const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1])
|
||||||
|
+ d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]);
|
||||||
|
|
||||||
|
return d * sumf_d;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
return 0.0f; // only to satisfy the compiler
|
return 0.0f; // only to satisfy the compiler
|
||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
||||||
|
@ -2122,10 +2197,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
|
||||||
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||||
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
|
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
|
||||||
|
@ -2876,8 +2951,11 @@ inline void ggml_cuda_op_rms_norm(
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t i01_diff = i01_high - i01_low;
|
const int64_t i01_diff = i01_high - i01_low;
|
||||||
|
|
||||||
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
@ -3962,18 +4040,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|
||||||
}
|
}
|
||||||
func = ggml_cuda_mul;
|
func = ggml_cuda_mul;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_GELU:
|
case GGML_OP_UNARY:
|
||||||
|
switch (ggml_get_unary_op(tensor)) {
|
||||||
|
case GGML_UNARY_OP_GELU:
|
||||||
if (!any_on_device) {
|
if (!any_on_device) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
func = ggml_cuda_gelu;
|
func = ggml_cuda_gelu;
|
||||||
break;
|
break;
|
||||||
case GGML_OP_SILU:
|
case GGML_UNARY_OP_SILU:
|
||||||
if (!any_on_device) {
|
if (!any_on_device) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
func = ggml_cuda_silu;
|
func = ggml_cuda_silu;
|
||||||
break;
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
} break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
if (!any_on_device) {
|
if (!any_on_device) {
|
||||||
return false;
|
return false;
|
||||||
|
|
19
ggml-metal.m
19
ggml-metal.m
|
@ -629,7 +629,9 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SILU:
|
case GGML_OP_UNARY:
|
||||||
|
switch (ggml_get_unary_op(gf->nodes[i])) {
|
||||||
|
case GGML_UNARY_OP_SILU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||||
|
@ -643,7 +645,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_RELU:
|
case GGML_UNARY_OP_RELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||||
|
@ -657,7 +659,7 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GELU:
|
case GGML_UNARY_OP_GELU:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||||
|
@ -671,6 +673,12 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
if (encoder == nil) {
|
if (encoder == nil) {
|
||||||
|
@ -914,7 +922,8 @@ void ggml_metal_graph_compute(
|
||||||
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
encoder = [command_buffer computeCommandEncoderWithDescriptor: encoder_descriptor];
|
||||||
}
|
}
|
||||||
|
|
||||||
const float eps = 1e-6f;
|
float eps;
|
||||||
|
memcpy(&eps, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nth = 512;
|
const int nth = 512;
|
||||||
|
|
||||||
|
@ -1089,10 +1098,12 @@ void ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
{
|
||||||
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (encoder != nil) {
|
if (encoder != nil) {
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
|
|
113
ggml-metal.metal
113
ggml-metal.metal
|
@ -387,87 +387,90 @@ kernel void kernel_rms_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float4 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||||
for (int i = 0; i < 16; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
return d * (sumy * -8.f + acc[0] + acc[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
||||||
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float m = qb_curr->m;
|
float m = qb_curr->m;
|
||||||
float4 acc = 0.f;
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||||
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
float2 acc = 0.f;
|
||||||
for (int i = 0; i < 16; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||||
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||||
}
|
}
|
||||||
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
return d * (acc[0] + acc[1]) + sumy * m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
template<typename block_q_type>
|
//Note: This is a template, but strictly speaking it only applies to
|
||||||
|
// quantizations where the block size is 32. It also does not
|
||||||
|
// giard against the number of rows not being divisible by
|
||||||
|
// N_DST, so this is another explicit assumption of the implementation.
|
||||||
|
template<typename block_q_type, int nr, int nsg, int nw>
|
||||||
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
||||||
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
||||||
uint2 tgpig, uint tiisg, uint sgitg) {
|
uint2 tgpig, uint tiisg, uint sgitg) {
|
||||||
const int nb = ne00/QK4_0;
|
const int nb = ne00/QK4_0;
|
||||||
const int r0 = tgpig.x;
|
const int r0 = tgpig.x;
|
||||||
const int r1 = tgpig.y;
|
const int r1 = tgpig.y;
|
||||||
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
const int first_row = (r0 * nsg + sgitg) * nr;
|
||||||
|
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
float4 y_curr[8]; // src1 vector cache
|
float yl[16]; // src1 vector cache
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
float sumf[nr]={0.f};
|
||||||
thread float * yl=(thread float *)y_curr;
|
|
||||||
|
|
||||||
// each thread in a SIMD group deals with 1 block.
|
const int ix = tiisg/2;
|
||||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
const int il = 8*(tiisg%2);
|
||||||
|
|
||||||
|
device const float * yb = y + ix * QK4_0 + il;
|
||||||
|
|
||||||
|
// each thread in a SIMD group deals with half a block.
|
||||||
|
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||||
float sumy = 0;
|
float sumy = 0;
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
for (int i = 0; i < 8; i += 2) {
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
sumy += yb[i] + yb[i+1];
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
yl[i+0] = yb[i+ 0];
|
||||||
|
yl[i+1] = yb[i+ 1]/256.f;
|
||||||
|
sumy += yb[i+16] + yb[i+17];
|
||||||
|
yl[i+8] = yb[i+16]/16.f;
|
||||||
|
yl[i+9] = yb[i+17]/4096.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < nr; row++) {
|
||||||
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// from now loads two rows every time and 16 blocks per row
|
yb += QK4_0 * 16;
|
||||||
int ir = tiisg / (N_SIMDWIDTH / 2);
|
|
||||||
int ib = tiisg % (N_SIMDWIDTH / 2);
|
|
||||||
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
|
||||||
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row+=2) {
|
for (int row = 0; row < nr; ++row) {
|
||||||
if (nb_start + ib < nb) {
|
const float tot = simd_sum(sumf[row]);
|
||||||
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
if (tiisg == 0 && first_row + row < ne01) {
|
||||||
}
|
dst[r1*ne0 + first_row + row] = tot;
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -483,7 +486,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
|
@ -497,7 +500,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
|
|
61
ggml.h
61
ggml.h
|
@ -330,16 +330,6 @@ extern "C" {
|
||||||
GGML_OP_ARGMAX,
|
GGML_OP_ARGMAX,
|
||||||
GGML_OP_REPEAT,
|
GGML_OP_REPEAT,
|
||||||
GGML_OP_REPEAT_BACK,
|
GGML_OP_REPEAT_BACK,
|
||||||
GGML_OP_ABS,
|
|
||||||
GGML_OP_SGN,
|
|
||||||
GGML_OP_NEG,
|
|
||||||
GGML_OP_STEP,
|
|
||||||
GGML_OP_TANH,
|
|
||||||
GGML_OP_ELU,
|
|
||||||
GGML_OP_RELU,
|
|
||||||
GGML_OP_GELU,
|
|
||||||
GGML_OP_GELU_QUICK,
|
|
||||||
GGML_OP_SILU,
|
|
||||||
GGML_OP_SILU_BACK,
|
GGML_OP_SILU_BACK,
|
||||||
GGML_OP_NORM, // normalize
|
GGML_OP_NORM, // normalize
|
||||||
GGML_OP_RMS_NORM,
|
GGML_OP_RMS_NORM,
|
||||||
|
@ -378,6 +368,8 @@ extern "C" {
|
||||||
GGML_OP_WIN_PART,
|
GGML_OP_WIN_PART,
|
||||||
GGML_OP_WIN_UNPART,
|
GGML_OP_WIN_UNPART,
|
||||||
|
|
||||||
|
GGML_OP_UNARY,
|
||||||
|
|
||||||
GGML_OP_MAP_UNARY,
|
GGML_OP_MAP_UNARY,
|
||||||
GGML_OP_MAP_BINARY,
|
GGML_OP_MAP_BINARY,
|
||||||
|
|
||||||
|
@ -391,6 +383,18 @@ extern "C" {
|
||||||
GGML_OP_COUNT,
|
GGML_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_unary_op {
|
||||||
|
GGML_UNARY_OP_ABS,
|
||||||
|
GGML_UNARY_OP_SGN,
|
||||||
|
GGML_UNARY_OP_NEG,
|
||||||
|
GGML_UNARY_OP_STEP,
|
||||||
|
GGML_UNARY_OP_TANH,
|
||||||
|
GGML_UNARY_OP_ELU,
|
||||||
|
GGML_UNARY_OP_RELU,
|
||||||
|
GGML_UNARY_OP_GELU,
|
||||||
|
GGML_UNARY_OP_GELU_QUICK,
|
||||||
|
GGML_UNARY_OP_SILU,
|
||||||
|
};
|
||||||
|
|
||||||
// ggml object
|
// ggml object
|
||||||
struct ggml_object {
|
struct ggml_object {
|
||||||
|
@ -535,6 +539,7 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API const char * ggml_type_name(enum ggml_type type);
|
GGML_API const char * ggml_type_name(enum ggml_type type);
|
||||||
GGML_API const char * ggml_op_name (enum ggml_op op);
|
GGML_API const char * ggml_op_name (enum ggml_op op);
|
||||||
|
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||||
|
|
||||||
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
@ -558,6 +563,7 @@ extern "C" {
|
||||||
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
|
||||||
|
|
||||||
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
|
||||||
|
GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
|
||||||
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
|
||||||
|
|
||||||
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
|
||||||
|
@ -617,6 +623,8 @@ extern "C" {
|
||||||
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
||||||
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
|
||||||
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
||||||
|
@ -629,6 +637,11 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// in-place, returns view(a)
|
||||||
|
GGML_API struct ggml_tensor * ggml_dup_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_add(
|
GGML_API struct ggml_tensor * ggml_add(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -853,14 +866,17 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm(
|
GGML_API struct ggml_tensor * ggml_rms_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a,
|
||||||
|
float eps);
|
||||||
|
|
||||||
// a - x
|
// a - x
|
||||||
// b - dy
|
// b - dy
|
||||||
|
// TODO: update with configurable eps
|
||||||
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
GGML_API struct ggml_tensor * ggml_rms_norm_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -952,11 +968,22 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b);
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
// a -> b, in-place, return view(b)
|
||||||
|
GGML_API struct ggml_tensor * ggml_cpy_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// make contiguous
|
// make contiguous
|
||||||
GGML_API struct ggml_tensor * ggml_cont(
|
GGML_API struct ggml_tensor * ggml_cont(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// make contiguous, in-place
|
||||||
|
GGML_API struct ggml_tensor * ggml_cont_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
// return view(a), b specifies the new shape
|
// return view(a), b specifies the new shape
|
||||||
// TODO: when we start computing gradient, make a copy instead of view
|
// TODO: when we start computing gradient, make a copy instead of view
|
||||||
GGML_API struct ggml_tensor * ggml_reshape(
|
GGML_API struct ggml_tensor * ggml_reshape(
|
||||||
|
@ -1268,6 +1295,16 @@ extern "C" {
|
||||||
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||||
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_unary(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_unary_inplace(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_unary_op op);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
GGML_API struct ggml_tensor * ggml_map_unary_f32(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
6
grammars/arithmetic.gbnf
Normal file
6
grammars/arithmetic.gbnf
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
root ::= (expr "=" ws term "\n")+
|
||||||
|
expr ::= term ([-+*/] term)*
|
||||||
|
term ::= ident | num | "(" ws expr ")" ws
|
||||||
|
ident ::= [a-z] [a-z0-9_]* ws
|
||||||
|
num ::= [0-9]+ ws
|
||||||
|
ws ::= [ \t\n]*
|
13
grammars/chess.gbnf
Normal file
13
grammars/chess.gbnf
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# Specifies chess moves as a list in algebraic notation, using PGN conventions
|
||||||
|
|
||||||
|
# Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern
|
||||||
|
root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+
|
||||||
|
move ::= (pawn | nonpawn | castle) [+#]?
|
||||||
|
|
||||||
|
# piece type, optional file/rank, optional capture, dest file & rank
|
||||||
|
nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8]
|
||||||
|
|
||||||
|
# optional file & capture, dest file & rank, optional promotion
|
||||||
|
pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])?
|
||||||
|
|
||||||
|
castle ::= "O-O" "-O"?
|
7
grammars/japanese.gbnf
Normal file
7
grammars/japanese.gbnf
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# A probably incorrect grammar for Japanese
|
||||||
|
root ::= jp-char+ ([ \t\n] jp-char+)*
|
||||||
|
jp-char ::= hiragana | katakana | punctuation | cjk
|
||||||
|
hiragana ::= [ぁ-ゟ]
|
||||||
|
katakana ::= [ァ-ヿ]
|
||||||
|
punctuation ::= [、-〾]
|
||||||
|
cjk ::= [一-鿿]
|
29
grammars/json.gbnf
Normal file
29
grammars/json.gbnf
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Grammar for subset of JSON - doesn't support full string or number syntax
|
||||||
|
|
||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | boolean | "null"
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}"
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]"
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
# Only plain integers currently
|
||||||
|
number ::= "-"? [0-9]+ ws
|
||||||
|
boolean ::= ("true" | "false") ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
4
grammars/list.gbnf
Normal file
4
grammars/list.gbnf
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
root ::= item+
|
||||||
|
|
||||||
|
# Excludes various line break characters
|
||||||
|
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
|
@ -3297,8 +3297,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
|
int8_t aux8[QK_K];
|
||||||
uint8_t aux8[QK_K];
|
|
||||||
int16_t aux16[16];
|
int16_t aux16[16];
|
||||||
float sums [8];
|
float sums [8];
|
||||||
memset(sums, 0, 8*sizeof(float));
|
memset(sums, 0, 8*sizeof(float));
|
||||||
|
@ -3308,7 +3307,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const uint8_t * restrict q4 = x[i].qs;
|
const uint8_t * restrict q4 = x[i].qs;
|
||||||
const uint8_t * restrict hm = x[i].qh;
|
const uint8_t * restrict hm = x[i].qh;
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
uint8_t * restrict a = aux8;
|
int8_t * restrict a = aux8;
|
||||||
for (int l = 0; l < 32; ++l) {
|
for (int l = 0; l < 32; ++l) {
|
||||||
a[l+ 0] = q4[l] & 0xF;
|
a[l+ 0] = q4[l] & 0xF;
|
||||||
a[l+32] = q4[l] >> 4;
|
a[l+32] = q4[l] >> 4;
|
||||||
|
|
357
llama.cpp
357
llama.cpp
|
@ -186,6 +186,7 @@ struct llama_hparams {
|
||||||
// LLaMAv2
|
// LLaMAv2
|
||||||
// TODO: load from model data hparams
|
// TODO: load from model data hparams
|
||||||
float f_ffn_mult = 1.0f;
|
float f_ffn_mult = 1.0f;
|
||||||
|
float f_rms_norm_eps = 1e-6f;
|
||||||
|
|
||||||
float rope_freq_base = 10000.0f;
|
float rope_freq_base = 10000.0f;
|
||||||
float rope_freq_scale = 1.0f;
|
float rope_freq_scale = 1.0f;
|
||||||
|
@ -869,6 +870,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
/*.n_ctx =*/ 512,
|
/*.n_ctx =*/ 512,
|
||||||
/*.n_batch =*/ 512,
|
/*.n_batch =*/ 512,
|
||||||
/*.n_gqa =*/ 1,
|
/*.n_gqa =*/ 1,
|
||||||
|
/*.rms_norm_eps =*/ 1e-6f,
|
||||||
/*.gpu_layers =*/ 0,
|
/*.gpu_layers =*/ 0,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
|
@ -1000,6 +1002,7 @@ static void llama_model_load_internal(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
int n_gqa,
|
int n_gqa,
|
||||||
|
float rms_norm_eps,
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
@ -1024,6 +1027,9 @@ static void llama_model_load_internal(
|
||||||
|
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
|
// TODO: read from file
|
||||||
|
hparams.f_rms_norm_eps = rms_norm_eps;
|
||||||
|
|
||||||
{
|
{
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 26: model.type = e_model::MODEL_3B; break;
|
case 26: model.type = e_model::MODEL_3B; break;
|
||||||
|
@ -1072,6 +1078,7 @@ static void llama_model_load_internal(
|
||||||
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||||
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
|
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
|
||||||
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
fprintf(stderr, "%s: n_gqa = %u\n", __func__, hparams.n_gqa());
|
||||||
|
fprintf(stderr, "%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps);
|
||||||
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
|
||||||
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
fprintf(stderr, "%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
|
||||||
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
fprintf(stderr, "%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
|
||||||
|
@ -1330,6 +1337,7 @@ static bool llama_model_load(
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
int n_gqa,
|
int n_gqa,
|
||||||
|
float rms_norm_eps,
|
||||||
int n_gpu_layers,
|
int n_gpu_layers,
|
||||||
int main_gpu,
|
int main_gpu,
|
||||||
const float * tensor_split,
|
const float * tensor_split,
|
||||||
|
@ -1343,7 +1351,7 @@ static bool llama_model_load(
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void *progress_callback_user_data) {
|
void *progress_callback_user_data) {
|
||||||
try {
|
try {
|
||||||
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type,
|
||||||
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
|
@ -1396,10 +1404,12 @@ static bool llama_eval_internal(
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
const int64_t n_embd_gqa = hparams.n_embd_gqa();
|
||||||
|
|
||||||
|
|
||||||
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
LLAMA_ASSERT(n_embd_head == hparams.n_rot);
|
||||||
|
|
||||||
const float freq_base = hparams.rope_freq_base;
|
const float freq_base = hparams.rope_freq_base;
|
||||||
const float freq_scale = hparams.rope_freq_scale;
|
const float freq_scale = hparams.rope_freq_scale;
|
||||||
|
const float rms_norm_eps = hparams.f_rms_norm_eps;
|
||||||
|
|
||||||
const int n_gpu_layers = model.n_gpu_layers;
|
const int n_gpu_layers = model.n_gpu_layers;
|
||||||
|
|
||||||
|
@ -1479,7 +1489,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "rms_norm_0");
|
ggml_set_name(cur, "rms_norm_0");
|
||||||
|
|
||||||
|
@ -1627,7 +1637,7 @@ static bool llama_eval_internal(
|
||||||
{
|
{
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpFF);
|
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "rms_norm_1");
|
ggml_set_name(cur, "rms_norm_1");
|
||||||
|
|
||||||
|
@ -1680,7 +1690,7 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
|
||||||
offload_func_nr(cur);
|
offload_func_nr(cur);
|
||||||
ggml_set_name(cur, "rms_norm_2");
|
ggml_set_name(cur, "rms_norm_2");
|
||||||
|
|
||||||
|
@ -1968,6 +1978,279 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// grammar - internal
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar {
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> rules;
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_grammar_candidate {
|
||||||
|
size_t index;
|
||||||
|
const uint32_t * code_points;
|
||||||
|
};
|
||||||
|
|
||||||
|
// NOTE: assumes valid utf8 (but checks for overrun)
|
||||||
|
// adds a terminating 0 for use as pointer
|
||||||
|
std::vector<uint32_t> decode_utf8(const char * src) {
|
||||||
|
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
|
||||||
|
const char * pos = src;
|
||||||
|
std::vector<uint32_t> code_points;
|
||||||
|
while (*pos != 0) {
|
||||||
|
uint8_t first_byte = static_cast<uint8_t>(*pos);
|
||||||
|
uint8_t highbits = first_byte >> 4;
|
||||||
|
int len = lookup[highbits];
|
||||||
|
uint8_t mask = (1 << (8 - len)) - 1;
|
||||||
|
uint32_t value = first_byte & mask;
|
||||||
|
const char * end = pos + len; // may overrun!
|
||||||
|
++pos;
|
||||||
|
for ( ; pos < end && *pos != 0; ++pos) {
|
||||||
|
value = (value << 6) + (static_cast<uint8_t>(*pos) & 0x3F);
|
||||||
|
}
|
||||||
|
code_points.push_back(value);
|
||||||
|
}
|
||||||
|
code_points.push_back(0);
|
||||||
|
return code_points;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true iff pos points to the end of one of the definitions of a rule
|
||||||
|
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
|
||||||
|
switch (pos->type) {
|
||||||
|
case LLAMA_GRETYPE_END: return true;
|
||||||
|
case LLAMA_GRETYPE_ALT: return true;
|
||||||
|
default: return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns true iff chr satisfies the char range at pos (regular or inverse range)
|
||||||
|
// asserts that pos is pointing to a char range element
|
||||||
|
static std::pair<bool, const llama_grammar_element *> llama_grammar_match_char(
|
||||||
|
const llama_grammar_element * pos,
|
||||||
|
const uint32_t chr) {
|
||||||
|
|
||||||
|
bool found = false;
|
||||||
|
bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR;
|
||||||
|
LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT);
|
||||||
|
|
||||||
|
do {
|
||||||
|
if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) {
|
||||||
|
// inclusive range, e.g. [a-z]
|
||||||
|
found = found || (pos->value <= chr && chr <= pos[1].value);
|
||||||
|
pos += 2;
|
||||||
|
} else {
|
||||||
|
// exact char match, e.g. [a] or "a"
|
||||||
|
found = found || pos->value == chr;
|
||||||
|
pos += 1;
|
||||||
|
}
|
||||||
|
} while (pos->type == LLAMA_GRETYPE_CHAR_ALT);
|
||||||
|
|
||||||
|
return std::make_pair(found == is_positive_char, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// transforms a grammar pushdown stack into N possible stacks, all ending
|
||||||
|
// at a character range (terminal element)
|
||||||
|
static void llama_grammar_advance_stack(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<const llama_grammar_element *> & stack,
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> & new_stacks) {
|
||||||
|
|
||||||
|
if (stack.empty()) {
|
||||||
|
new_stacks.push_back(stack);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * pos = stack.back();
|
||||||
|
|
||||||
|
switch (pos->type) {
|
||||||
|
case LLAMA_GRETYPE_RULE_REF: {
|
||||||
|
const size_t rule_id = static_cast<size_t>(pos->value);
|
||||||
|
const llama_grammar_element * subpos = rules[rule_id].data();
|
||||||
|
do {
|
||||||
|
// init new stack without the top (pos)
|
||||||
|
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos + 1)) {
|
||||||
|
// if this rule ref is followed by another element, add that to stack
|
||||||
|
new_stack.push_back(pos + 1);
|
||||||
|
}
|
||||||
|
if (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
new_stack.push_back(subpos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(subpos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
subpos++;
|
||||||
|
}
|
||||||
|
if (subpos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
subpos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_GRETYPE_CHAR:
|
||||||
|
case LLAMA_GRETYPE_CHAR_NOT:
|
||||||
|
new_stacks.push_back(stack);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range
|
||||||
|
// (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on
|
||||||
|
// those
|
||||||
|
LLAMA_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// takes a set of possible pushdown stacks on a grammar, which are required to
|
||||||
|
// be positioned at a character range (see `llama_grammar_advance_stack`), and
|
||||||
|
// produces the N possible stacks if the given char is accepted at those
|
||||||
|
// positions
|
||||||
|
static std::vector<std::vector<const llama_grammar_element *>> llama_grammar_accept(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
|
const uint32_t chr) {
|
||||||
|
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> new_stacks;
|
||||||
|
|
||||||
|
for (const auto & stack : stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto match = llama_grammar_match_char(stack.back(), chr);
|
||||||
|
if (match.first) {
|
||||||
|
const llama_grammar_element * pos = match.second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
std::vector<const llama_grammar_element *> new_stack(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
new_stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(rules, new_stack, new_stacks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new_stacks;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
|
const std::vector<llama_grammar_candidate> & candidates);
|
||||||
|
|
||||||
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<const llama_grammar_element *> & stack,
|
||||||
|
const std::vector<llama_grammar_candidate> & candidates) {
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> rejects;
|
||||||
|
|
||||||
|
if (stack.empty()) {
|
||||||
|
// accept nothing; EOS is handled elsewhere
|
||||||
|
rejects.insert(rejects.end(), candidates.begin(), candidates.end());
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_grammar_element * stack_pos = stack.back();
|
||||||
|
|
||||||
|
std::vector<llama_grammar_candidate> next_candidates;
|
||||||
|
for (auto tok : candidates) {
|
||||||
|
if (llama_grammar_match_char(stack_pos, tok.code_points[0]).first) {
|
||||||
|
if (tok.code_points[1] != 0) {
|
||||||
|
next_candidates.push_back({ tok.index, tok.code_points + 1 });
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rejects.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second;
|
||||||
|
|
||||||
|
// update top of stack to next element, if any
|
||||||
|
std::vector<const llama_grammar_element *> stack_after(stack.begin(), stack.end() - 1);
|
||||||
|
if (!llama_grammar_is_end_of_sequence(stack_pos_after)) {
|
||||||
|
stack_after.push_back(stack_pos_after);
|
||||||
|
}
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> next_stacks;
|
||||||
|
llama_grammar_advance_stack(rules, stack_after, next_stacks);
|
||||||
|
|
||||||
|
auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
|
||||||
|
for (auto tok : next_rejects) {
|
||||||
|
rejects.push_back({ tok.index, tok.code_points - 1 });
|
||||||
|
}
|
||||||
|
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
|
||||||
|
const std::vector<std::vector<llama_grammar_element>> & rules,
|
||||||
|
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
|
||||||
|
const std::vector<llama_grammar_candidate> & candidates) {
|
||||||
|
LLAMA_ASSERT(!stacks.empty()); // REVIEW
|
||||||
|
|
||||||
|
if (candidates.empty()) {
|
||||||
|
return std::vector<llama_grammar_candidate>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates);
|
||||||
|
|
||||||
|
for (size_t i = 1, size = stacks.size(); i < size; ++i) {
|
||||||
|
rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects);
|
||||||
|
}
|
||||||
|
return rejects;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// grammar - external
|
||||||
|
//
|
||||||
|
|
||||||
|
struct llama_grammar * llama_grammar_init(
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index) {
|
||||||
|
const llama_grammar_element * pos;
|
||||||
|
|
||||||
|
// copy rule definitions into vectors
|
||||||
|
std::vector<std::vector<llama_grammar_element>> vec_rules(n_rules);
|
||||||
|
for (size_t i = 0; i < n_rules; i++) {
|
||||||
|
for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) {
|
||||||
|
vec_rules[i].push_back(*pos);
|
||||||
|
}
|
||||||
|
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// loop over alternates of start rule to build initial stacks
|
||||||
|
std::vector<std::vector<const llama_grammar_element *>> stacks;
|
||||||
|
pos = rules[start_rule_index];
|
||||||
|
do {
|
||||||
|
std::vector<const llama_grammar_element *> stack;
|
||||||
|
if (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// if alternate is nonempty, add to stack
|
||||||
|
stack.push_back(pos);
|
||||||
|
}
|
||||||
|
llama_grammar_advance_stack(vec_rules, stack, stacks);
|
||||||
|
while (!llama_grammar_is_end_of_sequence(pos)) {
|
||||||
|
// scan to end of alternate def
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
if (pos->type == LLAMA_GRETYPE_ALT) {
|
||||||
|
// there's another alternate def of this rule to process
|
||||||
|
pos++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} while (true);
|
||||||
|
|
||||||
|
return new llama_grammar{ std::move(vec_rules), std::move(stacks) };
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_grammar_free(struct llama_grammar * grammar) {
|
||||||
|
delete grammar;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// sampling
|
// sampling
|
||||||
//
|
//
|
||||||
|
@ -2253,6 +2536,47 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
|
||||||
|
assert(ctx);
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
bool allow_eos = false;
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
allow_eos = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_token eos = llama_token_eos();
|
||||||
|
|
||||||
|
std::vector<std::vector<uint32_t>> candidates_decoded;
|
||||||
|
std::vector<llama_grammar_candidate> candidates_grammar;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
|
const llama_token id = candidates->data[i].id;
|
||||||
|
const char * str = llama_token_to_str(ctx, id);
|
||||||
|
if (id == eos) {
|
||||||
|
if (!allow_eos) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
} else if (*str == 0) {
|
||||||
|
candidates->data[i].logit = -INFINITY;
|
||||||
|
} else {
|
||||||
|
candidates_decoded.push_back(decode_utf8(str));
|
||||||
|
candidates_grammar.push_back({ i, candidates_decoded.back().data() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto rejects =
|
||||||
|
llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
|
||||||
|
for (auto & reject : rejects) {
|
||||||
|
candidates->data[reject.index].logit = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_log_softmax(float * array, size_t size) {
|
static void llama_log_softmax(float * array, size_t size) {
|
||||||
float max_l = *std::max_element(array, array + size);
|
float max_l = *std::max_element(array, array + size);
|
||||||
float sum = 0.f;
|
float sum = 0.f;
|
||||||
|
@ -2428,6 +2752,29 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
|
||||||
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
|
if (token == llama_token_eos()) {
|
||||||
|
for (const auto & stack : grammar->stacks) {
|
||||||
|
if (stack.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
LLAMA_ASSERT(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * str = llama_token_to_str(ctx, token);
|
||||||
|
// Note terminating 0 in decoded string
|
||||||
|
auto code_points = decode_utf8(str);
|
||||||
|
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||||
|
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
|
||||||
|
}
|
||||||
|
LLAMA_ASSERT(!grammar->stacks.empty());
|
||||||
|
|
||||||
|
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// quantization
|
// quantization
|
||||||
//
|
//
|
||||||
|
@ -2750,7 +3097,7 @@ struct llama_model * llama_load_model_from_file(
|
||||||
|
|
||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.n_gpu_layers,
|
if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers,
|
||||||
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram,
|
||||||
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback,
|
||||||
params.progress_callback_user_data)) {
|
params.progress_callback_user_data)) {
|
||||||
|
|
50
llama.h
50
llama.h
|
@ -87,6 +87,7 @@ extern "C" {
|
||||||
int32_t n_ctx; // text context
|
int32_t n_ctx; // text context
|
||||||
int32_t n_batch; // prompt processing batch size
|
int32_t n_batch; // prompt processing batch size
|
||||||
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
|
int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams)
|
||||||
|
float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams)
|
||||||
int32_t n_gpu_layers; // number of layers to store in VRAM
|
int32_t n_gpu_layers; // number of layers to store in VRAM
|
||||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||||
|
|
||||||
|
@ -141,6 +142,40 @@ extern "C" {
|
||||||
bool quantize_output_tensor; // quantize output.weight
|
bool quantize_output_tensor; // quantize output.weight
|
||||||
} llama_model_quantize_params;
|
} llama_model_quantize_params;
|
||||||
|
|
||||||
|
// grammar types
|
||||||
|
struct llama_grammar;
|
||||||
|
|
||||||
|
// grammar element type
|
||||||
|
enum llama_gretype {
|
||||||
|
// end of rule definition
|
||||||
|
LLAMA_GRETYPE_END = 0,
|
||||||
|
|
||||||
|
// start of alternate definition for rule
|
||||||
|
LLAMA_GRETYPE_ALT = 1,
|
||||||
|
|
||||||
|
// non-terminal element: reference to rule
|
||||||
|
LLAMA_GRETYPE_RULE_REF = 2,
|
||||||
|
|
||||||
|
// terminal element: character (code point)
|
||||||
|
LLAMA_GRETYPE_CHAR = 3,
|
||||||
|
|
||||||
|
// inverse char(s) ([^a], [^a-b] [^abc])
|
||||||
|
LLAMA_GRETYPE_CHAR_NOT = 4,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
|
||||||
|
// be an inclusive range ([a-z])
|
||||||
|
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
|
||||||
|
|
||||||
|
// modifies a preceding LLAMA_GRETYPE_CHAR or
|
||||||
|
// LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
|
||||||
|
LLAMA_GRETYPE_CHAR_ALT = 6,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct llama_grammar_element {
|
||||||
|
enum llama_gretype type;
|
||||||
|
uint32_t value; // Unicode code point or rule ID
|
||||||
|
} llama_grammar_element;
|
||||||
|
|
||||||
// performance timing information
|
// performance timing information
|
||||||
struct llama_timings {
|
struct llama_timings {
|
||||||
double t_start_ms;
|
double t_start_ms;
|
||||||
|
@ -333,6 +368,15 @@ extern "C" {
|
||||||
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
|
||||||
LLAMA_API llama_token llama_token_nl(); // next-line
|
LLAMA_API llama_token llama_token_nl(); // next-line
|
||||||
|
|
||||||
|
// Grammar
|
||||||
|
//
|
||||||
|
LLAMA_API struct llama_grammar * llama_grammar_init(
|
||||||
|
const llama_grammar_element ** rules,
|
||||||
|
size_t n_rules,
|
||||||
|
size_t start_rule_index);
|
||||||
|
|
||||||
|
LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
|
||||||
|
|
||||||
// Sampling functions
|
// Sampling functions
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
|
@ -367,6 +411,9 @@ extern "C" {
|
||||||
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||||
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||||
|
|
||||||
|
/// @details Apply constraints from grammar
|
||||||
|
LLAMA_API void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar);
|
||||||
|
|
||||||
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||||
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||||
|
@ -388,6 +435,9 @@ extern "C" {
|
||||||
/// @details Randomly selects a token from the candidates based on their probabilities.
|
/// @details Randomly selects a token from the candidates based on their probabilities.
|
||||||
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
|
||||||
|
/// @details Accepts the sampled token into the grammar
|
||||||
|
LLAMA_API void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token);
|
||||||
|
|
||||||
// Performance information
|
// Performance information
|
||||||
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
|
||||||
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
LLAMA_API void llama_print_timings(struct llama_context * ctx);
|
||||||
|
|
|
@ -64,7 +64,7 @@ void get_random_dims(int64_t * dims, int ndims) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * get_random_tensor(
|
struct ggml_tensor * get_random_tensor_f32(
|
||||||
struct ggml_context * ctx0,
|
struct ggml_context * ctx0,
|
||||||
int ndims,
|
int ndims,
|
||||||
int64_t ne[],
|
int64_t ne[],
|
||||||
|
@ -112,7 +112,55 @@ struct ggml_tensor * get_random_tensor(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * get_random_tensor_int(
|
struct ggml_tensor * get_random_tensor_f16(
|
||||||
|
struct ggml_context * ctx0,
|
||||||
|
int ndims,
|
||||||
|
int64_t ne[],
|
||||||
|
float fmin,
|
||||||
|
float fmax) {
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne);
|
||||||
|
|
||||||
|
switch (ndims) {
|
||||||
|
case 1:
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
for (int i3 = 0; i3 < ne[3]; i3++) {
|
||||||
|
for (int i2 = 0; i2 < ne[2]; i2++) {
|
||||||
|
for (int i1 = 0; i1 < ne[1]; i1++) {
|
||||||
|
for (int i0 = 0; i0 < ne[0]; i0++) {
|
||||||
|
((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * get_random_tensor_i32(
|
||||||
struct ggml_context * ctx0,
|
struct ggml_context * ctx0,
|
||||||
int ndims,
|
int ndims,
|
||||||
int64_t ne[],
|
int64_t ne[],
|
||||||
|
@ -160,23 +208,6 @@ struct ggml_tensor * get_random_tensor_int(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
float get_element(const struct ggml_tensor * t, int idx) {
|
|
||||||
if (t->type == GGML_TYPE_F32) {
|
|
||||||
return ((float *)t->data)[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (t->type == GGML_TYPE_I32) {
|
|
||||||
return ((int32_t *)t->data)[idx];
|
|
||||||
}
|
|
||||||
|
|
||||||
assert(false);
|
|
||||||
return INFINITY;
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_element(struct ggml_tensor * t, int idx, float value) {
|
|
||||||
((float *)t->data)[idx] = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
void print_elements(const char* label, const struct ggml_tensor * t) {
|
void print_elements(const char* label, const struct ggml_tensor * t) {
|
||||||
if (!t) {
|
if (!t) {
|
||||||
printf("%s: %s = null\n", __func__, label);
|
printf("%s: %s = null\n", __func__, label);
|
||||||
|
@ -186,7 +217,7 @@ void print_elements(const char* label, const struct ggml_tensor * t) {
|
||||||
printf("%s: %s = [", __func__, label);
|
printf("%s: %s = [", __func__, label);
|
||||||
for (int k = 0; k < nelements; ++k) {
|
for (int k = 0; k < nelements; ++k) {
|
||||||
if (k > 0) { printf(", "); }
|
if (k > 0) { printf(", "); }
|
||||||
printf("%.5f", get_element(t, k));
|
printf("%.5f", ggml_get_f32_1d(t, k));
|
||||||
}
|
}
|
||||||
printf("] shape: [");
|
printf("] shape: [");
|
||||||
for (int k = 0; k < t->n_dims; ++k) {
|
for (int k = 0; k < t->n_dims; ++k) {
|
||||||
|
@ -237,23 +268,23 @@ bool check_gradient(
|
||||||
const int nelements = ggml_nelements(x[i]);
|
const int nelements = ggml_nelements(x[i]);
|
||||||
for (int k = 0; k < nelements; ++k) {
|
for (int k = 0; k < nelements; ++k) {
|
||||||
// compute gradient using finite differences
|
// compute gradient using finite differences
|
||||||
const float x0 = get_element(x[i], k);
|
const float x0 = ggml_get_f32_1d(x[i], k);
|
||||||
const float xm = x0 - eps;
|
const float xm = x0 - eps;
|
||||||
const float xp = x0 + eps;
|
const float xp = x0 + eps;
|
||||||
set_element(x[i], k, xp);
|
ggml_set_f32_1d(x[i], k, xp);
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
const float f0 = ggml_get_f32_1d(f, 0);
|
const float f0 = ggml_get_f32_1d(f, 0);
|
||||||
|
|
||||||
set_element(x[i], k, xm);
|
ggml_set_f32_1d(x[i], k, xm);
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gf, n_threads);
|
||||||
|
|
||||||
const float f1 = ggml_get_f32_1d(f, 0);
|
const float f1 = ggml_get_f32_1d(f, 0);
|
||||||
const float g0 = (f0 - f1)/(2.0f*eps);
|
const float g0 = (f0 - f1)/(2.0f*eps);
|
||||||
|
|
||||||
set_element(x[i], k, x0);
|
ggml_set_f32_1d(x[i], k, x0);
|
||||||
|
|
||||||
// compute gradient using backward graph
|
// compute gradient using backward graph
|
||||||
ggml_graph_reset (&gf);
|
ggml_graph_reset (&gf);
|
||||||
|
@ -261,7 +292,7 @@ bool check_gradient(
|
||||||
|
|
||||||
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
ggml_graph_compute_with_ctx(ctx0, &gb, n_threads);
|
||||||
|
|
||||||
const float g1 = get_element(x[i]->grad, k);
|
const float g1 = ggml_get_f32_1d(x[i]->grad, k);
|
||||||
|
|
||||||
const float error_abs = fabsf(g0 - g1);
|
const float error_abs = fabsf(g0 - g1);
|
||||||
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
const float error_rel = g0 != 0 ? fabsf(g0 - g1)/fabsf(g0) : 0;
|
||||||
|
@ -392,19 +423,35 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
struct ggml_tensor * x[MAX_NARGS];
|
struct ggml_tensor * x[MAX_NARGS];
|
||||||
|
|
||||||
// add
|
// add f32
|
||||||
{
|
{
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
check_gradient("add", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
|
check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add f16
|
||||||
|
{
|
||||||
|
const int nargs = 2;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
|
check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -414,7 +461,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -430,7 +477,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -446,7 +493,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, 0.5f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -462,7 +509,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -478,7 +525,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -494,7 +541,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -510,7 +557,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -527,7 +574,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,6 +584,40 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mean, not yet fully implemented
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// argmax
|
||||||
|
if (0)
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// repeat
|
// repeat
|
||||||
{
|
{
|
||||||
int64_t ne2[4];
|
int64_t ne2[4];
|
||||||
|
@ -549,15 +630,36 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1]))));
|
||||||
|
|
||||||
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// repeat back
|
||||||
|
{
|
||||||
|
int64_t ne2[4];
|
||||||
|
get_random_dims(ne2, 4);
|
||||||
|
|
||||||
|
ne2[0] = ne[0] * ne2[0];
|
||||||
|
ne2[1] = ne[1] * ne2[1];
|
||||||
|
ne2[2] = 1;
|
||||||
|
ne2[3] = 1;
|
||||||
|
|
||||||
|
const int nargs = 1;
|
||||||
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0]))));
|
||||||
|
|
||||||
|
check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// abs (finite differences do not work)
|
// abs (finite differences do not work)
|
||||||
|
@ -566,7 +668,7 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
// for (int ndims = 1; ndims <= 2; ++ndims) {
|
// for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
// for (int i = 0; i < nargs; ++i) {
|
// for (int i = 0; i < nargs; ++i) {
|
||||||
// x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
// x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
// ggml_set_param(ctx0, x[i]);
|
// ggml_set_param(ctx0, x[i]);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -576,17 +678,82 @@ int main(int argc, const char ** argv) {
|
||||||
// }
|
// }
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
// sgn
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// neg
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// step
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tanh, not yet fully implemented
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// mul_mat
|
// mul_mat
|
||||||
{
|
{
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
|
|
||||||
for (int ndims = 2; ndims <= 2; ++ndims) {
|
for (int ndims = 2; ndims <= 2; ++ndims) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
{
|
{
|
||||||
int64_t ne2[4];
|
int64_t ne2[4];
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
ne2[0] = ne[0];
|
ne2[0] = ne[0];
|
||||||
x[1] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
@ -602,13 +769,63 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// elu, not yet fully implemented
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// relu
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// gelu, not yet fully implemented
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0]));
|
||||||
|
|
||||||
|
check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// silu
|
// silu
|
||||||
{
|
{
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -629,11 +846,11 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f));
|
||||||
|
|
||||||
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
|
check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY);
|
||||||
}
|
}
|
||||||
|
@ -647,8 +864,8 @@ int main(int argc, const char ** argv) {
|
||||||
ne2[0] = 1;
|
ne2[0] = 1;
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
@ -659,20 +876,37 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cpy
|
// cpy f32
|
||||||
{
|
{
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 2; ++ndims) {
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
for (int i = 0; i < nargs; ++i) {
|
for (int i = 0; i < nargs; ++i) {
|
||||||
x[i] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[i]);
|
ggml_set_param(ctx0, x[i]);
|
||||||
}
|
}
|
||||||
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
|
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
check_gradient("cpy", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cpy f16
|
||||||
|
{
|
||||||
|
const int nargs = 2;
|
||||||
|
|
||||||
|
for (int ndims = 1; ndims <= 2; ++ndims) {
|
||||||
|
for (int i = 0; i < nargs; ++i) {
|
||||||
|
x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
ggml_set_param(ctx0, x[i]);
|
||||||
|
}
|
||||||
|
// x[1] is overwritten by x[0], so the gradients don't propagate to x[1]
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1]));
|
||||||
|
|
||||||
|
check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -689,8 +923,8 @@ int main(int argc, const char ** argv) {
|
||||||
for (int i = 0; i < ndims; ++i) {
|
for (int i = 0; i < ndims; ++i) {
|
||||||
ne2[0] *= ne[i];
|
ne2[0] *= ne[i];
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
|
||||||
x[1] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
|
||||||
|
@ -712,8 +946,8 @@ int main(int argc, const char ** argv) {
|
||||||
for (int i = 0; i < ndims; ++i) {
|
for (int i = 0; i < ndims; ++i) {
|
||||||
ne2[0] *= ne[i];
|
ne2[0] *= ne[i];
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
|
||||||
|
@ -729,7 +963,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 1);
|
get_random_dims(ne2, 1);
|
||||||
|
@ -737,7 +971,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 1);
|
get_random_dims(ne2, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
|
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
|
||||||
|
@ -758,7 +992,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 2);
|
get_random_dims(ne2, 2);
|
||||||
|
@ -766,7 +1000,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 2);
|
get_random_dims(ne2, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
||||||
|
@ -790,7 +1024,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
for (int ndims = 3; ndims <= 4; ++ndims) {
|
for (int ndims = 3; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 3);
|
get_random_dims(ne2, 3);
|
||||||
|
@ -798,7 +1032,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 3);
|
get_random_dims(ne2, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 3, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
||||||
|
@ -824,7 +1058,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
for (int ndims = 4; ndims <= 4; ++ndims) {
|
for (int ndims = 4; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
|
@ -832,7 +1066,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
||||||
|
@ -858,7 +1092,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 2;
|
const int nargs = 2;
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 1);
|
get_random_dims(ne2, 1);
|
||||||
|
@ -866,7 +1100,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 1);
|
get_random_dims(ne2, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 1, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
|
const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1]));
|
||||||
|
@ -887,7 +1121,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
get_random_dims(ne2, 2);
|
get_random_dims(ne2, 2);
|
||||||
|
@ -895,7 +1129,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 2);
|
get_random_dims(ne2, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
x[1] = get_random_tensor(ctx0, 2, ne2, -1.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
|
||||||
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]);
|
||||||
|
@ -915,7 +1149,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
@ -941,7 +1175,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
|
||||||
get_random_dims(ne2, 2);
|
get_random_dims(ne2, 2);
|
||||||
while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
|
while (ne2[0]*ne2[1] > ggml_nelements(x[0])) {
|
||||||
|
@ -971,7 +1205,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
for (int ndims = 1; ndims <= 4; ++ndims) {
|
for (int ndims = 1; ndims <= 4; ++ndims) {
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
|
|
||||||
get_random_dims(ne2, 3);
|
get_random_dims(ne2, 3);
|
||||||
while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
|
while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) {
|
||||||
|
@ -1010,7 +1244,7 @@ int main(int argc, const char ** argv) {
|
||||||
for (int i=ndims; i<4; ++i) {
|
for (int i=ndims; i<4; ++i) {
|
||||||
ne2[i] = 1;
|
ne2[i] = 1;
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
@ -1043,7 +1277,7 @@ int main(int argc, const char ** argv) {
|
||||||
for (int i=ndims; i<4; ++i) {
|
for (int i=ndims; i<4; ++i) {
|
||||||
ne2[i] = 1;
|
ne2[i] = 1;
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor(ctx0, 4, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
@ -1060,8 +1294,8 @@ int main(int argc, const char ** argv) {
|
||||||
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
|
int64_t ne3[4] = {1+irand(ne[1]), 1, 1, 1};
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
const int ndims = 2;
|
const int ndims = 2;
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
x[1] = get_random_tensor_int(ctx0, 1, ne3, 0, ne2[1]);
|
x[1] = get_random_tensor_i32(ctx0, 1, ne3, 0, ne2[1]);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
@ -1075,7 +1309,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
const int ndims = 2;
|
const int ndims = 2;
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
int n_past = irand(ne[0]);
|
int n_past = irand(ne[0]);
|
||||||
|
@ -1090,7 +1324,7 @@ int main(int argc, const char ** argv) {
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
const int ndims = 2;
|
const int ndims = 2;
|
||||||
|
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
int n_past = irand(ne[0]);
|
int n_past = irand(ne[0]);
|
||||||
|
@ -1108,7 +1342,7 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 3; ++ndims) {
|
for (int ndims = 1; ndims <= 3; ++ndims) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_soft_max(ctx0, x[0]));
|
||||||
|
@ -1125,8 +1359,8 @@ int main(int argc, const char ** argv) {
|
||||||
get_random_dims(ne2, 4);
|
get_random_dims(ne2, 4);
|
||||||
|
|
||||||
for (int ndims = 1; ndims <= 3; ++ndims) {
|
for (int ndims = 1; ndims <= 3; ++ndims) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
x[1] = get_random_tensor(ctx0, ndims, ne2, 0.0f, 1.0f);
|
x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_cross_entropy_loss(ctx0, x[0], x[1]));
|
||||||
|
@ -1136,7 +1370,7 @@ int main(int argc, const char ** argv) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// rope
|
// rope f32
|
||||||
{
|
{
|
||||||
const int nargs = 1;
|
const int nargs = 1;
|
||||||
|
|
||||||
|
@ -1148,7 +1382,7 @@ int main(int argc, const char ** argv) {
|
||||||
for (int ndims = 3; ndims <= 4; ++ndims) {
|
for (int ndims = 3; ndims <= 4; ++ndims) {
|
||||||
for (int mode = 0; mode < 4; ++mode) {
|
for (int mode = 0; mode < 4; ++mode) {
|
||||||
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
|
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
|
||||||
x[0] = get_random_tensor(ctx0, ndims, ne2, -1.0f, 1.0f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
@ -1163,14 +1397,48 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
|
||||||
|
|
||||||
GGML_PRINT_DEBUG("rope: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||||
check_gradient("rope", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// flash_attn
|
// rope f16
|
||||||
|
{
|
||||||
|
const int nargs = 1;
|
||||||
|
|
||||||
|
int64_t ne2[4];
|
||||||
|
get_random_dims(ne2, 4);
|
||||||
|
ne2[0] += ne2[0] % 2;
|
||||||
|
int n_rot = ne2[0];
|
||||||
|
|
||||||
|
for (int ndims = 3; ndims <= 4; ++ndims) {
|
||||||
|
for (int mode = 0; mode < 4; ++mode) {
|
||||||
|
for (int n_past = 1; n_past < ne2[2]; ++n_past) {
|
||||||
|
x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f);
|
||||||
|
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
|
||||||
|
const bool skip_past = (mode & 1);
|
||||||
|
if (skip_past) {
|
||||||
|
// we have no past, so this would have to work on uninitialized memory.
|
||||||
|
// we only test the gradients here;
|
||||||
|
// skip_past should have no influence on gradient computation.
|
||||||
|
// so when other modes work, we assume that this does as well.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], n_past, n_rot, mode, 0));
|
||||||
|
|
||||||
|
GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode);
|
||||||
|
check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flash_attn f32
|
||||||
{
|
{
|
||||||
const int nargs = 3;
|
const int nargs = 3;
|
||||||
|
|
||||||
|
@ -1196,16 +1464,57 @@ int main(int argc, const char ** argv) {
|
||||||
nek[3] = 1;
|
nek[3] = 1;
|
||||||
nev[3] = 1;
|
nev[3] = 1;
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||||
x[1] = get_random_tensor(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||||
x[2] = get_random_tensor(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||||
ggml_set_param(ctx0, x[0]);
|
ggml_set_param(ctx0, x[0]);
|
||||||
ggml_set_param(ctx0, x[1]);
|
ggml_set_param(ctx0, x[1]);
|
||||||
ggml_set_param(ctx0, x[2]);
|
ggml_set_param(ctx0, x[2]);
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||||
|
|
||||||
check_gradient("flash_attn", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flash_attn f16, not yet fully implemented
|
||||||
|
if(0)
|
||||||
|
{
|
||||||
|
const int nargs = 3;
|
||||||
|
|
||||||
|
int64_t ne2[4];
|
||||||
|
|
||||||
|
get_random_dims(ne2, 4);
|
||||||
|
int64_t D = ne2[0];
|
||||||
|
int64_t N = ne2[1];
|
||||||
|
int64_t M = ne2[2] + N;
|
||||||
|
int64_t B = ne2[3];
|
||||||
|
|
||||||
|
for (int masked = 0; masked <= 1; ++masked) {
|
||||||
|
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
|
int64_t neq[4] = { D, N, B, ne[3] };
|
||||||
|
int64_t nek[4] = { D, M, B, ne[3] };
|
||||||
|
int64_t nev[4] = { M, D, B, ne[3] };
|
||||||
|
if (ndims == 2) {
|
||||||
|
neq[2] = 1; neq[3] = 1;
|
||||||
|
nek[2] = 1; nek[3] = 1;
|
||||||
|
nev[2] = 1; nev[3] = 1;
|
||||||
|
} else if (ndims == 3) {
|
||||||
|
neq[3] = 1;
|
||||||
|
nek[3] = 1;
|
||||||
|
nev[3] = 1;
|
||||||
|
}
|
||||||
|
x[0] = get_random_tensor_f16(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||||
|
x[1] = get_random_tensor_f16(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||||
|
x[2] = get_random_tensor_f16(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
ggml_set_param(ctx0, x[2]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||||
|
|
||||||
|
check_gradient("flash_attn f16", ctx0, x, f, ndims, nargs, 1.5e-4f, INFINITY, 3.5f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,9 +125,9 @@ int main(void) {
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx = ggml_init(params);
|
struct ggml_context * ctx = ggml_init(params);
|
||||||
|
|
||||||
int64_t ne1[4] = {4, 1024, 1, 1};
|
int64_t ne1[4] = {4, 128, 1, 1};
|
||||||
int64_t ne2[4] = {4, 2048, 1, 1};;
|
int64_t ne2[4] = {4, 256, 1, 1};;
|
||||||
int64_t ne3[4] = {1024, 2048, 1, 1};
|
int64_t ne3[4] = {128, 256, 1, 1};
|
||||||
|
|
||||||
struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
|
struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1);
|
||||||
struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
|
struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue