Compare commits

..

1 commit

Author SHA1 Message Date
Xuan Son Nguyen
1e6d002c56 server : correct signal handler 2025-02-10 16:51:14 +01:00
30 changed files with 678 additions and 1324 deletions

View file

@ -249,30 +249,16 @@ class chat_template {
inputs.add_generation_prompt = false;
full = apply(inputs);
}
auto eos_pos_last = full.rfind(eos_token_);
if (eos_pos_last == prefix.size() - eos_token_.size() ||
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
full = full.substr(0, eos_pos_last);
}
size_t common_prefix_length = 0;
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
if (prefix[i] != full[i]) {
break;
if (full.find(prefix) != 0) {
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
}
if (prefix[i] == '<') {
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
// but it removes thinking tags for past messages.
// The prefix and full strings diverge at <think> vs. <tool▁calls▁begin>, we avoid consuming the leading <.
continue;
}
common_prefix_length = i + 1;
}
auto example = full.substr(common_prefix_length);
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
if (full.find(prefix) != 0) {
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
} else {
tool_call_example_ = example;
}
tool_call_example_ = full.substr(prefix.size());
}
} catch (const std::exception & e) {
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
@ -377,7 +363,7 @@ class chat_template {
if (polyfill_tools) {
adjusted_messages = add_system(inputs.messages,
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
} else {
adjusted_messages = inputs.messages;
}

View file

@ -2,7 +2,6 @@
#include "ggml.h" // for ggml_log_level
#define LOG_CLR_TO_EOL "\033[K\r"
#define LOG_COL_DEFAULT "\033[0m"
#define LOG_COL_BOLD "\033[1m"
#define LOG_COL_RED "\033[31m"

View file

@ -1385,13 +1385,6 @@ static std::string strip(const std::string & s) {
return s.substr(start, end - start + 1);
}
static std::string capitalize(const std::string & s) {
if (s.empty()) return s;
auto result = s;
result[0] = std::toupper(result[0]);
return result;
}
static std::string html_escape(const std::string & s) {
std::string result;
result.reserve(s.size());
@ -1469,9 +1462,6 @@ public:
if (method->get_name() == "strip") {
vargs.expectArgs("strip method", {0, 0}, {0, 0});
return Value(strip(str));
} else if (method->get_name() == "capitalize") {
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
return Value(capitalize(str));
} else if (method->get_name() == "endswith") {
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
auto suffix = vargs.args[0].get<std::string>();
@ -1802,7 +1792,7 @@ private:
auto left = parseStringConcat();
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
static std::regex not_tok(R"(not\b)");
std::string op_str;
while (!(op_str = consumeToken(compare_tok)).empty()) {
@ -2181,7 +2171,7 @@ private:
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
std::vector<std::string> parseVarNames() {
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
std::vector<std::string> group;
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
@ -2204,13 +2194,13 @@ private:
}
TemplateTokenVector tokenize() {
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
static std::regex expr_open_regex(R"(\{\{([-~])?)");
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
static std::regex block_close_regex(R"(\s*([-~])?%\})");
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
TemplateTokenVector tokens;
std::vector<std::string> group;
@ -2294,7 +2284,7 @@ private:
auto post_space = parseBlockClose();
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
} else if (keyword == "set") {
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
std::string ns;
std::vector<std::string> var_names;
@ -2346,11 +2336,6 @@ private:
throw std::runtime_error("Unexpected block: " + keyword);
}
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
if (!match.position()) {
if (match[0] != "{#")
throw std::runtime_error("Internal error: Expected a comment");
throw std::runtime_error("Missing end of comment tag");
}
auto text_end = it + match.position();
text = std::string(it, text_end);
it = text_end;
@ -2415,7 +2400,7 @@ private:
auto text = text_token->text;
if (post_space == SpaceHandling::Strip) {
static std::regex trailing_space_regex(R"(\s+$)");
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
text = std::regex_replace(text, trailing_space_regex, "");
} else if (options.lstrip_blocks && it != end) {
auto i = text.size();
@ -2425,7 +2410,7 @@ private:
}
}
if (pre_space == SpaceHandling::Strip) {
static std::regex leading_space_regex(R"(^\s+)");
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
text = std::regex_replace(text, leading_space_regex, "");
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
if (text.length() > 0 && text[0] == '\n') {

View file

@ -37,7 +37,7 @@ Once downloaded, place your model in the models folder in llama.cpp.
##### Infinite text from a starting prompt (you can use `Ctrl-C` to stop it):
```bash
./llama-cli -m models/gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
./llama-cli -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1
```
### Windows:

View file

@ -535,7 +535,8 @@ class HttpClient {
static void print_progress(const std::string & progress_prefix, const std::string & progress_bar,
const std::string & progress_suffix) {
printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str());
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
progress_suffix.c_str());
}
// Function to write data to a file
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
@ -796,13 +797,16 @@ class LlamaData {
llama_model_ptr initialize_model(Opt & opt) {
ggml_backend_load_all();
resolve_model(opt.model_);
printe("\r" LOG_CLR_TO_EOL "Loading model");
printe(
"\r%*s"
"\rLoading model",
get_terminal_width(), " ");
llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params));
if (!model) {
printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str());
}
printe("\r" LOG_CLR_TO_EOL);
printe("\r%*s\r", static_cast<int>(sizeof("Loading model")), " ");
return model;
}
@ -965,7 +969,10 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
static int read_user_input(std::string & user_input) {
static const char * prompt_prefix = "> ";
#ifdef WIN32
printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix);
printf(
"\r%*s"
"\r" LOG_COL_DEFAULT "%s",
get_terminal_width(), " ", prompt_prefix);
std::getline(std::cin, user_input);
if (std::cin.eof()) {

View file

@ -220,7 +220,7 @@ services:
The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint.
The web UI is developed using:
- `react` framework for frontend development
- `vue` framework for frontend development
- `tailwindcss` and `daisyui` for styling
- `vite` for build tooling

Binary file not shown.

View file

@ -334,24 +334,24 @@ struct server_task {
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str());
LOG_DBG("JSON schema: %s\n", schema.dump(2).c_str());
params.sampling.grammar = json_schema_to_grammar(schema);
SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
LOG_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str());
} catch (const std::exception & e) {
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
}
} else {
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
LOG_DBG("Grammar: %s\n", params.sampling.grammar.c_str());
params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy);
SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
LOG_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false");
}
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
LOG_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
}
@ -367,12 +367,12 @@ struct server_task {
auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
LOG_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str());
params.sampling.grammar_trigger_tokens.push_back(ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
continue;
}
SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
LOG_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str());
params.sampling.grammar_trigger_words.push_back(trigger);
}
}
@ -381,11 +381,11 @@ struct server_task {
for (const auto & t : *preserved_tokens) {
auto ids = common_tokenize(vocab, t.get<std::string>(), /* add_special= */ false, /* parse_special= */ true);
if (ids.size() == 1) {
SRV_DBG("Preserved token: %d\n", ids[0]);
LOG_DBG("Preserved token: %d\n", ids[0]);
params.sampling.preserved_tokens.insert(ids[0]);
} else {
// This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens.
SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
LOG_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get<std::string>().c_str());
}
}
}
@ -717,7 +717,7 @@ struct server_task_result_cmpl_final : server_task_result {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
SRV_DBG("Parsing chat message: %s\n", content.c_str());
LOG_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else {
@ -1600,6 +1600,10 @@ struct server_queue {
while (true) {
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
lock.unlock();
break;
@ -1620,11 +1624,11 @@ struct server_queue {
QUE_DBG("%s", "waiting for new tasks\n");
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
if (queue_tasks.empty()) {
if (!running) {
QUE_DBG("%s", "terminate\n");
return;
}
condition_tasks.wait(lock, [&]{
return (!queue_tasks.empty() || !running);
});
@ -1885,7 +1889,7 @@ struct server_context {
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
@ -3355,10 +3359,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
// reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
SRV_DBG("request: %s\n", req.body.c_str());
SRV_DBG("response: %s\n", res.body.c_str());
LOG_DBG("request: %s\n", req.body.c_str());
LOG_DBG("response: %s\n", res.body.c_str());
}
std::function<void(int)> shutdown_handler;
@ -3860,9 +3864,7 @@ int main(int argc, char ** argv) {
try {
const auto & prompt = data.at("prompt");
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
LOG_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@ -4378,9 +4380,6 @@ int main(int argc, char ** argv) {
res.set_content("Error: gzip is not supported by this browser", "text/plain");
} else {
res.set_header("Content-Encoding", "gzip");
// COEP and COOP headers, required by pyodide (python interpreter)
res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
res.set_header("Cross-Origin-Opener-Policy", "same-origin");
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
}
return false;
@ -4430,6 +4429,7 @@ int main(int argc, char ** argv) {
// clean up function, to be called before exit
auto clean_up = [&svr]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
svr->stop();
llama_backend_free();
};
@ -4446,10 +4446,6 @@ int main(int argc, char ** argv) {
}
if (!was_bound) {
//LOG_ERROR("couldn't bind HTTP server socket", {
// {"hostname", params.hostname},
// {"port", params.port},
//});
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
clean_up();
return 1;
@ -4466,7 +4462,7 @@ int main(int argc, char ** argv) {
if (!ctx_server.load_model(params)) {
clean_up();
t.join();
// t.join(); // FIXME: see below
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
@ -4490,13 +4486,10 @@ int main(int argc, char ** argv) {
});
shutdown_handler = [&](int) {
// this will unblock start_loop()
ctx_server.queue_tasks.terminate();
};
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
ctx_server.queue_tasks.start_loop();
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
@ -4511,8 +4504,13 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
// this call blocks the main thread until queue_tasks.terminate() is called
ctx_server.queue_tasks.start_loop();
clean_up();
t.join();
// t.join(); // FIXME: http thread may stuck if there is an on-going request. we don't need to care about this for now as the HTTP connection will already be closed at this point, but it's better to fix this
return 0;
}

View file

@ -8,7 +8,6 @@
"name": "webui",
"version": "0.0.0",
"dependencies": {
"@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0",
"@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20",
@ -903,15 +902,6 @@
"node": "^18.18.0 || ^20.9.0 || >=21.1.0"
}
},
"node_modules/@heroicons/react": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@heroicons/react/-/react-2.2.0.tgz",
"integrity": "sha512-LMcepvRaS9LYHJGsF0zzmgKCUim/X3N/DQKc4jepAXJ7l8QxJ1PmxJzqplF2Z3FE4PqBAIGyJAQ/w4B5dsqbtQ==",
"license": "MIT",
"peerDependencies": {
"react": ">= 16 || ^19.0.0-rc"
}
},
"node_modules/@humanfs/core": {
"version": "0.19.1",
"resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz",

View file

@ -11,7 +11,6 @@
"preview": "vite preview"
},
"dependencies": {
"@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0",
"@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20",

View file

@ -1,9 +1,8 @@
import { HashRouter, Outlet, Route, Routes } from 'react-router';
import Header from './components/Header';
import Sidebar from './components/Sidebar';
import { AppContextProvider, useAppContext } from './utils/app.context';
import { AppContextProvider } from './utils/app.context';
import ChatScreen from './components/ChatScreen';
import SettingDialog from './components/SettingDialog';
function App() {
return (
@ -23,23 +22,13 @@ function App() {
}
function AppLayout() {
const { showSettings, setShowSettings } = useAppContext();
return (
<>
<Sidebar />
<div
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto"
id="main-scroll"
>
<div className="chat-screen drawer-content grow flex flex-col h-screen w-screen mx-auto px-4">
<Header />
<Outlet />
</div>
{
<SettingDialog
show={showSettings}
onClose={() => setShowSettings(false)}
/>
}
</>
);
}

View file

@ -10,7 +10,6 @@ export const BASE_URL = new URL('.', document.baseURI).href
export const CONFIG_DEFAULT = {
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
apiKey: '',
systemMessage: 'You are a helpful assistant.',
showTokensPerSecond: false,
@ -37,8 +36,6 @@ export const CONFIG_DEFAULT = {
dry_penalty_last_n: -1,
max_tokens: -1,
custom: '', // custom json-stringified object
// experimental features
pyIntepreterEnabled: false,
};
export const CONFIG_INFO: Record<string, string> = {
apiKey: 'Set the API Key if you are using --api-key option for the server.',

View file

@ -1,195 +0,0 @@
import { useEffect, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import { OpenInNewTab, XCloseButton } from '../utils/common';
import { CanvasType } from '../utils/types';
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
import StorageUtils from '../utils/storage';
const canInterrupt = typeof SharedArrayBuffer === 'function';
// adapted from https://pyodide.org/en/stable/usage/webworker.html
const WORKER_CODE = `
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
let stdOutAndErr = [];
let pyodideReadyPromise = loadPyodide({
stdout: (data) => stdOutAndErr.push(data),
stderr: (data) => stdOutAndErr.push(data),
});
let alreadySetBuff = false;
self.onmessage = async (event) => {
stdOutAndErr = [];
// make sure loading is done
const pyodide = await pyodideReadyPromise;
const { id, python, context, interruptBuffer } = event.data;
if (interruptBuffer && !alreadySetBuff) {
pyodide.setInterruptBuffer(interruptBuffer);
alreadySetBuff = true;
}
// Now load any packages we need, run the code, and send the result back.
await pyodide.loadPackagesFromImports(python);
// make a Python dictionary with the data from content
const dict = pyodide.globals.get("dict");
const globals = dict(Object.entries(context));
try {
self.postMessage({ id, running: true });
// Execute the python code in this context
const result = pyodide.runPython(python, { globals });
self.postMessage({ result, id, stdOutAndErr });
} catch (error) {
self.postMessage({ error: error.message, id });
}
interruptBuffer[0] = 0;
};
`;
let worker: Worker;
const interruptBuffer = canInterrupt
? new Uint8Array(new SharedArrayBuffer(1))
: null;
const startWorker = () => {
if (!worker) {
worker = new Worker(
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
);
}
};
if (StorageUtils.getConfig().pyIntepreterEnabled) {
startWorker();
}
const runCodeInWorker = (
pyCode: string,
callbackRunning: () => void
): {
donePromise: Promise<string>;
interrupt: () => void;
} => {
startWorker();
const id = Math.random() * 1e8;
const context = {};
if (interruptBuffer) {
interruptBuffer[0] = 0;
}
const donePromise = new Promise<string>((resolve) => {
worker.onmessage = (event) => {
const { error, stdOutAndErr, running } = event.data;
if (id !== event.data.id) return;
if (running) {
callbackRunning();
return;
} else if (error) {
resolve(error.toString());
} else {
resolve(stdOutAndErr.join('\n'));
}
};
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
});
const interrupt = () => {
console.log('Interrupting...');
console.trace();
if (interruptBuffer) {
interruptBuffer[0] = 2;
}
};
return { donePromise, interrupt };
};
export default function CanvasPyInterpreter() {
const { canvasData, setCanvasData } = useAppContext();
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
const [running, setRunning] = useState(false);
const [output, setOutput] = useState('');
const [interruptFn, setInterruptFn] = useState<() => void>();
const [showStopBtn, setShowStopBtn] = useState(false);
const runCode = async (pycode: string) => {
interruptFn?.();
setRunning(true);
setOutput('Loading Pyodide...');
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
setOutput('Running...');
setShowStopBtn(canInterrupt);
});
setInterruptFn(() => interrupt);
const out = await donePromise;
setOutput(out);
setRunning(false);
setShowStopBtn(false);
};
// run code on mount
useEffect(() => {
setCode(canvasData?.content ?? '');
runCode(canvasData?.content ?? '');
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [canvasData?.content]);
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
return null;
}
return (
<div className="card bg-base-200 w-full h-full shadow-xl">
<div className="card-body">
<div className="flex justify-between items-center mb-4">
<span className="text-lg font-bold">Python Interpreter</span>
<XCloseButton
className="bg-base-100"
onClick={() => setCanvasData(null)}
/>
</div>
<div className="grid grid-rows-3 gap-4 h-full">
<textarea
className="textarea textarea-bordered w-full h-full font-mono"
value={code}
onChange={(e) => setCode(e.target.value)}
></textarea>
<div className="font-mono flex flex-col row-span-2">
<div className="flex items-center mb-2">
<button
className="btn btn-sm bg-base-100"
onClick={() => runCode(code)}
disabled={running}
>
<PlayIcon className="h-6 w-6" /> Run
</button>
{showStopBtn && (
<button
className="btn btn-sm bg-base-100 ml-2"
onClick={() => interruptFn?.()}
>
<StopIcon className="h-6 w-6" /> Stop
</button>
)}
<span className="grow text-right text-xs">
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
Report a bug
</OpenInNewTab>
</span>
</div>
<textarea
className="textarea textarea-bordered h-full dark-color"
value={output}
readOnly
></textarea>
</div>
</div>
</div>
</div>
);
}

View file

@ -92,7 +92,7 @@ export default function ChatMessage({
<>
<textarea
dir="auto"
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
className="textarea textarea-bordered bg-base-100 text-base-content w-[calc(90vw-8em)] lg:w-96"
value={editingContent}
onChange={(e) => setEditingContent(e.target.value)}
></textarea>
@ -149,17 +149,11 @@ export default function ChatMessage({
)}
</summary>
<div className="collapse-content">
<MarkdownDisplay
content={thought}
isGenerating={isPending}
/>
<MarkdownDisplay content={thought} />
</div>
</details>
)}
<MarkdownDisplay
content={content}
isGenerating={isPending}
/>
<MarkdownDisplay content={content} />
</div>
</>
)}

View file

@ -1,11 +1,9 @@
import { useEffect, useState } from 'react';
import { useEffect, useRef, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import StorageUtils from '../utils/storage';
import { useNavigate } from 'react-router';
import ChatMessage from './ChatMessage';
import { CanvasType, PendingMessage } from '../utils/types';
import { classNames } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter';
import { PendingMessage } from '../utils/types';
export default function ChatScreen() {
const {
@ -14,24 +12,24 @@ export default function ChatScreen() {
isGenerating,
stopGenerating,
pendingMessages,
canvasData,
} = useAppContext();
const [inputMsg, setInputMsg] = useState('');
const containerRef = useRef<HTMLDivElement>(null);
const navigate = useNavigate();
const currConvId = viewingConversation?.id ?? '';
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
const scrollToBottom = (requiresNearBottom: boolean) => {
const mainScrollElem = document.getElementById('main-scroll');
if (!mainScrollElem) return;
if (!containerRef.current) return;
const msgListElem = containerRef.current;
const spaceToBottom =
mainScrollElem.scrollHeight -
mainScrollElem.scrollTop -
mainScrollElem.clientHeight;
msgListElem.scrollHeight -
msgListElem.scrollTop -
msgListElem.clientHeight;
if (!requiresNearBottom || spaceToBottom < 50) {
setTimeout(
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
() => msgListElem.scrollTo({ top: msgListElem.scrollHeight }),
1
);
}
@ -60,87 +58,66 @@ export default function ChatScreen() {
}
};
const hasCanvas = !!canvasData;
return (
<div
className={classNames({
'grid lg:gap-8 grow transition-[300ms]': true,
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
'grid-cols-[1fr_0fr]': !hasCanvas,
})}
>
<>
{/* chat messages */}
<div
className={classNames({
'flex flex-col w-full max-w-[900px] mx-auto': true,
'hidden lg:flex': hasCanvas, // adapted for mobile
flex: !hasCanvas,
})}
id="messages-list"
className="flex flex-col grow overflow-y-auto"
ref={containerRef}
>
{/* chat messages */}
<div id="messages-list" className="grow">
<div className="mt-auto flex justify-center">
{/* placeholder to shift the message to the bottom */}
{viewingConversation ? '' : 'Send a message to start'}
</div>
{viewingConversation?.messages.map((msg) => (
<ChatMessage
key={msg.id}
msg={msg}
scrollToBottom={scrollToBottom}
/>
))}
{pendingMsg && (
<ChatMessage
msg={pendingMsg}
scrollToBottom={scrollToBottom}
isPending
id="pending-msg"
/>
)}
<div className="mt-auto flex justify-center">
{/* placeholder to shift the message to the bottom */}
{viewingConversation ? '' : 'Send a message to start'}
</div>
{viewingConversation?.messages.map((msg) => (
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
))}
{/* chat input */}
<div className="flex flex-row items-center pt-8 pb-6 sticky bottom-0 bg-base-100">
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter' && e.shiftKey) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendNewMessage();
}
}}
id="msg-input"
dir="auto"
></textarea>
{isGenerating(currConvId) ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId)}
>
Stop
</button>
) : (
<button
className="btn btn-primary ml-2"
onClick={sendNewMessage}
disabled={inputMsg.trim().length === 0}
>
Send
</button>
)}
</div>
</div>
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
{canvasData?.type === CanvasType.PY_INTERPRETER && (
<CanvasPyInterpreter />
{pendingMsg && (
<ChatMessage
msg={pendingMsg}
scrollToBottom={scrollToBottom}
isPending
id="pending-msg"
/>
)}
</div>
</div>
{/* chat input */}
<div className="flex flex-row items-center mt-8 mb-6">
<textarea
className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)"
value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter' && e.shiftKey) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendNewMessage();
}
}}
id="msg-input"
dir="auto"
></textarea>
{isGenerating(currConvId) ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId)}
>
Stop
</button>
) : (
<button
className="btn btn-primary ml-2"
onClick={sendNewMessage}
disabled={inputMsg.trim().length === 0}
>
Send
</button>
)}
</div>
</>
);
}

View file

@ -5,11 +5,12 @@ import { classNames } from '../utils/misc';
import daisyuiThemes from 'daisyui/src/theming/themes';
import { THEMES } from '../Config';
import { useNavigate } from 'react-router';
import SettingDialog from './SettingDialog';
export default function Header() {
const navigate = useNavigate();
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
const { setShowSettings } = useAppContext();
const [showSettingDialog, setShowSettingDialog] = useState(false);
const setTheme = (theme: string) => {
StorageUtils.setTheme(theme);
@ -53,7 +54,7 @@ export default function Header() {
};
return (
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
<div className="flex flex-row items-center mt-6 mb-6">
{/* open sidebar button */}
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
<svg
@ -108,7 +109,7 @@ export default function Header() {
</ul>
</div>
<div className="tooltip tooltip-bottom" data-tip="Settings">
<button className="btn" onClick={() => setShowSettings(true)}>
<button className="btn" onClick={() => setShowSettingDialog(true)}>
{/* settings button */}
<svg
xmlns="http://www.w3.org/2000/svg"
@ -171,6 +172,11 @@ export default function Header() {
</div>
</div>
</div>
<SettingDialog
show={showSettingDialog}
onClose={() => setShowSettingDialog(false)}
/>
</div>
);
}

View file

@ -9,16 +9,8 @@ import 'katex/dist/katex.min.css';
import { classNames, copyStr } from '../utils/misc';
import { ElementContent, Root } from 'hast';
import { visit } from 'unist-util-visit';
import { useAppContext } from '../utils/app.context';
import { CanvasType } from '../utils/types';
export default function MarkdownDisplay({
content,
isGenerating,
}: {
content: string;
isGenerating?: boolean;
}) {
export default function MarkdownDisplay({ content }: { content: string }) {
const preprocessedContent = useMemo(
() => preprocessLaTeX(content),
[content]
@ -29,13 +21,8 @@ export default function MarkdownDisplay({
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
components={{
button: (props) => (
<CodeBlockButtons
{...props}
isGenerating={isGenerating}
origContent={preprocessedContent}
/>
<CopyCodeButton {...props} origContent={preprocessedContent} />
),
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
}}
>
{preprocessedContent}
@ -43,12 +30,11 @@ export default function MarkdownDisplay({
);
}
const CodeBlockButtons: React.ElementType<
const CopyCodeButton: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement> &
ExtraProps & { origContent: string; isGenerating?: boolean }
> = ({ node, origContent, isGenerating }) => {
const { config } = useAppContext();
ExtraProps & { origContent: string }
> = ({ node, origContent }) => {
const startOffset = node?.position?.start.offset ?? 0;
const endOffset = node?.position?.end.offset ?? 0;
@ -61,33 +47,14 @@ const CodeBlockButtons: React.ElementType<
[origContent, startOffset, endOffset]
);
const codeLanguage = useMemo(
() =>
origContent
.substring(startOffset, startOffset + 10)
.match(/^```([^\n]+)\n/)?.[1] ?? '',
[origContent, startOffset]
);
const canRunCode =
!isGenerating &&
config.pyIntepreterEnabled &&
codeLanguage.startsWith('py');
return (
<div
className={classNames({
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
'text-right sticky top-4 mb-2 mr-2 h-0': true,
'display-none': !node?.position,
})}
>
<CopyButton className="badge btn-mini" content={copiedContent} />
{canRunCode && (
<RunPyCodeButton
className="badge btn-mini ml-2"
content={copiedContent}
/>
)}
</div>
);
};
@ -114,31 +81,6 @@ export const CopyButton = ({
);
};
export const RunPyCodeButton = ({
content,
className,
}: {
content: string;
className?: string;
}) => {
const { setCanvasData } = useAppContext();
return (
<>
<button
className={className}
onClick={() =>
setCanvasData({
type: CanvasType.PY_INTERPRETER,
content,
})
}
>
Run
</button>
</>
);
};
/**
* This injects the "button" element before each "pre" element.
* The actual button will be replaced with a react component in the MarkdownDisplay.
@ -152,7 +94,9 @@ function rehypeCustomCopyButton() {
// replace current node
preNode.properties.visited = 'true';
node.tagName = 'div';
node.properties = {};
node.properties = {
className: 'relative my-4',
};
// add node for button
const btnNode: ElementContent = {
type: 'element',

View file

@ -3,27 +3,17 @@ import { useAppContext } from '../utils/app.context';
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
import { isDev } from '../Config';
import StorageUtils from '../utils/storage';
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
import {
BeakerIcon,
ChatBubbleOvalLeftEllipsisIcon,
Cog6ToothIcon,
FunnelIcon,
HandRaisedIcon,
SquaresPlusIcon,
} from '@heroicons/react/24/outline';
import { OpenInNewTab } from '../utils/common';
type SettKey = keyof typeof CONFIG_DEFAULT;
const BASIC_KEYS: SettKey[] = [
const COMMON_SAMPLER_KEYS: SettKey[] = [
'temperature',
'top_k',
'top_p',
'min_p',
'max_tokens',
];
const SAMPLER_KEYS: SettKey[] = [
const OTHER_SAMPLER_KEYS: SettKey[] = [
'dynatemp_range',
'dynatemp_exponent',
'typical_p',
@ -41,223 +31,6 @@ const PENALTY_KEYS: SettKey[] = [
'dry_penalty_last_n',
];
enum SettingInputType {
SHORT_INPUT,
LONG_INPUT,
CHECKBOX,
CUSTOM,
}
interface SettingFieldInput {
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
label: string | React.ReactElement;
help?: string | React.ReactElement;
key: SettKey;
}
interface SettingFieldCustom {
type: SettingInputType.CUSTOM;
key: SettKey;
component:
| string
| React.FC<{
value: string | boolean | number;
onChange: (value: string) => void;
}>;
}
interface SettingSection {
title: React.ReactElement;
fields: (SettingFieldInput | SettingFieldCustom)[];
}
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
const SETTING_SECTIONS: SettingSection[] = [
{
title: (
<>
<Cog6ToothIcon className={ICON_CLASSNAME} />
General
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'API Key',
key: 'apiKey',
},
{
type: SettingInputType.LONG_INPUT,
label: 'System Message (will be disabled if left empty)',
key: 'systemMessage',
},
...BASIC_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<FunnelIcon className={ICON_CLASSNAME} />
Samplers
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'Samplers queue',
key: 'samplers',
},
...SAMPLER_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<HandRaisedIcon className={ICON_CLASSNAME} />
Penalties
</>
),
fields: PENALTY_KEYS.map((key) => ({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
})),
},
{
title: (
<>
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
Reasoning
</>
),
fields: [
{
type: SettingInputType.CHECKBOX,
label: 'Expand though process by default for generating message',
key: 'showThoughtInProgress',
},
{
type: SettingInputType.CHECKBOX,
label:
'Exclude thought process when sending request to API (Recommended for DeepSeek-R1)',
key: 'excludeThoughtOnReq',
},
],
},
{
title: (
<>
<SquaresPlusIcon className={ICON_CLASSNAME} />
Advanced
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => {
const debugImportDemoConv = async () => {
const res = await fetch('/demo-conversation.json');
const demoConv = await res.json();
StorageUtils.remove(demoConv.id);
for (const msg of demoConv.messages) {
StorageUtils.appendMsg(demoConv.id, msg);
}
};
return (
<button className="btn" onClick={debugImportDemoConv}>
(debug) Import demo conversation
</button>
);
},
},
{
type: SettingInputType.CHECKBOX,
label: 'Show tokens per second',
key: 'showTokensPerSecond',
},
{
type: SettingInputType.LONG_INPUT,
label: (
<>
Custom JSON config (For more info, refer to{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md">
server documentation
</OpenInNewTab>
)
</>
),
key: 'custom',
},
],
},
{
title: (
<>
<BeakerIcon className={ICON_CLASSNAME} />
Experimental
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => (
<>
<p className="mb-8">
Experimental features are not guaranteed to work correctly.
<br />
<br />
If you encounter any problems, create a{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
Bug (misc.)
</OpenInNewTab>{' '}
report on Github. Please also specify <b>webui/experimental</b> on
the report title and include screenshots.
<br />
<br />
Some features may require packages downloaded from CDN, so they
need internet connection.
</p>
</>
),
},
{
type: SettingInputType.CHECKBOX,
label: (
<>
<b>Enable Python interpreter</b>
<br />
<small className="text-xs">
This feature uses{' '}
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
downloaded from CDN. To use this feature, ask the LLM to generate
python code inside a markdown code block. You will see a "Run"
button on the code block, near the "Copy" button.
</small>
</>
),
key: 'pyIntepreterEnabled',
},
],
},
];
export default function SettingDialog({
show,
onClose,
@ -266,7 +39,6 @@ export default function SettingDialog({
onClose: () => void;
}) {
const { config, saveConfig } = useAppContext();
const [sectionIdx, setSectionIdx] = useState(0);
// clone the config object to prevent direct mutation
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
@ -280,148 +52,206 @@ export default function SettingDialog({
};
const handleSave = () => {
// copy the local config to prevent direct mutation
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
JSON.stringify(localConfig)
);
// validate the config
for (const key in newConfig) {
const value = newConfig[key as SettKey];
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
if (mustBeString) {
if (!isString(value)) {
alert(`Value for ${key} must be string`);
return;
}
} else if (mustBeNumeric) {
const trimedValue = value.toString().trim();
const numVal = Number(trimedValue);
if (isNaN(numVal) || !isNumeric(numVal) || trimedValue.length === 0) {
alert(`Value for ${key} must be numeric`);
return;
}
// force conversion to number
// @ts-expect-error this is safe
newConfig[key] = numVal;
} else if (mustBeBoolean) {
if (!isBoolean(value)) {
alert(`Value for ${key} must be boolean`);
return;
}
} else {
console.error(`Unknown default type for key ${key}`);
}
}
if (isDev) console.log('Saving config', newConfig);
saveConfig(newConfig);
saveConfig(localConfig);
onClose();
};
const onChange = (key: SettKey) => (value: string | boolean) => {
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
setLocalConfig({ ...localConfig, [key]: value });
const debugImportDemoConv = async () => {
const res = await fetch('/demo-conversation.json');
const demoConv = await res.json();
StorageUtils.remove(demoConv.id);
for (const msg of demoConv.messages) {
StorageUtils.appendMsg(demoConv.id, msg);
}
onClose();
};
return (
<dialog className={classNames({ modal: true, 'modal-open': show })}>
<div className="modal-box w-11/12 max-w-3xl">
<dialog className={`modal ${show ? 'modal-open' : ''}`}>
<div className="modal-box">
<h3 className="text-lg font-bold mb-6">Settings</h3>
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
{/* Left panel, showing sections - Desktop version */}
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
{section.title}
</div>
))}
</div>
<div className="h-[calc(90vh-12rem)] overflow-y-auto">
<p className="opacity-40 mb-6">
Settings below are saved in browser's localStorage
</p>
{/* Left panel, showing sections - Mobile version */}
<div className="md:hidden flex flex-row gap-2 mb-4">
<details className="dropdown">
<summary className="btn bt-sm w-full m-1">
{SETTING_SECTIONS[sectionIdx].title}
</summary>
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
{section.title}
</div>
))}
</ul>
</details>
</div>
<SettingsModalShortInput
configKey="apiKey"
configDefault={CONFIG_DEFAULT}
value={localConfig.apiKey}
onChange={(value) =>
setLocalConfig({ ...localConfig, apiKey: value })
}
/>
{/* Right panel, showing setting fields */}
<div className="grow overflow-y-auto px-4">
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
const key = `${sectionIdx}-${idx}`;
if (field.type === SettingInputType.SHORT_INPUT) {
return (
<SettingsModalShortInput
key={key}
configKey={field.key}
value={localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.LONG_INPUT) {
return (
<SettingsModalLongInput
key={key}
configKey={field.key}
value={localConfig[field.key].toString()}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CHECKBOX) {
return (
<SettingsModalCheckbox
key={key}
configKey={field.key}
value={!!localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CUSTOM) {
return (
<div key={key} className="mb-2">
{typeof field.component === 'string'
? field.component
: field.component({
value: localConfig[field.key],
onChange: onChange(field.key),
})}
</div>
);
<label className="form-control mb-2">
<div className="label">
System Message (will be disabled if left empty)
</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT.systemMessage}`}
value={localConfig.systemMessage}
onChange={(e) =>
setLocalConfig({
...localConfig,
systemMessage: e.target.value,
})
}
})}
/>
</label>
<p className="opacity-40 mb-6 text-sm mt-8">
Settings are saved in browser's localStorage
</p>
</div>
{COMMON_SAMPLER_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Other sampler settings
</summary>
<div className="collapse-content">
<SettingsModalShortInput
label="Samplers queue"
configKey="samplers"
configDefault={CONFIG_DEFAULT}
value={localConfig.samplers}
onChange={(value) =>
setLocalConfig({ ...localConfig, samplers: value })
}
/>
{OTHER_SAMPLER_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
</div>
</details>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Penalties settings
</summary>
<div className="collapse-content">
{PENALTY_KEYS.map((key) => (
<SettingsModalShortInput
key={key}
configKey={key}
configDefault={CONFIG_DEFAULT}
value={localConfig[key]}
onChange={(value) =>
setLocalConfig({ ...localConfig, [key]: value })
}
/>
))}
</div>
</details>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Reasoning models
</summary>
<div className="collapse-content">
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.showThoughtInProgress}
onChange={(e) =>
setLocalConfig({
...localConfig,
showThoughtInProgress: e.target.checked,
})
}
/>
<span className="ml-4">
Expand though process by default for generating message
</span>
</div>
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.excludeThoughtOnReq}
onChange={(e) =>
setLocalConfig({
...localConfig,
excludeThoughtOnReq: e.target.checked,
})
}
/>
<span className="ml-4">
Exclude thought process when sending request to API
(Recommended for DeepSeek-R1)
</span>
</div>
</div>
</details>
<details className="collapse collapse-arrow bg-base-200 mb-2 overflow-visible">
<summary className="collapse-title font-bold">
Advanced config
</summary>
<div className="collapse-content">
{/* this button only shows in dev mode, used to import a demo conversation to test message rendering */}
{isDev && (
<div className="flex flex-row items-center mb-2">
<button className="btn" onClick={debugImportDemoConv}>
(debug) Import demo conversation
</button>
</div>
)}
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="checkbox"
checked={localConfig.showTokensPerSecond}
onChange={(e) =>
setLocalConfig({
...localConfig,
showTokensPerSecond: e.target.checked,
})
}
/>
<span className="ml-4">Show tokens per second</span>
</div>
<label className="form-control mb-2">
<div className="label inline">
Custom JSON config (For more info, refer to{' '}
<a
className="underline"
href="https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md"
target="_blank"
rel="noopener noreferrer"
>
server documentation
</a>
)
</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder='Example: { "mirostat": 1, "min_p": 0.1 }'
value={localConfig.custom}
onChange={(e) =>
setLocalConfig({ ...localConfig, custom: e.target.value })
}
/>
</label>
</div>
</details>
</div>
<div className="modal-action">
@ -440,97 +270,37 @@ export default function SettingDialog({
);
}
function SettingsModalLongInput({
function SettingsModalShortInput({
configKey,
configDefault,
value,
onChange,
label,
}: {
configKey: SettKey;
value: string;
configDefault: typeof CONFIG_DEFAULT;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
value: any;
onChange: (value: string) => void;
label?: string;
}) {
return (
<label className="form-control mb-2">
<div className="label inline">{label || configKey}</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
<div className="dropdown dropdown-hover">
<div tabIndex={0} role="button" className="font-bold">
{label || configKey}
</div>
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{CONFIG_INFO[configKey] ?? '(no help message available)'}
</div>
</div>
<input
type="text"
className="grow"
placeholder={`Default: ${configDefault[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
);
}
function SettingsModalShortInput({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
value: any;
onChange: (value: string) => void;
label?: string;
}) {
const helpMsg = CONFIG_INFO[configKey];
return (
<>
{/* on mobile, we simply show the help message here */}
{helpMsg && (
<div className="block md:hidden mb-1">
<b>{label || configKey}</b>
<br />
<p className="text-xs">{helpMsg}</p>
</div>
)}
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
<div className="dropdown dropdown-hover">
<div tabIndex={0} role="button" className="font-bold hidden md:block">
{label || configKey}
</div>
{helpMsg && (
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{helpMsg}
</div>
)}
</div>
<input
type="text"
className="grow"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
</>
);
}
function SettingsModalCheckbox({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
value: boolean;
onChange: (value: boolean) => void;
label: string;
}) {
return (
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="toggle"
checked={value}
onChange={(e) => onChange(e.target.checked)}
/>
<span className="ml-4">{label || configKey}</span>
</div>
);
}

View file

@ -45,9 +45,6 @@
/* Highlight.js */
[data-color-scheme='light'] {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
[data-color-scheme='dark'] {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
@ -55,9 +52,6 @@
[data-color-scheme='auto'] {
@media (prefers-color-scheme: light) {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
@media (prefers-color-scheme: dark) {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');

View file

@ -1,11 +1,5 @@
import React, { createContext, useContext, useEffect, useState } from 'react';
import {
APIMessage,
CanvasData,
Conversation,
Message,
PendingMessage,
} from './types';
import { APIMessage, Conversation, Message, PendingMessage } from './types';
import StorageUtils from './storage';
import {
filterThoughtFromMsgs,
@ -16,7 +10,6 @@ import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
import { matchPath, useLocation } from 'react-router';
interface AppContextValue {
// conversations and messages
viewingConversation: Conversation | null;
pendingMessages: Record<Conversation['id'], PendingMessage>;
isGenerating: (convId: string) => boolean;
@ -33,15 +26,8 @@ interface AppContextValue {
onChunk?: CallbackGeneratedChunk
) => Promise<void>;
// canvas
canvasData: CanvasData | null;
setCanvasData: (data: CanvasData | null) => void;
// config
config: typeof CONFIG_DEFAULT;
saveConfig: (config: typeof CONFIG_DEFAULT) => void;
showSettings: boolean;
setShowSettings: (show: boolean) => void;
}
// for now, this callback is only used for scrolling to the bottom of the chat
@ -68,13 +54,8 @@ export const AppContextProvider = ({
Record<Conversation['id'], AbortController>
>({});
const [config, setConfig] = useState(StorageUtils.getConfig());
const [canvasData, setCanvasData] = useState<CanvasData | null>(null);
const [showSettings, setShowSettings] = useState(false);
// handle change when the convId from URL is changed
useEffect(() => {
// also reset the canvas data
setCanvasData(null);
const handleConversationChange = (changedConvId: string) => {
if (changedConvId !== convId) return;
setViewingConversation(StorageUtils.getOneConversation(convId));
@ -311,12 +292,8 @@ export const AppContextProvider = ({
sendMessage,
stopGenerating,
replaceMessageAndGenerate,
canvasData,
setCanvasData,
config,
saveConfig,
showSettings,
setShowSettings,
}}
>
{children}

View file

@ -1,38 +0,0 @@
export const XCloseButton: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement>
> = ({ className, ...props }) => (
<button className={`btn btn-square btn-sm ${className ?? ''}`} {...props}>
<svg
xmlns="http://www.w3.org/2000/svg"
className="h-6 w-6"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
);
export const OpenInNewTab = ({
href,
children,
}: {
href: string;
children: string;
}) => (
<a
className="underline"
href={href}
target="_blank"
rel="noopener noreferrer"
>
{children}
</a>
);

View file

@ -85,6 +85,3 @@ export function classNames(classes: Record<string, boolean>): string {
.map(([key, _]) => key)
.join(' ');
}
export const delay = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));

View file

@ -23,14 +23,3 @@ export interface Conversation {
export type PendingMessage = Omit<Message, 'content'> & {
content: string | null;
};
export enum CanvasType {
PY_INTERPRETER,
}
export interface CanvasPyInterpreter {
type: CanvasType.PY_INTERPRETER;
content: string;
}
export type CanvasData = CanvasPyInterpreter;

View file

@ -72,9 +72,5 @@ export default defineConfig({
proxy: {
'/v1': 'http://localhost:8080',
},
headers: {
'Cross-Origin-Embedder-Policy': 'require-corp',
'Cross-Origin-Opener-Policy': 'same-origin',
},
},
});

View file

@ -10,6 +10,8 @@ extern "C" {
#define GGML_VK_NAME "Vulkan"
#define GGML_VK_MAX_DEVICES 16
GGML_BACKEND_API void ggml_vk_instance_init(void);
// backend API
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);

View file

@ -13856,13 +13856,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
tp->ec = GGML_STATUS_ABORTED;
}
if (node_n + 1 < cgraph->n_nodes) {
ggml_barrier(state->threadpool);
}
ggml_barrier(state->threadpool);
}
ggml_barrier(state->threadpool);
return 0;
}

View file

@ -16,7 +16,7 @@
#include "common.cuh"
#if CUDART_VERSION >= 11080
#if CUDART_VERSION >= 11800
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
@ -50,7 +50,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11080
#endif // CUDART_VERSION >= 11800
template <typename T>

View file

@ -167,7 +167,6 @@ struct vk_device_struct {
uint32_t subgroup_size;
uint32_t shader_core_count;
bool uma;
bool prefer_host_memory;
bool float_controls_rte_fp16;
bool subgroup_size_control;
@ -185,12 +184,12 @@ struct vk_device_struct {
size_t idx;
bool mul_mat_l[GGML_TYPE_COUNT];
bool mul_mat_m[GGML_TYPE_COUNT];
bool mul_mat_s[GGML_TYPE_COUNT];
bool mul_mat_id_l[GGML_TYPE_COUNT];
bool mul_mat_id_m[GGML_TYPE_COUNT];
bool mul_mat_id_s[GGML_TYPE_COUNT];
bool mul_mat_l;
bool mul_mat_m;
bool mul_mat_s;
bool mul_mat_id_l;
bool mul_mat_id_m;
bool mul_mat_id_s;
// set to true to indicate that some shaders need to be compiled after the dryrun
bool need_compiles {};
@ -1295,9 +1294,7 @@ static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk:
static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
vk_buffer buf;
try {
if (device->prefer_host_memory) {
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
} else if (device->uma) {
if (device->uma) {
// Fall back to host memory type
buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
} else {
@ -1381,33 +1378,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
return {64, 64};
};
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
uint32_t lut_size = 0;
switch (src0_type) {
case GGML_TYPE_IQ2_XXS:
lut_size = 8*256;
break;
case GGML_TYPE_IQ2_XS:
lut_size = 8*512;
break;
case GGML_TYPE_IQ2_S:
lut_size = 8*1024;
break;
case GGML_TYPE_IQ3_XXS:
lut_size = 4*256;
break;
case GGML_TYPE_IQ3_S:
lut_size = 4*512;
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
lut_size = 4*16;
break;
default:
break;
}
static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
// Needs to be kept up to date on shader changes
const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
@ -1417,13 +1388,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
"mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
return supported;
return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
}
static void ggml_vk_load_shaders(vk_device& device) {
@ -1507,32 +1472,62 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_align = 64;
s_align = 32;
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
ggml_type t = (ggml_type)i;
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
device->mul_mat_m[i] = false;
device->mul_mat_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
device->mul_mat_l[i] = false;
}
// Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
// and tile sizes, this should handle 16KB, 32KB, and 48KB+.
// This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
// But the numbers happen to work out for 32KB shared memory size that when using the medium
// size there's enough room for everything, and we assert for this.
uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
l_warptile = m_warptile;
l_wg_denoms = m_wg_denoms;
shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
}
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
// assert mul_mat_mat_id shaders will fit.
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true, t)) {
device->mul_mat_id_s[i] = false;
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true, t)) {
device->mul_mat_id_m[i] = false;
device->mul_mat_id_l[i] = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true, t)) {
device->mul_mat_id_l[i] = false;
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
l_warptile_mmq = m_warptile_mmq;
l_mmq_wg_denoms = m_mmq_wg_denoms;
} else {
l_warptile_mmq = s_warptile_mmq;
l_mmq_wg_denoms = s_mmq_wg_denoms;
}
shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
}
if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
// assert mul_mat_mat_id shaders will fit.
GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
}
// Disable medium and large matrix multiplication if not enough shared memory is available
// Check mmq warptiles as the largest configuration
// Throw an error if not enough for any matrix multiplication is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
throw std::runtime_error("Shared memory size too small for matrix multiplication.");
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
device->mul_mat_m = false;
device->mul_mat_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
device->mul_mat_l = false;
}
// Disable mul_mat_id if not enough shared memory is available
if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
device->mul_mat_id_s = false;
device->mul_mat_id_m = false;
device->mul_mat_id_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
device->mul_mat_id_m = false;
device->mul_mat_id_l = false;
} else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
device->mul_mat_id_l = false;
}
}
@ -1689,116 +1684,119 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat_support) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
if (device->mul_mat ## ID ## _l[TYPE]) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->coopmat_acc_f16_support) { \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
} \
if (device->coopmat_acc_f32_support) { \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
} \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
if (device->coopmat_acc_f16_support) {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
if (device->coopmat_acc_f16_support) {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
if (device->coopmat_acc_f16_support) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
}
#undef CREATE_MM2
#undef CREATE_MM
@ -1806,135 +1804,141 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->fp16) {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _l[TYPE]) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
// Create 2 variants, {f16,f32} accumulator
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f16acc, matmul_id_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f16acc, matmul_id_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
} else {
// Create 6 variants, {s,m,l}x{unaligned,aligned}
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l[TYPE]) \
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
if (device->mul_mat ## ID ## _l[TYPE]) \
if (device->mul_mat ## ID ## _l) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
if (device->mul_mat ## ID ## _m[TYPE]) \
if (device->mul_mat ## ID ## _m) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
if (device->mul_mat ## ID ## _s[TYPE]) \
if (device->mul_mat ## ID ## _s) \
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM
}
@ -2202,9 +2206,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->physical_device = physical_devices[dev_num];
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
bool fp16_storage = false;
bool fp16_compute = false;
bool maintenance4_support = false;
@ -2622,36 +2623,34 @@ static vk_device ggml_vk_get_device(size_t idx) {
// Shaders
// Disable matmul tile sizes early if performance low or not supported
for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
switch (device->vendor_id) {
switch (device->vendor_id) {
#ifndef GGML_VULKAN_RUN_TESTS
case VK_VENDOR_ID_AMD:
case VK_VENDOR_ID_INTEL:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
case VK_VENDOR_ID_APPLE:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = false;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = false;
break;
case VK_VENDOR_ID_AMD:
case VK_VENDOR_ID_INTEL:
device->mul_mat_l = false;
device->mul_mat_m = true;
device->mul_mat_s = true;
device->mul_mat_id_l = false;
device->mul_mat_id_m = true;
device->mul_mat_id_s = true;
break;
case VK_VENDOR_ID_APPLE:
device->mul_mat_l = false;
device->mul_mat_m = true;
device->mul_mat_s = false;
device->mul_mat_id_l = false;
device->mul_mat_id_m = true;
device->mul_mat_id_s = false;
break;
#endif
default:
device->mul_mat_l[i] = true;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = true;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
}
default:
device->mul_mat_l = true;
device->mul_mat_m = true;
device->mul_mat_s = true;
device->mul_mat_id_l = true;
device->mul_mat_id_m = true;
device->mul_mat_id_s = true;
break;
}
ggml_vk_load_shaders(device);
@ -2793,12 +2792,14 @@ static void ggml_vk_print_gpu_info(size_t idx) {
static bool ggml_vk_instance_validation_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
static void ggml_vk_instance_init() {
void ggml_vk_instance_init() {
if (vk_instance_initialized) {
return;
}
VK_LOG_DEBUG("ggml_vk_instance_init()");
vk_instance_initialized = true;
uint32_t api_version = vk::enumerateInstanceVersion();
if (api_version < VK_API_VERSION_1_2) {
@ -2849,7 +2850,6 @@ static void ggml_vk_instance_init() {
GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
}
vk_instance.instance = vk::createInstance(instance_create_info);
vk_instance_initialized = true;
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
@ -2874,7 +2874,7 @@ static void ggml_vk_instance_init() {
// Make sure at least one device exists
if (devices.empty()) {
std::cerr << "ggml_vulkan: Error: No devices found." << std::endl;
return;
GGML_ABORT("fatal error");
}
// Default to using all dedicated GPUs
@ -3756,31 +3756,31 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
return split_k;
}
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true)->align;
}
static void ggml_vk_matmul(
@ -3807,31 +3807,31 @@ static void ggml_vk_matmul(
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
}
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
if (ctx->device->coopmat2) {
if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
return aligned ? mmp->a_l : mmp->l;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
return aligned ? mmp->a_s : mmp->s;
}
if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
return aligned ? mmp->a_m : mmp->m;
}
return aligned ? mmp->a_l : mmp->l;
}
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
}
static void ggml_vk_matmul_id(
@ -4012,10 +4012,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
@ -4594,10 +4594,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t y_ne = ne11 * ne10;
const uint64_t d_ne = ne21 * ne20;
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@ -8036,14 +8036,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
ggml_type src0_type = op->src[0]->type;
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU
return false;
}
switch (src0_type) {
switch (op->src[0]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q4_0:
@ -8349,13 +8348,8 @@ ggml_backend_reg_t ggml_backend_vk_reg() {
/* .iface = */ ggml_backend_vk_reg_i,
/* .context = */ nullptr,
};
try {
ggml_vk_instance_init();
return &reg;
} catch (const vk::SystemError& e) {
VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
return nullptr;
}
return &reg;
}
// Extension availability

View file

@ -1275,7 +1275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const bool use_mmap_buffer = true;
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, use_mmap_buffer ? "true" : "false");
// build a list of buffer types for the CPU and GPU devices
pimpl->cpu_buft_list = make_cpu_buft_list(devices);

View file

@ -9430,6 +9430,7 @@ static struct llama_model * llama_model_load_from_file_impl(
struct llama_model_params params) {
ggml_time_init();
unsigned cur_percentage = 0;
if (params.progress_callback == NULL) {
params.progress_callback_user_data = &cur_percentage;