Merge remote-tracking branch 'origin/master' into sl/mmid-cpu-perf

This commit is contained in:
slaren 2025-02-09 16:27:24 +01:00
commit 2d493d26ab
93 changed files with 9479 additions and 3868 deletions

View file

@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case {
struct test_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const bool v; // whether a is a non-contiguous view
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
return VARS_TO_STR4(type, ne, v, eps);
}
test_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
: type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}
ggml_tensor * out = ggml_norm(ctx, a, eps);
ggml_set_name(out, "out");
@ -1700,22 +1707,29 @@ struct test_norm : public test_case {
struct test_rms_norm : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const bool v; // whether a is a non-contiguous view
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
return VARS_TO_STR4(type, ne, v, eps);
}
test_rms_norm(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
bool v = false,
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
: type(type), ne(ne), v(v), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(ctx, a);
ggml_set_name(a, "a");
if (v) {
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view of a");
}
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
ggml_set_name(out, "out");
@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case {
struct test_rms_norm_back : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float eps;
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case {
const float eps;
std::string vars() override {
return VARS_TO_STR3(type, ne, num_groups);
return VARS_TO_STR4(type, ne, num_groups, eps);
}
test_group_norm(ggml_type type = GGML_TYPE_F32,
@ -3964,9 +3978,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
}
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}

View file

@ -18,12 +18,8 @@
using json = nlohmann::ordered_json;
static common_chat_msg msg_from_json(const json & message) {
common_chat_msg ret{
"assistant",
"",
{},
/* .tool_plan = */ "",
};
common_chat_msg ret;
ret.role = "assistant";
if (message.contains("content") && !message.at("content").is_null()) {
ret.content = message.at("content");
}
@ -289,7 +285,7 @@ static void test_template(const common_chat_template & tmpl, const std::vector<s
static void test_template_output_parsers() {
json text_message {
{ "role", "assistant" },
{ "content", "Hello, world!" },
{ "content", "Hello, world!\nWhat's up?" },
};
json tool_calls = json::array({{
{ "type", "function" },
@ -379,7 +375,7 @@ static void test_template_output_parsers() {
common_chat_inputs inputs_no_tools;
inputs_no_tools.messages = {
{ { "role", "user" }, { "content", "Hey" } }
{ { "role", "user" }, { "content", "Hey\nThere" } }
};
common_chat_inputs inputs_tools = inputs_no_tools;
@ -408,7 +404,8 @@ static void test_template_output_parsers() {
" {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
"]<|END_ACTION|>");
test_template(tmpl, end_tokens, text_message, tools,
"<|START_RESPONSE|>Hello, world!<|END_RESPONSE|>",
"<|START_RESPONSE|>Hello, world!\n"
"What's up?<|END_RESPONSE|>",
/* expect_grammar_triggered= */ false);
}
{
@ -428,7 +425,7 @@ static void test_template_output_parsers() {
assert_msg_equals(msg_from_json(text_message),
common_chat_parse("{\n"
" \"response\": \"Hello, world!\"\n"
" \"response\": \"Hello, world!\\nWhat's up?\"\n"
"}",
common_chat_params_init(tmpl, inputs_tools).format));
test_template(tmpl, end_tokens, tool_call_message_with_id, tools,
@ -451,7 +448,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(
tmpl, end_tokens, tool_call_message_with_id, tools,
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
@ -476,7 +473,7 @@ static void test_template_output_parsers() {
inputs_tools)
.format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool_call>\n"
"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
@ -516,7 +513,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}");
}
@ -528,7 +525,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<function=special_function>{\"arg1\": 1}</function>");
}
@ -542,7 +539,8 @@ static void test_template_output_parsers() {
test_template(tmpl, end_tokens, text_message, {},
"all\n"
"Hello, world!",
"Hello, world!\n"
"What's up?",
/* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"special_function\n"
@ -555,7 +553,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]");
}
@ -566,7 +564,7 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_params_init(tmpl, inputs_tools).format);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, text_message, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_template(tmpl, end_tokens, tool_call_message, tools,
"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>special_function\n"
"```json\n"