Merge branch 'master' of https://github.com/ggerganov/llama.cpp
This commit is contained in:
commit
7752c97f3e
75 changed files with 32967 additions and 3133 deletions
|
@ -24,6 +24,16 @@ insert_final_newline = unset
|
|||
[examples/server/public/*]
|
||||
indent_size = 2
|
||||
|
||||
[examples/server/public/deps_*]
|
||||
trim_trailing_whitespace = unset
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
|
||||
[examples/server/deps_*]
|
||||
trim_trailing_whitespace = unset
|
||||
indent_style = unset
|
||||
indent_size = unset
|
||||
|
||||
[examples/llama.swiftui/llama.swiftui.xcodeproj/*]
|
||||
indent_style = tab
|
||||
|
||||
|
|
17
.github/workflows/build.yml
vendored
17
.github/workflows/build.yml
vendored
|
@ -55,7 +55,13 @@ jobs:
|
|||
sysctl -a
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL_EMBED_LIBRARY=ON -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF ..
|
||||
cmake .. \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DGGML_RPC=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build . --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
|
@ -113,7 +119,12 @@ jobs:
|
|||
sysctl -a
|
||||
# Metal is disabled due to intermittent failures with Github runners not having a GPU:
|
||||
# https://github.com/ggerganov/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313
|
||||
cmake -B build -DLLAMA_FATAL_WARNINGS=ON -DGGML_METAL=OFF -DLLAMA_CURL=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF
|
||||
cmake -B build \
|
||||
-DLLAMA_FATAL_WARNINGS=ON \
|
||||
-DLLAMA_CURL=ON \
|
||||
-DGGML_METAL=OFF \
|
||||
-DGGML_RPC=ON \
|
||||
-DBUILD_SHARED_LIBS=OFF
|
||||
cmake --build build --config Release -j $(sysctl -n hw.logicalcpu)
|
||||
|
||||
- name: Test
|
||||
|
@ -569,6 +580,7 @@ jobs:
|
|||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
|
@ -599,6 +611,7 @@ jobs:
|
|||
mkdir build
|
||||
cd build
|
||||
cmake -G Xcode .. \
|
||||
-DGGML_METAL_USE_BF16=ON \
|
||||
-DGGML_METAL_EMBED_LIBRARY=ON \
|
||||
-DLLAMA_BUILD_EXAMPLES=OFF \
|
||||
-DLLAMA_BUILD_TESTS=OFF \
|
||||
|
|
21
Makefile
21
Makefile
|
@ -878,6 +878,10 @@ ifdef GGML_METAL
|
|||
MK_CPPFLAGS += -DGGML_USE_METAL
|
||||
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
OBJ_GGML += ggml/src/ggml-metal.o
|
||||
|
||||
ifdef GGML_METAL_USE_BF16
|
||||
MK_CPPFLAGS += -DGGML_METAL_USE_BF16
|
||||
endif # GGML_METAL_USE_BF16
|
||||
ifdef GGML_METAL_NDEBUG
|
||||
MK_CPPFLAGS += -DGGML_METAL_NDEBUG
|
||||
endif
|
||||
|
@ -1455,22 +1459,13 @@ llama-server: \
|
|||
examples/server/server.cpp \
|
||||
examples/server/utils.hpp \
|
||||
examples/server/httplib.h \
|
||||
examples/server/colorthemes.css.hpp \
|
||||
examples/server/style.css.hpp \
|
||||
examples/server/theme-beeninorder.css.hpp \
|
||||
examples/server/theme-ketivah.css.hpp \
|
||||
examples/server/theme-mangotango.css.hpp \
|
||||
examples/server/theme-playground.css.hpp \
|
||||
examples/server/theme-polarnight.css.hpp \
|
||||
examples/server/theme-snowstorm.css.hpp \
|
||||
examples/server/index.html.hpp \
|
||||
examples/server/index-new.html.hpp \
|
||||
examples/server/index.js.hpp \
|
||||
examples/server/completion.js.hpp \
|
||||
examples/server/system-prompts.js.hpp \
|
||||
examples/server/prompt-formats.js.hpp \
|
||||
examples/server/json-schema-to-grammar.mjs.hpp \
|
||||
examples/server/loading.html.hpp \
|
||||
examples/server/deps_daisyui.min.css.hpp \
|
||||
examples/server/deps_markdown-it.js.hpp \
|
||||
examples/server/deps_tailwindcss.js.hpp \
|
||||
examples/server/deps_vue.esm-browser.js.hpp \
|
||||
common/json.hpp \
|
||||
common/stb_image.h \
|
||||
$(OBJ_ALL)
|
||||
|
|
|
@ -92,13 +92,15 @@ let package = Package(
|
|||
name: "llama",
|
||||
path: ".",
|
||||
exclude: [
|
||||
"build",
|
||||
"cmake",
|
||||
"examples",
|
||||
"scripts",
|
||||
"models",
|
||||
"tests",
|
||||
"CMakeLists.txt",
|
||||
"Makefile"
|
||||
"Makefile",
|
||||
"ggml/src/ggml-metal-embed.metal"
|
||||
],
|
||||
sources: sources,
|
||||
resources: resources,
|
||||
|
|
|
@ -39,7 +39,7 @@ SRC=`pwd`
|
|||
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON"
|
||||
|
||||
if [ ! -z ${GG_BUILD_METAL} ]; then
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
|
||||
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON"
|
||||
fi
|
||||
|
||||
if [ ! -z ${GG_BUILD_CUDA} ]; then
|
||||
|
|
|
@ -1003,6 +1003,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
|
|||
if (s == "f16") {
|
||||
return GGML_TYPE_F16;
|
||||
}
|
||||
if (s == "bf16") {
|
||||
return GGML_TYPE_BF16;
|
||||
}
|
||||
if (s == "q8_0") {
|
||||
return GGML_TYPE_Q8_0;
|
||||
}
|
||||
|
|
|
@ -3748,10 +3748,7 @@ class JaisModel(Model):
|
|||
|
||||
# Embeddings scale
|
||||
self.embeddings_scale = 1.0
|
||||
# note: For some JAIS flavors, output is tied to (same as) wte in original model
|
||||
self.output_is_wte = False
|
||||
if 'mup_embeddings_scale' in self.hparams:
|
||||
self.output_is_wte = True # Hack (?)
|
||||
self.embeddings_scale = self.hparams['mup_embeddings_scale']
|
||||
elif 'embeddings_scale' in self.hparams:
|
||||
self.embeddings_scale = self.hparams['embeddings_scale']
|
||||
|
@ -3808,10 +3805,7 @@ class JaisModel(Model):
|
|||
|
||||
if new_name == self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD):
|
||||
tensors.append((new_name, data_torch * self.embeddings_scale))
|
||||
if self.output_is_wte:
|
||||
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT), data_torch * self.width_scale))
|
||||
elif new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT):
|
||||
assert not self.output_is_wte
|
||||
tensors.append((new_name, data_torch * self.width_scale))
|
||||
else:
|
||||
tensors.append((new_name, data_torch))
|
||||
|
|
|
@ -377,7 +377,7 @@ found 2 SYCL devices:
|
|||
|
||||
|Chosen Device ID|Setting|
|
||||
|-|-|
|
||||
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action|
|
||||
|0|`export ONEAPI_DEVICE_SELECTOR="level_zero:0"` or no action|
|
||||
|1|`export ONEAPI_DEVICE_SELECTOR="level_zero:1"`|
|
||||
|0 & 1|`export ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`|
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@ CUR_PROMPT_CACHE="${CHAT_SAVE_DIR}/current-cache.bin"
|
|||
NEXT_PROMPT_FILE="${CHAT_SAVE_DIR}/next-prompt.txt"
|
||||
NEXT_PROMPT_CACHE="${CHAT_SAVE_DIR}/next-cache.bin"
|
||||
|
||||
SESSION_SIZE_MSG_PATTERN='main: session file matches [[:digit:]]+ / [[:digit:]]+'
|
||||
SAMPLE_TIME_MSG_PATTERN='sample time =[[:space:]]+[[:digit:]]+.[[:digit:]]+ ms /[[:space:]]+[[:digit:]]+'
|
||||
SESSION_AND_SAMPLE_PATTERN='main: session file matches [[:digit:]]+ / [[:digit:]]+'\
|
||||
'|'\
|
||||
'sampling time =[[:space:]]+[[:digit:]]+.[[:digit:]]+ ms /[[:space:]]+[[:digit:]]+'
|
||||
SED_DELETE_MESSAGES="/^(${USER_NAME}:|${AI_NAME}:|\\.\\.\\.)/,\$d"
|
||||
|
||||
CTX_SIZE=2048
|
||||
|
@ -129,15 +130,12 @@ while read -e line; do
|
|||
|
||||
printf ' '
|
||||
|
||||
# HACK get num tokens from debug message
|
||||
# TODO get both messages in one go
|
||||
if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" ||
|
||||
! sample_time_msg="$(tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then
|
||||
if ! session_and_sample_msg=$(tail -n30 "$LOG" | grep -oE "$SESSION_AND_SAMPLE_PATTERN"); then
|
||||
echo >&2 "Couldn't get number of tokens from ./llama-cli output!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
n_tokens=$(($(cut -d/ -f2 <<<"$session_size_msg") + $(cut -d/ -f2 <<<"$sample_time_msg")))
|
||||
n_tokens=$(awk '{sum+=$1} END {print sum}' <<< "$(cut -d/ -f2 <<< "$session_and_sample_msg")")
|
||||
|
||||
if ((n_tokens > CTX_ROTATE_POINT)); then
|
||||
tail -c+$((n_prompt_len_pre + 1)) "$CUR_PROMPT_FILE" >>"$NEXT_PROMPT_FILE"
|
||||
|
|
|
@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
|
|||
if (s == "f16") {
|
||||
return GGML_TYPE_F16;
|
||||
}
|
||||
if (s == "bf16") {
|
||||
return GGML_TYPE_BF16;
|
||||
}
|
||||
if (s == "q8_0") {
|
||||
return GGML_TYPE_Q8_0;
|
||||
}
|
||||
|
|
|
@ -15,22 +15,13 @@ set(TARGET_SRCS
|
|||
httplib.h
|
||||
)
|
||||
set(PUBLIC_ASSETS
|
||||
colorthemes.css
|
||||
style.css
|
||||
theme-beeninorder.css
|
||||
theme-ketivah.css
|
||||
theme-mangotango.css
|
||||
theme-playground.css
|
||||
theme-polarnight.css
|
||||
theme-snowstorm.css
|
||||
index.html
|
||||
index-new.html
|
||||
index.js
|
||||
completion.js
|
||||
system-prompts.js
|
||||
prompt-formats.js
|
||||
json-schema-to-grammar.mjs
|
||||
loading.html
|
||||
deps_daisyui.min.css
|
||||
deps_markdown-it.js
|
||||
deps_tailwindcss.js
|
||||
deps_vue.esm-browser.js
|
||||
)
|
||||
|
||||
foreach(asset ${PUBLIC_ASSETS})
|
||||
|
|
|
@ -928,6 +928,16 @@ Apart from error types supported by OAI, we also have custom types that are spec
|
|||
}
|
||||
```
|
||||
|
||||
### Legacy completion web UI
|
||||
|
||||
A new chat-based UI has replaced the old completion-based since [this PR](https://github.com/ggerganov/llama.cpp/pull/10175). If you want to use the old completion, start the server with `--path ./examples/server/public_legacy`
|
||||
|
||||
For example:
|
||||
|
||||
```sh
|
||||
./llama-server -m my_model.gguf -c 8192 --path ./examples/server/public_legacy
|
||||
```
|
||||
|
||||
### Extending or building alternative Web Front End
|
||||
|
||||
You can extend the front end by running the server binary with `--path` set to `./your-directory` and importing `/completion.js` to get access to the llamaComplete() method.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import * as readline from 'node:readline'
|
||||
import { stdin, stdout } from 'node:process'
|
||||
import { readFileSync } from 'node:fs'
|
||||
import { SchemaConverter } from './public/json-schema-to-grammar.mjs'
|
||||
import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs'
|
||||
|
||||
const args = process.argv.slice(2);
|
||||
const grammarJsonSchemaFile = args.find(
|
||||
|
|
|
@ -6,5 +6,20 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
|||
PUBLIC=$DIR/public
|
||||
|
||||
echo "download js bundle files"
|
||||
curl https://npm.reversehttp.com/@preact/signals-core,@preact/signals,htm/preact,preact,preact/hooks > $PUBLIC/index.js
|
||||
echo >> $PUBLIC/index.js # add newline
|
||||
|
||||
# Note for contributors: Always pin to a specific version "maj.min.patch" to avoid breaking the CI
|
||||
|
||||
curl -L https://cdn.tailwindcss.com/3.4.14 > $PUBLIC/deps_tailwindcss.js
|
||||
echo >> $PUBLIC/deps_tailwindcss.js # add newline
|
||||
|
||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/styled.min.css > $PUBLIC/deps_daisyui.min.css
|
||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/daisyui/4.12.14/themes.min.css >> $PUBLIC/deps_daisyui.min.css
|
||||
echo >> $PUBLIC/deps_daisyui.min.css # add newline
|
||||
|
||||
curl -L https://unpkg.com/vue@3.5.12/dist/vue.esm-browser.js > $PUBLIC/deps_vue.esm-browser.js
|
||||
echo >> $PUBLIC/deps_vue.esm-browser.js # add newline
|
||||
|
||||
curl -L https://cdnjs.cloudflare.com/ajax/libs/markdown-it/13.0.2/markdown-it.js > $PUBLIC/deps_markdown-it.js
|
||||
echo >> $PUBLIC/deps_markdown-it.js # add newline
|
||||
|
||||
ls -lah $PUBLIC
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
const paramDefaults = {
|
||||
stream: true,
|
||||
n_predict: 500,
|
||||
temperature: 0.2,
|
||||
stop: ["</s>"]
|
||||
};
|
||||
|
||||
let generation_settings = null;
|
||||
|
||||
export class CompletionError extends Error {
|
||||
constructor(message, name, data) {
|
||||
super(message);
|
||||
this.name = name;
|
||||
}
|
||||
};
|
||||
|
||||
// Completes the prompt as a generator. Recommended for most use cases.
|
||||
//
|
||||
|
@ -29,7 +33,7 @@ export async function* llama(prompt, params = {}, config = {}) {
|
|||
|
||||
const completionParams = { ...paramDefaults, ...params, prompt };
|
||||
|
||||
const response = await fetch(`${api_url}/completion`, {
|
||||
const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(completionParams),
|
||||
headers: {
|
||||
|
@ -41,6 +45,18 @@ export async function* llama(prompt, params = {}, config = {}) {
|
|||
signal: controller.signal,
|
||||
});
|
||||
|
||||
const status = response.status;
|
||||
if (status !== 200) {
|
||||
try {
|
||||
const body = await response.json();
|
||||
if (body && body.error && body.error.message) {
|
||||
throw new CompletionError(body.error.message, 'ServerError');
|
||||
}
|
||||
} catch (err) {
|
||||
throw new CompletionError(err.message, 'ServerError');
|
||||
}
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
|
@ -78,7 +94,12 @@ export async function* llama(prompt, params = {}, config = {}) {
|
|||
for (const line of lines) {
|
||||
const match = regex.exec(line);
|
||||
if (match) {
|
||||
result[match[1]] = match[2]
|
||||
result[match[1]] = match[2];
|
||||
if (result.data === '[DONE]') {
|
||||
cont = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// since we know this is llama.cpp, let's just decode the json in data
|
||||
if (result.data) {
|
||||
result.data = JSON.parse(result.data);
|
||||
|
|
13
examples/server/public/deps_daisyui.min.css
vendored
Normal file
13
examples/server/public/deps_daisyui.min.css
vendored
Normal file
File diff suppressed because one or more lines are too long
8442
examples/server/public/deps_markdown-it.js
Normal file
8442
examples/server/public/deps_markdown-it.js
Normal file
File diff suppressed because it is too large
Load diff
82
examples/server/public/deps_tailwindcss.js
Normal file
82
examples/server/public/deps_tailwindcss.js
Normal file
File diff suppressed because one or more lines are too long
18160
examples/server/public/deps_vue.esm-browser.js
Normal file
18160
examples/server/public/deps_vue.esm-browser.js
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
209
examples/server/public_legacy/completion.js
Normal file
209
examples/server/public_legacy/completion.js
Normal file
|
@ -0,0 +1,209 @@
|
|||
const paramDefaults = {
|
||||
stream: true,
|
||||
n_predict: 500,
|
||||
temperature: 0.2,
|
||||
stop: ["</s>"]
|
||||
};
|
||||
|
||||
let generation_settings = null;
|
||||
|
||||
|
||||
// Completes the prompt as a generator. Recommended for most use cases.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import { llama } from '/completion.js'
|
||||
//
|
||||
// const request = llama("Tell me a joke", {n_predict: 800})
|
||||
// for await (const chunk of request) {
|
||||
// document.write(chunk.data.content)
|
||||
// }
|
||||
//
|
||||
export async function* llama(prompt, params = {}, config = {}) {
|
||||
let controller = config.controller;
|
||||
const api_url = config.api_url?.replace(/\/+$/, '') || "";
|
||||
|
||||
if (!controller) {
|
||||
controller = new AbortController();
|
||||
}
|
||||
|
||||
const completionParams = { ...paramDefaults, ...params, prompt };
|
||||
|
||||
const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
|
||||
method: 'POST',
|
||||
body: JSON.stringify(completionParams),
|
||||
headers: {
|
||||
'Connection': 'keep-alive',
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'text/event-stream',
|
||||
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
|
||||
},
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
let content = "";
|
||||
let leftover = ""; // Buffer for partially read lines
|
||||
|
||||
try {
|
||||
let cont = true;
|
||||
|
||||
while (cont) {
|
||||
const result = await reader.read();
|
||||
if (result.done) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Add any leftover data to the current chunk of data
|
||||
const text = leftover + decoder.decode(result.value);
|
||||
|
||||
// Check if the last character is a line break
|
||||
const endsWithLineBreak = text.endsWith('\n');
|
||||
|
||||
// Split the text into lines
|
||||
let lines = text.split('\n');
|
||||
|
||||
// If the text doesn't end with a line break, then the last line is incomplete
|
||||
// Store it in leftover to be added to the next chunk of data
|
||||
if (!endsWithLineBreak) {
|
||||
leftover = lines.pop();
|
||||
} else {
|
||||
leftover = ""; // Reset leftover if we have a line break at the end
|
||||
}
|
||||
|
||||
// Parse all sse events and add them to result
|
||||
const regex = /^(\S+):\s(.*)$/gm;
|
||||
for (const line of lines) {
|
||||
const match = regex.exec(line);
|
||||
if (match) {
|
||||
result[match[1]] = match[2];
|
||||
if (result.data === '[DONE]') {
|
||||
cont = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// since we know this is llama.cpp, let's just decode the json in data
|
||||
if (result.data) {
|
||||
result.data = JSON.parse(result.data);
|
||||
content += result.data.content;
|
||||
|
||||
// yield
|
||||
yield result;
|
||||
|
||||
// if we got a stop token from server, we will break here
|
||||
if (result.data.stop) {
|
||||
if (result.data.generation_settings) {
|
||||
generation_settings = result.data.generation_settings;
|
||||
}
|
||||
cont = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (result.error) {
|
||||
try {
|
||||
result.error = JSON.parse(result.error);
|
||||
if (result.error.message.includes('slot unavailable')) {
|
||||
// Throw an error to be caught by upstream callers
|
||||
throw new Error('slot unavailable');
|
||||
} else {
|
||||
console.error(`llama.cpp error [${result.error.code} - ${result.error.type}]: ${result.error.message}`);
|
||||
}
|
||||
} catch(e) {
|
||||
console.error(`llama.cpp error ${result.error}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
if (e.name !== 'AbortError') {
|
||||
console.error("llama error: ", e);
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
finally {
|
||||
controller.abort();
|
||||
}
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
// Call llama, return an event target that you can subscribe to
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// import { llamaEventTarget } from '/completion.js'
|
||||
//
|
||||
// const conn = llamaEventTarget(prompt)
|
||||
// conn.addEventListener("message", (chunk) => {
|
||||
// document.write(chunk.detail.content)
|
||||
// })
|
||||
//
|
||||
export const llamaEventTarget = (prompt, params = {}, config = {}) => {
|
||||
const eventTarget = new EventTarget();
|
||||
(async () => {
|
||||
let content = "";
|
||||
for await (const chunk of llama(prompt, params, config)) {
|
||||
if (chunk.data) {
|
||||
content += chunk.data.content;
|
||||
eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data }));
|
||||
}
|
||||
if (chunk.data.generation_settings) {
|
||||
eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings }));
|
||||
}
|
||||
if (chunk.data.timings) {
|
||||
eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings }));
|
||||
}
|
||||
}
|
||||
eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } }));
|
||||
})();
|
||||
return eventTarget;
|
||||
}
|
||||
|
||||
// Call llama, return a promise that resolves to the completed text. This does not support streaming
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// llamaPromise(prompt).then((content) => {
|
||||
// document.write(content)
|
||||
// })
|
||||
//
|
||||
// or
|
||||
//
|
||||
// const content = await llamaPromise(prompt)
|
||||
// document.write(content)
|
||||
//
|
||||
export const llamaPromise = (prompt, params = {}, config = {}) => {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
let content = "";
|
||||
try {
|
||||
for await (const chunk of llama(prompt, params, config)) {
|
||||
content += chunk.data.content;
|
||||
}
|
||||
resolve(content);
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* (deprecated)
|
||||
*/
|
||||
export const llamaComplete = async (params, controller, callback) => {
|
||||
for await (const chunk of llama(params.prompt, params, { controller })) {
|
||||
callback(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the model info from the server. This is useful for getting the context window and so on.
|
||||
export const llamaModelInfo = async (config = {}) => {
|
||||
if (!generation_settings) {
|
||||
const api_url = config.api_url?.replace(/\/+$/, '') || "";
|
||||
const props = await fetch(`${api_url}/props`).then(r => r.json());
|
||||
generation_settings = props.default_generation_settings;
|
||||
}
|
||||
return generation_settings;
|
||||
}
|
Before Width: | Height: | Size: 4 KiB After Width: | Height: | Size: 4 KiB |
1303
examples/server/public_legacy/index.html
Normal file
1303
examples/server/public_legacy/index.html
Normal file
File diff suppressed because it is too large
Load diff
12
examples/server/public_legacy/loading.html
Normal file
12
examples/server/public_legacy/loading.html
Normal file
|
@ -0,0 +1,12 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="refresh" content="5">
|
||||
</head>
|
||||
<body>
|
||||
<div id="loading">
|
||||
The model is loading. Please wait.<br/>
|
||||
The user interface will appear soon.
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
|
@ -14,22 +14,13 @@
|
|||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
#include "colorthemes.css.hpp"
|
||||
#include "style.css.hpp"
|
||||
#include "theme-beeninorder.css.hpp"
|
||||
#include "theme-ketivah.css.hpp"
|
||||
#include "theme-mangotango.css.hpp"
|
||||
#include "theme-playground.css.hpp"
|
||||
#include "theme-polarnight.css.hpp"
|
||||
#include "theme-snowstorm.css.hpp"
|
||||
#include "index.html.hpp"
|
||||
#include "index-new.html.hpp"
|
||||
#include "index.js.hpp"
|
||||
#include "completion.js.hpp"
|
||||
#include "system-prompts.js.hpp"
|
||||
#include "prompt-formats.js.hpp"
|
||||
#include "json-schema-to-grammar.mjs.hpp"
|
||||
#include "loading.html.hpp"
|
||||
#include "deps_daisyui.min.css.hpp"
|
||||
#include "deps_markdown-it.js.hpp"
|
||||
#include "deps_tailwindcss.js.hpp"
|
||||
#include "deps_vue.esm-browser.js.hpp"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
|
@ -378,8 +369,8 @@ struct server_queue {
|
|||
std::condition_variable condition_tasks;
|
||||
|
||||
// callback functions
|
||||
std::function<void(server_task&)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
std::function<void(server_task)> callback_new_task;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task, bool front = false) {
|
||||
|
@ -431,7 +422,7 @@ struct server_queue {
|
|||
}
|
||||
|
||||
// Register function to process a new task
|
||||
void on_new_task(std::function<void(server_task &)> callback) {
|
||||
void on_new_task(std::function<void(server_task)> callback) {
|
||||
callback_new_task = std::move(callback);
|
||||
}
|
||||
|
||||
|
@ -481,7 +472,7 @@ struct server_queue {
|
|||
lock.unlock();
|
||||
|
||||
QUE_DBG("processing task, id = %d\n", task.id);
|
||||
callback_new_task(task);
|
||||
callback_new_task(std::move(task));
|
||||
}
|
||||
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
|
@ -644,17 +635,12 @@ struct server_context {
|
|||
bool load_model(const common_params & params_) {
|
||||
params = params_;
|
||||
|
||||
// reserve one extra sequence (seq_id == 0) for extra features
|
||||
params.n_parallel += 1;
|
||||
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
loras = llama_init.lora_adapters;
|
||||
|
||||
params.n_parallel -= 1; // but be sneaky about it
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
||||
return false;
|
||||
|
@ -1288,16 +1274,16 @@ struct server_context {
|
|||
|
||||
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
res.id = slot.id_task;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
|
||||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
std::vector<float> embd_res(n_embd, 0.0f);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1332,12 +1318,12 @@ struct server_context {
|
|||
|
||||
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
||||
server_task_result res;
|
||||
res.id = slot.id_task;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
res.id = slot.id_task;
|
||||
res.error = false;
|
||||
res.stop = true;
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
||||
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -1510,7 +1496,7 @@ struct server_context {
|
|||
// Functions to process the task
|
||||
//
|
||||
|
||||
void process_single_task(const server_task & task) {
|
||||
void process_single_task(server_task task) {
|
||||
switch (task.type) {
|
||||
case SERVER_TASK_TYPE_INFERENCE:
|
||||
{
|
||||
|
@ -1646,7 +1632,7 @@ struct server_context {
|
|||
std::string filename = task.data.at("filename");
|
||||
std::string filepath = task.data.at("filepath");
|
||||
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
|
||||
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||
|
@ -1688,7 +1674,7 @@ struct server_context {
|
|||
|
||||
slot->cache_tokens.resize(slot->n_ctx);
|
||||
size_t token_count = 0;
|
||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||
if (nread == 0) {
|
||||
slot->cache_tokens.resize(0);
|
||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||
|
@ -1731,7 +1717,7 @@ struct server_context {
|
|||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
server_task_result result;
|
||||
|
@ -1808,8 +1794,8 @@ struct server_context {
|
|||
|
||||
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
|
@ -1836,7 +1822,7 @@ struct server_context {
|
|||
|
||||
slot.i_batch = batch.n_tokens;
|
||||
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
|
||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||
|
||||
slot.n_past += 1;
|
||||
|
||||
|
@ -1983,8 +1969,8 @@ struct server_context {
|
|||
|
||||
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
|
||||
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
|
||||
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
|
||||
|
||||
for (size_t i = 0; i < n_match; i++) {
|
||||
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
||||
|
@ -2033,9 +2019,9 @@ struct server_context {
|
|||
}
|
||||
|
||||
// keep only the common part
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
||||
|
||||
// there is no common part left
|
||||
slot.n_past = 0;
|
||||
|
@ -2048,7 +2034,7 @@ struct server_context {
|
|||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||
|
@ -2290,16 +2276,6 @@ int main(int argc, char ** argv) {
|
|||
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
||||
|
||||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||
|
||||
// CORS preflight
|
||||
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// Access-Control-Allow-Origin is already set by middleware
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
return res.set_content("", "text/html"); // blank response, no data
|
||||
});
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, const json & error_data) {
|
||||
|
@ -2412,6 +2388,14 @@ int main(int argc, char ** argv) {
|
|||
// register server middlewares
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
||||
if (req.method == "OPTIONS") {
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
res.set_content("", "text/html"); // blank response, no data
|
||||
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
||||
}
|
||||
if (!middleware_server_state(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
|
@ -3121,33 +3105,19 @@ int main(int argc, char ** argv) {
|
|||
// register static assets routes
|
||||
if (!params.public_path.empty()) {
|
||||
// Set the base directory for serving static files
|
||||
svr->set_base_dir(params.public_path);
|
||||
}
|
||||
|
||||
if (!params.api_keys.empty()) {
|
||||
// for now, if API key is set, web UI is unusable
|
||||
svr->Get("/", [&](const httplib::Request &, httplib::Response & res) {
|
||||
return res.set_content("Web UI is disabled because API key is set.", "text/html; charset=utf-8");
|
||||
});
|
||||
bool is_found = svr->set_mount_point("/", params.public_path);
|
||||
if (!is_found) {
|
||||
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
// using embedded static files
|
||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
|
||||
|
||||
// add new-ui files
|
||||
svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/deps_daisyui.min.css", handle_static_file(deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8"));
|
||||
svr->Get("/deps_markdown-it.js", handle_static_file(deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/deps_tailwindcss.js", handle_static_file(deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8"));
|
||||
svr->Get("/deps_vue.esm-browser.js", handle_static_file(deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8"));
|
||||
}
|
||||
|
||||
// register API routes
|
||||
|
|
|
@ -64,5 +64,5 @@ Feature: Security
|
|||
| localhost | Access-Control-Allow-Origin | localhost |
|
||||
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
||||
| origin | Access-Control-Allow-Credentials | true |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | GET, POST |
|
||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
||||
|
|
6
flake.lock
generated
6
flake.lock
generated
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1730200266,
|
||||
"narHash": "sha256-l253w0XMT8nWHGXuXqyiIC/bMvh1VRszGXgdpQlfhvU=",
|
||||
"lastModified": 1730785428,
|
||||
"narHash": "sha256-Zwl8YgTVJTEum+L+0zVAWvXAGbWAuXHax3KzuejaDyo=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "807e9154dcb16384b1b765ebe9cd2bba2ac287fd",
|
||||
"rev": "4aa36568d413aca0ea84a1684d2d46f55dbabad7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation"
|
|||
option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF)
|
||||
option(GGML_KOMPUTE "ggml: use Kompute" OFF)
|
||||
option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT})
|
||||
option(GGML_METAL_USE_BF16 "ggml: use bfloat if available" OFF)
|
||||
option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF)
|
||||
option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF)
|
||||
option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL})
|
||||
|
@ -218,12 +219,12 @@ include(CMakePackageConfigHelpers)
|
|||
# all public headers
|
||||
set(GGML_PUBLIC_HEADERS
|
||||
include/ggml.h
|
||||
include/ggml-cpu.h
|
||||
include/ggml-alloc.h
|
||||
include/ggml-backend.h
|
||||
include/ggml-blas.h
|
||||
include/ggml-cann.h
|
||||
include/ggml-cuda.h
|
||||
include/ggml.h
|
||||
include/ggml-kompute.h
|
||||
include/ggml-metal.h
|
||||
include/ggml-rpc.h
|
||||
|
|
|
@ -509,7 +509,7 @@ extern "C" {
|
|||
GGML_OP_WIN_UNPART,
|
||||
GGML_OP_GET_REL_POS,
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
@ -1746,6 +1746,9 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
enum ggml_prec prec);
|
||||
|
||||
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||
const struct ggml_tensor * a);
|
||||
|
||||
// TODO: needs to be adapted to ggml_flash_attn_ext
|
||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -1819,7 +1822,7 @@ extern "C" {
|
|||
struct ggml_tensor * pw,
|
||||
struct ggml_tensor * ph);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv(
|
||||
GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
|
|
|
@ -58,6 +58,10 @@ if (GGML_METAL)
|
|||
add_compile_definitions(GGML_METAL_NDEBUG)
|
||||
endif()
|
||||
|
||||
if (GGML_METAL_USE_BF16)
|
||||
add_compile_definitions(GGML_METAL_USE_BF16)
|
||||
endif()
|
||||
|
||||
# copy ggml-common.h and ggml-metal.metal to bin directory
|
||||
configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
|
||||
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
|
||||
|
@ -1261,8 +1265,13 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
message(STATUS "PowerPC detected")
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
||||
execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
|
||||
OUTPUT_VARIABLE POWER10_M)
|
||||
string(FIND ${POWER10_M} "POWER10" substring_index)
|
||||
if(${substring_index} GREATER_EQUAL 0)
|
||||
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
||||
else()
|
||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
||||
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
||||
|
|
|
@ -409,6 +409,8 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
|
|||
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_8_8] = {
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 8,
|
||||
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
||||
|
@ -11642,24 +11644,30 @@ static void ggml_compute_forward_add_rel_pos(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_rwkv_wkv
|
||||
// ggml_compute_forward_rwkv_wkv6
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv_f32(
|
||||
static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const size_t T = dst->src[1]->ne[3];
|
||||
const size_t C = dst->ne[0];
|
||||
const size_t H = dst->src[1]->ne[2];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[1]->ne[3];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[2];
|
||||
const int64_t n_seqs = dst->src[5]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
if (params->ith != 0) {
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
|
@ -11667,54 +11675,160 @@ static void ggml_compute_forward_rwkv_wkv_f32(
|
|||
float * time_faaaa = (float *) dst->src[3]->data;
|
||||
float * time_decay = (float *) dst->src[4]->data;
|
||||
|
||||
size_t t_stride = H * (C / H);
|
||||
size_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
size_t h_stride = C / H;
|
||||
size_t h_stride_2d = (C / H) * (C / H);
|
||||
size_t h_stride = C / HEADS;
|
||||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
// basically fused operations:
|
||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||
// state = time_decay * state + (k @ v),
|
||||
// recursive through each token
|
||||
for (size_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = (C / H) * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
if (ith == 0) {
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
}
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
for (size_t h = 0; h < H; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (size_t i = 0; i < C / H; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
#if defined(__AVX__) && !defined(__AVX512F__)
|
||||
#define GGML_F32X GGML_F32x8
|
||||
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x8_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x8_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x8_FMA
|
||||
#define WKV_VECTOR_SIZE 8
|
||||
#elif defined(__AVX512F__)
|
||||
#define GGML_F32X GGML_F32x16
|
||||
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x16_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x16_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x16_FMA
|
||||
#define WKV_VECTOR_SIZE 16
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#define GGML_F32X GGML_F32x4
|
||||
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x4_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x4_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x4_FMA
|
||||
#define WKV_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
// RWKV v6: different time_decay for each token.
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
#ifdef WKV_VECTOR_SIZE
|
||||
const int64_t vec_count = head_size / WKV_VECTOR_SIZE;
|
||||
|
||||
for (size_t j = 0; j < C / H; j ++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
// Broadcast scalar values to vectors
|
||||
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
||||
GGML_F32X r_vec = GGML_F32X_SET1(r_val);
|
||||
GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
|
||||
GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
|
||||
|
||||
for (int64_t j = 0; j < vec_count; j++) {
|
||||
size_t base_j = j * WKV_VECTOR_SIZE;
|
||||
size_t t_h_j_offset = t_h_offset + base_j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||
|
||||
// Load x elements at once
|
||||
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
||||
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k
|
||||
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
||||
|
||||
// Compute temp = kv * time_faaaa + prev_state
|
||||
GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
|
||||
|
||||
// Update dst: dst += temp * r
|
||||
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
|
||||
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
||||
|
||||
// Update state: state = prev_state * time_decay + kv
|
||||
GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
|
||||
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
|
||||
}
|
||||
|
||||
// Handle remaining elements, this will not be used.
|
||||
for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// basically fused operations:
|
||||
// dst = r @ (time_faaaa * (k @ v) + state),
|
||||
// state = time_decay * state + (k @ v),
|
||||
// recursive through each token
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_i_offset = h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float r_val = r[t_h_i_offset];
|
||||
float time_faaaa_val = time_faaaa[h_i_offset];
|
||||
// RWKV v6: different time_decay for each token.
|
||||
float time_decay_val = time_decay[t_h_i_offset];
|
||||
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val * time_faaaa_val + prev_state_val;
|
||||
dst_data[t_h_j_offset] += temp_val * r_val;
|
||||
state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv(
|
||||
|
||||
static void ggml_compute_forward_rwkv_wkv6(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
|
@ -11723,7 +11837,7 @@ static void ggml_compute_forward_rwkv_wkv(
|
|||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv_f32(params, dst);
|
||||
ggml_compute_forward_rwkv_wkv6_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
|
@ -12475,9 +12589,9 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_add_rel_pos(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
ggml_compute_forward_rwkv_wkv(params, tensor);
|
||||
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
|
@ -12775,7 +12889,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_WIN_PART:
|
||||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
|
|
@ -36,7 +36,7 @@
|
|||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/rwkv-wkv.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -2319,8 +2319,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
ggml_cuda_cross_entropy_loss(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
ggml_cuda_op_rwkv_wkv(ctx, dst);
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
|
@ -3153,12 +3153,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_ARANGE:
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
return false;
|
||||
#endif
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
const int64_t ne = ggml_nelements(src0);
|
||||
GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
|
||||
const int64_t dne = GGML_PAD(ne / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
|
||||
const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
|
||||
|
||||
CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
|
||||
|
||||
|
|
|
@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (precision != GGML_PREC_DEFAULT) {
|
||||
if (prec != GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||
constexpr int cols_per_block = 16;
|
||||
switch (Q->ne[0]) {
|
||||
|
@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int32_t precision = KQV->op_params[3];
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (cc >= CC_OFFSET_AMD) {
|
||||
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
|
@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|||
}
|
||||
|
||||
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
||||
if (precision == GGML_PREC_DEFAULT) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
return;
|
||||
} else if(Q->ne[0] <= 128) {
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -1,5 +1,5 @@
|
|||
#include "common.cuh"
|
||||
#include "rwkv-wkv.cuh"
|
||||
#include "wkv6.cuh"
|
||||
|
||||
static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -64,7 +64,7 @@ static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
|
@ -83,7 +83,7 @@ void ggml_cuda_op_rwkv_wkv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE);
|
||||
GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE); // The current cuda kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
rwkv_wkv_f32<<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
|
||||
}
|
5
ggml/src/ggml-cuda/wkv6.cuh
Normal file
5
ggml/src/ggml-cuda/wkv6.cuh
Normal file
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_WKV_BLOCK_SIZE 64
|
||||
|
||||
void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -26,5 +26,8 @@
|
|||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
#include "wkv6.hpp"
|
||||
#include "outprod.hpp"
|
||||
#include "element_wise.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
@ -62,3 +62,43 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block
|
|||
}
|
||||
return sycl_down_blk_size;
|
||||
}
|
||||
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const ggml_sycl_op_flatten_t op) try {
|
||||
const int64_t nrows0 = ggml_nrows(src0);
|
||||
|
||||
const bool use_src1 = src1 != nullptr;
|
||||
const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
|
||||
|
||||
GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||
GGML_ASSERT( dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
|
||||
|
||||
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
|
||||
ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
// dd = data device
|
||||
float * src0_ddf = (float *) src0->data;
|
||||
float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
|
||||
float * dst_ddf = (float *) dst->data;
|
||||
|
||||
ggml_sycl_pool_alloc<float> src0_f(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> src1_f(ctx.pool());
|
||||
ggml_sycl_pool_alloc<float> dst_f(ctx.pool());
|
||||
|
||||
ggml_sycl_set_device(ctx.device);
|
||||
queue_ptr main_stream = ctx.stream();
|
||||
// GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
|
||||
// ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
|
||||
|
||||
// do the computation
|
||||
op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
|
||||
// print_ggml_tensor("tensor", dst);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
<< ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
|
|
@ -404,4 +404,262 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
|||
|
||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
||||
|
||||
typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
|
||||
item_ct1.get_local_id(1));
|
||||
const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) /
|
||||
ne3;
|
||||
const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
|
||||
item_ct1.get_local_id(0)) %
|
||||
ne3;
|
||||
|
||||
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i_src0;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
for (int i0 = i0s; i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
}
|
||||
|
||||
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
|
||||
static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
|
||||
int ne0, int ne1, int ne2, int ne3,
|
||||
int ne10, int ne11, int ne12, int ne13,
|
||||
/*int s0, */ int s1, int s2, int s3,
|
||||
/*int s10,*/ int s11, int s12, int s13,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
const int i3 = i/(ne2*ne1*ne0);
|
||||
const int i2 = (i/(ne1*ne0)) % ne2;
|
||||
const int i1 = (i/ne0) % ne1;
|
||||
const int i0 = i % ne0;
|
||||
|
||||
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int i11 = i1 % ne11;
|
||||
const int i12 = i2 % ne12;
|
||||
const int i13 = i3 % ne13;
|
||||
|
||||
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
|
||||
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
||||
const size_t i_dst = i_src0;
|
||||
|
||||
const src0_t * src0_row = src0 + i_src0;
|
||||
const src1_t * src1_row = src1 + i_src1;
|
||||
dst_t * dst_row = dst + i_dst;
|
||||
|
||||
const int i10 = i0 % ne10;
|
||||
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
|
||||
}
|
||||
|
||||
|
||||
template<float (*bin_op)(const float, const float)>
|
||||
struct bin_bcast_sycl {
|
||||
template <typename src0_t, typename src1_t, typename dst_t>
|
||||
void operator()(ggml_backend_sycl_context & ctx,
|
||||
const struct ggml_tensor *src0,
|
||||
const struct ggml_tensor *src1, struct ggml_tensor *dst,
|
||||
const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
|
||||
queue_ptr stream) {
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
int nr0 = ne10/ne0;
|
||||
int nr1 = ne11/ne1;
|
||||
int nr2 = ne12/ne2;
|
||||
int nr3 = ne13/ne3;
|
||||
|
||||
int nr[4] = { nr0, nr1, nr2, nr3 };
|
||||
|
||||
// collapse dimensions until first broadcast dimension
|
||||
int64_t cne0[] = {ne0, ne1, ne2, ne3};
|
||||
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
||||
size_t cnb0[] = {nb0, nb1, nb2, nb3};
|
||||
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
||||
auto collapse = [](int64_t cne[]) {
|
||||
cne[0] *= cne[1];
|
||||
cne[1] = cne[2];
|
||||
cne[2] = cne[3];
|
||||
cne[3] = 1;
|
||||
};
|
||||
|
||||
auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
|
||||
cnb[1] *= cne[1];
|
||||
cnb[2] *= cne[2];
|
||||
cnb[3] *= cne[3];
|
||||
};
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if (nr[i] != 1) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
collapse_nb(cnb0, cne0);
|
||||
collapse_nb(cnb1, cne1);
|
||||
collapse(cne0);
|
||||
collapse(cne1);
|
||||
}
|
||||
}
|
||||
{
|
||||
int64_t ne0 = cne0[0];
|
||||
int64_t ne1 = cne0[1];
|
||||
int64_t ne2 = cne0[2];
|
||||
int64_t ne3 = cne0[3];
|
||||
|
||||
int64_t ne10 = cne1[0];
|
||||
int64_t ne11 = cne1[1];
|
||||
int64_t ne12 = cne1[2];
|
||||
int64_t ne13 = cne1[3];
|
||||
|
||||
size_t nb0 = cnb0[0];
|
||||
size_t nb1 = cnb0[1];
|
||||
size_t nb2 = cnb0[2];
|
||||
size_t nb3 = cnb0[3];
|
||||
|
||||
size_t nb10 = cnb1[0];
|
||||
size_t nb11 = cnb1[1];
|
||||
size_t nb12 = cnb1[2];
|
||||
size_t nb13 = cnb1[3];
|
||||
|
||||
size_t s0 = nb0 / sizeof(dst_t);
|
||||
size_t s1 = nb1 / sizeof(dst_t);
|
||||
size_t s2 = nb2 / sizeof(dst_t);
|
||||
size_t s3 = nb3 / sizeof(dst_t);
|
||||
|
||||
size_t s10 = nb10 / sizeof(src1_t);
|
||||
size_t s11 = nb11 / sizeof(src1_t);
|
||||
size_t s12 = nb12 / sizeof(src1_t);
|
||||
size_t s13 = nb13 / sizeof(src1_t);
|
||||
|
||||
GGML_ASSERT(s0 == 1);
|
||||
GGML_ASSERT(s10 == 1);
|
||||
|
||||
const int block_size = 128;
|
||||
|
||||
int64_t hne0 = std::max(ne0/2LL, 1LL);
|
||||
|
||||
sycl::range<3> block_dims(1, 1, 1);
|
||||
block_dims[2] = std::min<unsigned int>(hne0, block_size);
|
||||
block_dims[1] = std::min<unsigned int>(
|
||||
ne1, block_size / (unsigned int)block_dims[2]);
|
||||
block_dims[0] = std::min(
|
||||
std::min<unsigned int>(
|
||||
ne2 * ne3, block_size / (unsigned int)block_dims[2] /
|
||||
(unsigned int)block_dims[1]),
|
||||
64U);
|
||||
|
||||
sycl::range<3> block_nums(
|
||||
(ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
|
||||
(ne1 + block_dims[1] - 1) / block_dims[1],
|
||||
(hne0 + block_dims[2] - 1) / block_dims[2]);
|
||||
|
||||
if (block_nums[0] > 65535) {
|
||||
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
|
||||
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
|
||||
sycl::range<3>(1, 1, block_size),
|
||||
sycl::range<3>(1, 1, block_size)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast_unravel<bin_op>(
|
||||
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
|
||||
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
|
||||
s13, item_ct1);
|
||||
});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
DPCT1049:16: The work-group size passed to the SYCL kernel may
|
||||
exceed the limit. To get the device limit, query
|
||||
info::device::max_work_group_size. Adjust the work-group size if
|
||||
needed.
|
||||
*/
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
|
||||
ne2, ne3, ne10, ne11, ne12, ne13,
|
||||
s1, s2, s3, s11, s12, s13,
|
||||
item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class op>
|
||||
inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
|
||||
(sycl::half *)dst_dd, main_stream);
|
||||
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
|
||||
op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
|
||||
op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
|
||||
main_stream);
|
||||
} else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
|
||||
op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
|
||||
main_stream);
|
||||
} else {
|
||||
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
|
||||
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const ggml_sycl_op_flatten_t op);
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
|
|
@ -106,6 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
|||
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||
});
|
||||
break;
|
||||
// dim >=2 will be dispatched to the default path
|
||||
default:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
|
|
1011
ggml/src/ggml-sycl/element_wise.cpp
Normal file
1011
ggml/src/ggml-sycl/element_wise.cpp
Normal file
File diff suppressed because it is too large
Load diff
76
ggml/src/ggml-sycl/element_wise.hpp
Normal file
76
ggml/src/ggml-sycl/element_wise.hpp
Normal file
|
@ -0,0 +1,76 @@
|
|||
#ifndef GGML_SYCL_ELEMENTWISE_HPP
|
||||
#define GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
static __dpct_inline__ float op_repeat(const float a, const float b) {
|
||||
return b;
|
||||
GGML_UNUSED(a);
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_add(const float a, const float b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_sub(const float a, const float b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_mul(const float a, const float b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
static __dpct_inline__ float op_div(const float a, const float b) {
|
||||
return a / b;
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
55
ggml/src/ggml-sycl/outprod.cpp
Normal file
55
ggml/src/ggml-sycl/outprod.cpp
Normal file
|
@ -0,0 +1,55 @@
|
|||
#include <sycl/sycl.hpp>
|
||||
#include "outprod.hpp"
|
||||
|
||||
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
// Get SYCL queue
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Dimension checks
|
||||
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
||||
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
||||
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
||||
|
||||
// Get data pointers
|
||||
const float* src0_d = (const float*)src0->data;
|
||||
const float* src1_d = (const float*)src1->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
// GEMM parameters
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
|
||||
// Handle transposition of src1
|
||||
const bool src1_T = ggml_is_transposed(src1);
|
||||
const oneapi::mkl::transpose src1_op =
|
||||
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
||||
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
||||
|
||||
try {
|
||||
// Perform matrix multiplication using oneMKL GEMM
|
||||
oneapi::mkl::blas::gemm(*stream,
|
||||
oneapi::mkl::transpose::nontrans, src1_op,
|
||||
ne0, ne1, ne01,
|
||||
alpha,
|
||||
src0_d, ne00,
|
||||
src1_d, ldb,
|
||||
beta,
|
||||
dst_d, ne0);
|
||||
}
|
||||
catch (sycl::exception const& exc) {
|
||||
std::cerr << exc.what() << std::endl;
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
11
ggml/src/ggml-sycl/outprod.hpp
Normal file
11
ggml/src/ggml-sycl/outprod.hpp
Normal file
|
@ -0,0 +1,11 @@
|
|||
#ifndef GGML_SYCL_OUTPROD_HPP
|
||||
#define GGML_SYCL_OUTPROD_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_OUTPROD_HPP
|
||||
|
|
@ -25,6 +25,11 @@
|
|||
#define SYCL_RELU_BLOCK_SIZE 256
|
||||
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
||||
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
||||
#define SYCL_EXP_BLOCK_SIZE 256
|
||||
#define SYCL_NEG_BLOCK_SIZE 256
|
||||
#define SYCL_SIGMOID_BLOCK_SIZE 256
|
||||
#define SYCL_SQRT_BLOCK_SIZE 256
|
||||
#define SYCL_SIN_BLOCK_SIZE 256
|
||||
#define SYCL_SQR_BLOCK_SIZE 256
|
||||
#define SYCL_CPY_BLOCK_SIZE 32
|
||||
#define SYCL_SCALE_BLOCK_SIZE 256
|
||||
|
@ -41,6 +46,7 @@
|
|||
#define SYCL_ACC_BLOCK_SIZE 256
|
||||
#define SYCL_IM2COL_BLOCK_SIZE 256
|
||||
#define SYCL_POOL2D_BLOCK_SIZE 256
|
||||
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
||||
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
||||
|
||||
|
|
138
ggml/src/ggml-sycl/wkv6.cpp
Normal file
138
ggml/src/ggml-sycl/wkv6.cpp
Normal file
|
@ -0,0 +1,138 @@
|
|||
#include <sycl/sycl.hpp>
|
||||
#include "wkv6.hpp"
|
||||
|
||||
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
||||
|
||||
// Helper function for the main kernel
|
||||
static void rwkv_wkv_f32_kernel(
|
||||
const int B, const int T, const int C, const int H,
|
||||
const float* k, const float* v, const float* r,
|
||||
const float* tf, const float* td, const float* s,
|
||||
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int bid = item_ct1.get_group(2);
|
||||
|
||||
const int head_size = WKV_BLOCK_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
// Set up shared memory pointers
|
||||
float* _k = shared_mem;
|
||||
float* _r = _k + head_size;
|
||||
float* _tf = _r + head_size;
|
||||
float* _td = _tf + head_size;
|
||||
|
||||
// Local state array
|
||||
float state[WKV_BLOCK_SIZE];
|
||||
|
||||
// Load initial state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
// Sync threads before shared memory operations
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load time-mixing parameters
|
||||
_tf[tid] = tf[head_i * head_size + tid];
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Main sequence processing loop
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
||||
t += C) {
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
// Load current timestep data to shared memory
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
|
||||
// Process in chunks of 4 for better vectorization
|
||||
sycl::float4 k4, r4, tf4, td4, s4, kv4;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
// Load data in vec4 chunks
|
||||
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
||||
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
||||
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
||||
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
||||
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
||||
|
||||
// Compute key-value product
|
||||
sycl::float4 kv4 = k4 * _v;
|
||||
|
||||
// Accumulate weighted sum
|
||||
y += sycl::dot(r4, tf4 * kv4 + s4);
|
||||
|
||||
// Update state
|
||||
s4 = s4 * td4 + kv4;
|
||||
|
||||
// Store updated state
|
||||
state[j] = s4.x();
|
||||
state[j+1] = s4.y();
|
||||
state[j+2] = s4.z();
|
||||
state[j+3] = s4.w();
|
||||
}
|
||||
|
||||
dst[t] = y;
|
||||
}
|
||||
|
||||
// Save final state
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
const float* r_d = (const float*)dst->src[2]->data;
|
||||
const float* tf_d = (const float*)dst->src[3]->data;
|
||||
const float* td_d = (const float*)dst->src[4]->data;
|
||||
const float* s_d = (const float*)dst->src[5]->data;
|
||||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[3];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[2];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
||||
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
||||
// Calculate execution configuration
|
||||
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
||||
sycl::range<3> block_dims(1, 1, C / H);
|
||||
sycl::range<3> grid_dims(1, 1, B * H);
|
||||
|
||||
// Submit kernel
|
||||
stream->submit([&](sycl::handler& cgh) {
|
||||
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
rwkv_wkv_f32_kernel(
|
||||
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
||||
item_ct1, shared_mem_acc.get_pointer()
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
10
ggml/src/ggml-sycl/wkv6.hpp
Normal file
10
ggml/src/ggml-sycl/wkv6.hpp
Normal file
|
@ -0,0 +1,10 @@
|
|||
#ifndef GGML_SYCL_WKV6_HPP
|
||||
#define GGML_SYCL_WKV6_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor * dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_WKV6_HPP
|
|
@ -3147,7 +3147,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||
|
||||
if (mmp == nullptr) {
|
||||
if (qx_needs_dequant) {
|
||||
// Fall back to dequant + f16 mulmat
|
||||
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
|
||||
}
|
||||
|
@ -3630,9 +3630,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|||
|
||||
static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
|
||||
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
|
||||
if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
|
||||
// detect 0213 permutation, and batch size of 1
|
||||
src0->nb[0] <= src0->nb[2] &&
|
||||
src0->nb[2] <= src0->nb[1] &&
|
||||
src0->nb[1] <= src0->nb[3] &&
|
||||
src1->nb[0] <= src1->nb[2] &&
|
||||
src1->nb[2] <= src1->nb[1] &&
|
||||
src1->nb[1] <= src1->nb[3] &&
|
||||
src0->ne[3] == 1 &&
|
||||
src1->ne[3] == 1) {
|
||||
ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
||||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
|
||||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
||||
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||
|
@ -3708,7 +3718,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|||
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
||||
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
|
||||
|
||||
if (mmp == nullptr) {
|
||||
if (qx_needs_dequant) {
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
|
@ -4470,7 +4480,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint32_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const uint32_t OW = dst->ne[1];
|
||||
|
||||
const uint32_t batch = src1->ne[3];
|
||||
const uint32_t batch = src1->ne[is_2D ? 3 : 2];
|
||||
|
||||
elements = { OW * KW * KH, OH, batch * IC };
|
||||
} break;
|
||||
|
@ -4915,7 +4925,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint32_t OW = dst->ne[1];
|
||||
|
||||
const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||
const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
const uint32_t pelements = OW * KW * KH;
|
||||
|
||||
|
@ -6804,6 +6814,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
||||
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
|
@ -975,7 +975,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"WIN_UNPART",
|
||||
"GET_REL_POS",
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV",
|
||||
"RWKV_WKV6",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
@ -1070,7 +1070,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"win_unpart(x)",
|
||||
"get_rel_pos(x)",
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv(k, v, r, tf, td, s)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
@ -1407,11 +1407,11 @@ static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const str
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||
static bool is_first_call = false;
|
||||
static bool is_first_call = true;
|
||||
|
||||
ggml_critical_section_start();
|
||||
|
||||
if (!is_first_call) {
|
||||
if (is_first_call) {
|
||||
// initialize time system (required on Windows)
|
||||
ggml_time_init();
|
||||
|
||||
|
@ -1422,7 +1422,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
} u = {i};
|
||||
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
|
||||
}
|
||||
is_first_call = true;
|
||||
|
||||
is_first_call = false;
|
||||
}
|
||||
|
||||
ggml_critical_section_end();
|
||||
|
@ -4227,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
|
|||
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||
}
|
||||
|
||||
enum ggml_prec ggml_flash_attn_ext_get_prec(
|
||||
const struct ggml_tensor * a) {
|
||||
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||
|
||||
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
|
||||
|
||||
return (enum ggml_prec) prec_i32;
|
||||
}
|
||||
|
||||
// ggml_flash_attn_back
|
||||
|
||||
struct ggml_tensor * ggml_flash_attn_back(
|
||||
|
@ -4502,9 +4512,9 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
|
|||
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
|
||||
}
|
||||
|
||||
// ggml_rwkv_wkv
|
||||
// ggml_rwkv_wkv6
|
||||
|
||||
struct ggml_tensor * ggml_rwkv_wkv(
|
||||
struct ggml_tensor * ggml_rwkv_wkv6(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
|
@ -4536,7 +4546,7 @@ struct ggml_tensor * ggml_rwkv_wkv(
|
|||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
result->op = GGML_OP_RWKV_WKV;
|
||||
result->op = GGML_OP_RWKV_WKV6;
|
||||
result->src[0] = k;
|
||||
result->src[1] = v;
|
||||
result->src[2] = r;
|
||||
|
@ -6083,7 +6093,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_ADD_REL_POS:
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
|
|
@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
#if defined(__MMA__)
|
||||
typedef vector unsigned char vec_t;
|
||||
typedef __vector_quad acc_t;
|
||||
#endif
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
|
||||
|
@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX {
|
|||
};
|
||||
#endif // __AVX__
|
||||
|
||||
//PPC Implementation
|
||||
#if defined(__MMA__)
|
||||
|
||||
#define SAVE_ACC(ACC, ii, jj) \
|
||||
__builtin_mma_disassemble_acc(vec_C, ACC); \
|
||||
for (int I = 0; I < 4; I++) { \
|
||||
for (int J = 0; J < 4; J++) { \
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
|
||||
} \
|
||||
} \
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_PPC {
|
||||
public:
|
||||
tinyBLAS_PPC(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
||||
|
||||
void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
|
||||
int64_t i, j;
|
||||
float *aoffset = NULL, *boffset = NULL;
|
||||
float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||
float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||
|
||||
aoffset = const_cast<float*>(a);
|
||||
boffset = vec;
|
||||
j = (rows >> 3);
|
||||
if (j > 0) {
|
||||
do {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset5 = aoffset4 + lda;
|
||||
aoffset6 = aoffset5 + lda;
|
||||
aoffset7 = aoffset6 + lda;
|
||||
aoffset8 = aoffset7 + lda;
|
||||
aoffset += 8 * lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
||||
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
do {
|
||||
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
||||
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
||||
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
||||
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
||||
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
|
||||
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
|
||||
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
|
||||
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
|
||||
__builtin_vsx_disassemble_pair(c1, &C1);
|
||||
__builtin_vsx_disassemble_pair(c2, &C2);
|
||||
__builtin_vsx_disassemble_pair(c3, &C3);
|
||||
__builtin_vsx_disassemble_pair(c4, &C4);
|
||||
__builtin_vsx_disassemble_pair(c5, &C5);
|
||||
__builtin_vsx_disassemble_pair(c6, &C6);
|
||||
__builtin_vsx_disassemble_pair(c7, &C7);
|
||||
__builtin_vsx_disassemble_pair(c8, &C8);
|
||||
|
||||
t1 = vec_mergeh(c1[0], c2[0]);
|
||||
t2 = vec_mergeh(c3[0], c4[0]);
|
||||
t3 = vec_mergeh(c5[0], c6[0]);
|
||||
t4 = vec_mergeh(c7[0], c8[0]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset);
|
||||
vec_xst(t6, 0, boffset+4);
|
||||
vec_xst(t7, 0, boffset+8);
|
||||
vec_xst(t8, 0, boffset+12);
|
||||
|
||||
t1 = vec_mergel(c1[0], c2[0]);
|
||||
t2 = vec_mergel(c3[0], c4[0]);
|
||||
t3 = vec_mergel(c5[0], c6[0]);
|
||||
t4 = vec_mergel(c7[0], c8[0]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset+16);
|
||||
vec_xst(t6, 0, boffset+20);
|
||||
vec_xst(t7, 0, boffset+24);
|
||||
vec_xst(t8, 0, boffset+28);
|
||||
|
||||
t1 = vec_mergeh(c1[1], c2[1]);
|
||||
t2 = vec_mergeh(c3[1], c4[1]);
|
||||
t3 = vec_mergeh(c5[1], c6[1]);
|
||||
t4 = vec_mergeh(c7[1], c8[1]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset+32);
|
||||
vec_xst(t6, 0, boffset+36);
|
||||
vec_xst(t7, 0, boffset+40);
|
||||
vec_xst(t8, 0, boffset+44);
|
||||
|
||||
t1 = vec_mergel(c1[1], c2[1]);
|
||||
t2 = vec_mergel(c3[1], c4[1]);
|
||||
t3 = vec_mergel(c5[1], c6[1]);
|
||||
t4 = vec_mergel(c7[1], c8[1]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset+48);
|
||||
vec_xst(t6, 0, boffset+52);
|
||||
vec_xst(t7, 0, boffset+56);
|
||||
vec_xst(t8, 0, boffset+60);
|
||||
|
||||
aoffset1 += 8*lda;
|
||||
aoffset2 += 8*lda;
|
||||
aoffset3 += 8*lda;
|
||||
aoffset4 += 8*lda;
|
||||
boffset += 64;
|
||||
i--;
|
||||
} while(i > 0);
|
||||
}
|
||||
if (cols & 4) {
|
||||
vector float c1, c2, c3, c4, c5, c6, c7, c8;
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
c1 = vec_xl(0, aoffset1);
|
||||
c2 = vec_xl(0, aoffset2);
|
||||
c3 = vec_xl(0, aoffset3);
|
||||
c4 = vec_xl(0, aoffset4);
|
||||
c5 = vec_xl(0, aoffset5);
|
||||
c6 = vec_xl(0, aoffset6);
|
||||
c7 = vec_xl(0, aoffset7);
|
||||
c8 = vec_xl(0, aoffset8);
|
||||
|
||||
t1 = vec_mergeh(c1, c2);
|
||||
t2 = vec_mergeh(c3, c4);
|
||||
t3 = vec_mergeh(c5, c6);
|
||||
t4 = vec_mergeh(c7, c8);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset);
|
||||
vec_xst(t6, 0, boffset+4);
|
||||
vec_xst(t7, 0, boffset+8);
|
||||
vec_xst(t8, 0, boffset+12);
|
||||
|
||||
t1 = vec_mergel(c1, c2);
|
||||
t2 = vec_mergel(c3, c4);
|
||||
t3 = vec_mergel(c5, c6);
|
||||
t4 = vec_mergel(c7, c8);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t3, t4, 0);
|
||||
t7 = vec_xxpermdi(t1, t2, 3);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset+16);
|
||||
vec_xst(t6, 0, boffset+20);
|
||||
vec_xst(t7, 0, boffset+24);
|
||||
vec_xst(t8, 0, boffset+28);
|
||||
}
|
||||
j--;
|
||||
} while(j > 0);
|
||||
}
|
||||
|
||||
if (rows & 4) {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
aoffset4 = aoffset3 + lda;
|
||||
aoffset += 4 * lda;
|
||||
i = (cols >> 3);
|
||||
if (i > 0) {
|
||||
__vector_pair C1, C2, C3, C4;
|
||||
vector float c1[2], c2[2], c3[2], c4[2];
|
||||
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||
do {
|
||||
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
||||
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
||||
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
||||
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
||||
__builtin_vsx_disassemble_pair(c1, &C1);
|
||||
__builtin_vsx_disassemble_pair(c2, &C2);
|
||||
__builtin_vsx_disassemble_pair(c3, &C3);
|
||||
__builtin_vsx_disassemble_pair(c4, &C4);
|
||||
|
||||
t1 = vec_mergeh(c1[0], c2[0]);
|
||||
t2 = vec_mergeh(c3[0], c4[0]);
|
||||
t3 = vec_mergel(c1[0], c2[0]);
|
||||
t4 = vec_mergel(c3[0], c4[0]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t1, t2, 3);
|
||||
t7 = vec_xxpermdi(t3, t4, 0);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset);
|
||||
vec_xst(t6, 0, boffset+4);
|
||||
vec_xst(t7, 0, boffset+8);
|
||||
vec_xst(t8, 0, boffset+12);
|
||||
|
||||
t1 = vec_mergeh(c1[1], c2[1]);
|
||||
t2 = vec_mergeh(c3[1], c4[1]);
|
||||
t3 = vec_mergel(c1[1], c2[1]);
|
||||
t4 = vec_mergel(c3[1], c4[1]);
|
||||
t5 = vec_xxpermdi(t1, t2, 0);
|
||||
t6 = vec_xxpermdi(t1, t2, 3);
|
||||
t7 = vec_xxpermdi(t3, t4, 0);
|
||||
t8 = vec_xxpermdi(t3, t4, 3);
|
||||
vec_xst(t5, 0, boffset+16);
|
||||
vec_xst(t6, 0, boffset+20);
|
||||
vec_xst(t7, 0, boffset+24);
|
||||
vec_xst(t8, 0, boffset+28);
|
||||
|
||||
aoffset1 += 8*lda;
|
||||
aoffset2 += 8*lda;
|
||||
aoffset3 += 8*lda;
|
||||
aoffset4 += 8*lda;
|
||||
boffset += 32;
|
||||
i--;
|
||||
} while(i > 0);
|
||||
}
|
||||
|
||||
if (cols & 4) {
|
||||
vector float c1, c2, c3, c4;
|
||||
vector float t1, t2, t3, t4;
|
||||
c1 = vec_xl(0, aoffset1);
|
||||
c2 = vec_xl(0, aoffset2);
|
||||
c3 = vec_xl(0, aoffset3);
|
||||
c4 = vec_xl(0, aoffset4);
|
||||
|
||||
t1 = vec_mergeh(c1, c2);
|
||||
t2 = vec_mergeh(c3, c4);
|
||||
t3 = vec_xxpermdi(t1, t2, 0);
|
||||
t4 = vec_xxpermdi(t1, t2, 3);
|
||||
vec_xst(t3, 0, boffset);
|
||||
vec_xst(t4, 0, boffset+4);
|
||||
|
||||
t1 = vec_mergel(c1, c2);
|
||||
t2 = vec_mergel(c3, c4);
|
||||
t3 = vec_xxpermdi(t1, t2, 0);
|
||||
t4 = vec_xxpermdi(t1, t2, 3);
|
||||
vec_xst(t3, 0, boffset+8);
|
||||
vec_xst(t4, 0, boffset+12);
|
||||
}
|
||||
}
|
||||
if (rows & 3) {
|
||||
aoffset1 = aoffset;
|
||||
aoffset2 = aoffset1 + lda;
|
||||
aoffset3 = aoffset2 + lda;
|
||||
if (cols & 4) {
|
||||
vector float c1, c2, c3, c4 = {0};
|
||||
vector float t1, t2, t3, t4;
|
||||
c1 = vec_xl(0, aoffset1);
|
||||
c2 = vec_xl(0, aoffset2);
|
||||
c3 = vec_xl(0, aoffset3);
|
||||
|
||||
t1 = vec_mergeh(c1, c2);
|
||||
t2 = vec_mergeh(c3, c4);
|
||||
t3 = vec_xxpermdi(t1, t2, 0);
|
||||
t4 = vec_xxpermdi(t1, t2, 3);
|
||||
vec_xst(t3, 0, boffset);
|
||||
vec_xst(t4, 0, boffset+4);
|
||||
|
||||
t1 = vec_mergel(c1, c2);
|
||||
t2 = vec_mergel(c3, c4);
|
||||
t3 = vec_xxpermdi(t1, t2, 0);
|
||||
t4 = vec_xxpermdi(t1, t2, 3);
|
||||
vec_xst(t3, 0, boffset+8);
|
||||
vec_xst(t4, 0, boffset+12);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[4], vec_B[4], vec_C[4];
|
||||
acc_t acc_0;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
for (int l = 0; l < k; l+=4) {
|
||||
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
}
|
||||
|
||||
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[4], vec_B[8], vec_C[4];
|
||||
acc_t acc_0, acc_1;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
for (int64_t l = 0; l < k; l+=4) {
|
||||
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii, jj+4);
|
||||
}
|
||||
|
||||
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[8], vec_B[4], vec_C[4];
|
||||
acc_t acc_0, acc_1;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
for (int64_t l = 0; l < k; l+=4) {
|
||||
READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii+4, jj);
|
||||
}
|
||||
|
||||
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||
vec_t vec_A[16], vec_B[16], vec_C[4];
|
||||
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
__builtin_mma_xxsetaccz(&acc_1);
|
||||
__builtin_mma_xxsetaccz(&acc_2);
|
||||
__builtin_mma_xxsetaccz(&acc_3);
|
||||
for (int l = 0; l < k; l+=8) {
|
||||
READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
|
||||
for(int x = 0; x < 16; x+=2) {
|
||||
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
|
||||
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
|
||||
}
|
||||
}
|
||||
SAVE_ACC(&acc_0, ii, jj);
|
||||
SAVE_ACC(&acc_1, ii, jj+4);
|
||||
SAVE_ACC(&acc_2, ii+4, jj);
|
||||
SAVE_ACC(&acc_3, ii+4, jj+4);
|
||||
}
|
||||
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t mc, nc, mp, np;
|
||||
int m_rem = MIN(m - m0, 16);
|
||||
int n_rem = MIN(n - n0, 16);
|
||||
if (m_rem >= 16 && n_rem >= 8) {
|
||||
mc = 8;
|
||||
nc = 8;
|
||||
gemm<8,8>(m0, m, n0, n);
|
||||
} else if(m_rem >= 8 && n_rem >= 16) {
|
||||
mc = 8;
|
||||
nc = 8;
|
||||
gemm<8,8>(m0, m, n0, n);
|
||||
} else if (m_rem >= 8 && n_rem >= 8) {
|
||||
mc = 8;
|
||||
nc = 8;
|
||||
gemm<8,8>(m0, m, n0, n);
|
||||
} else if (m_rem >= 4 && n_rem >= 8) {
|
||||
mc = 4;
|
||||
nc = 8;
|
||||
gemm<4,8>(m0, m, n0, n);
|
||||
} else if (m_rem >= 8 && n_rem >= 4) {
|
||||
mc = 8;
|
||||
nc = 4;
|
||||
gemm<8,4>(m0, m, n0, n);
|
||||
} else if (m_rem >= 4 && n_rem >= 4) {
|
||||
mc = 4;
|
||||
nc = 4;
|
||||
gemm<4,4>(m0, m, n0, n);
|
||||
} else if ((m_rem < 4) && (n_rem > 4)) {
|
||||
nc = 4;
|
||||
switch(m_rem) {
|
||||
case 1:
|
||||
mc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 2:
|
||||
mc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 3:
|
||||
mc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
} else if ((m_rem > 4) && (n_rem < 4)) {
|
||||
mc = 4;
|
||||
switch(n_rem) {
|
||||
case 1:
|
||||
nc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 2:
|
||||
nc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 3:
|
||||
nc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
switch((m_rem << 4) | n_rem) {
|
||||
case 0x43:
|
||||
mc = 4;
|
||||
nc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x42:
|
||||
mc = 4;
|
||||
nc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x41:
|
||||
mc = 4;
|
||||
nc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x34:
|
||||
mc = 3;
|
||||
nc = 4;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x33:
|
||||
mc = 3;
|
||||
nc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x32:
|
||||
mc = 3;
|
||||
nc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x31:
|
||||
mc = 3;
|
||||
nc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x24:
|
||||
mc = 2;
|
||||
nc = 4;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x23:
|
||||
mc = 2;
|
||||
nc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x22:
|
||||
mc = 2;
|
||||
nc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x21:
|
||||
mc = 2;
|
||||
nc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x14:
|
||||
mc = 1;
|
||||
nc = 4;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x13:
|
||||
mc = 1;
|
||||
nc = 3;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x12:
|
||||
mc = 1;
|
||||
nc = 2;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
case 0x11:
|
||||
mc = 1;
|
||||
nc = 1;
|
||||
gemm_small(m0, m, n0, n, mc, nc);
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
}
|
||||
mp = m0 + (m - m0) / mc * mc;
|
||||
np = n0 + (n - n0) / nc * nc;
|
||||
mnpack(mp, m, n0, np);
|
||||
mnpack(m0, m, np, n);
|
||||
}
|
||||
|
||||
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles)
|
||||
end = tiles;
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
vec_t vec_C[4];
|
||||
acc_t acc_0;
|
||||
__builtin_mma_xxsetaccz(&acc_0);
|
||||
vec_t vec_A[4], vec_B[4];
|
||||
for (int l=0; l<k; l+=4) {
|
||||
if (RN >= 4 && RM == 1) {
|
||||
float* a = const_cast<float*>(A+(ii)*lda+l);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||
vec_A[0] = (vec_t)vec_xl(0,a);
|
||||
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
|
||||
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
|
||||
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
|
||||
} else {
|
||||
READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
|
||||
READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
|
||||
}
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
||||
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
||||
}
|
||||
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||
for (int I = 0; I < RM; I++) {
|
||||
for (int J = 0; J < RN; J++) {
|
||||
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (RM == 4 && RN == 4) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_4x4;
|
||||
} else if (RM == 4 && RN == 8) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_4x8;
|
||||
} else if (RM == 8 && RN == 4) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_8x4;
|
||||
} else if (RM == 8 && RN == 8) {
|
||||
kernel = &tinyBLAS_PPC::KERNEL_8x8;
|
||||
}
|
||||
if (end > tiles)
|
||||
end = tiles;
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
(this->*kernel)(ii, jj);
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const TB *const B;
|
||||
TC *C;
|
||||
TA *At;
|
||||
TB *Bt;
|
||||
const int64_t k;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
#endif
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
|
@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__MMA__)
|
||||
if (k % 8)
|
||||
return false;
|
||||
tinyBLAS_PPC<float, float, float> tb{
|
||||
k, (const float *)A, lda,
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
|
|
@ -124,7 +124,7 @@ You can use GBNF grammars:
|
|||
- In [llama-cli](../examples/main), passed as the `--json` / `-j` flag
|
||||
- To convert to a grammar ahead of time:
|
||||
- in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py)
|
||||
- in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI)
|
||||
- in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI)
|
||||
|
||||
Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggerganov/llama.cpp/pull/5978, https://github.com/ggerganov/llama.cpp/pull/6659 & https://github.com/ggerganov/llama.cpp/pull/6555).
|
||||
|
||||
|
|
|
@ -114,46 +114,22 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||
|
||||
# replace filenames:
|
||||
#
|
||||
# CMakelists.txt -> ggml/CMakeLists.txt
|
||||
# src/CMakeLists.txt -> ggml/src/CMakeLists.txt
|
||||
# cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake
|
||||
# CMakelists.txt -> ggml/CMakeLists.txt
|
||||
# src/CMakeLists.txt -> ggml/src/CMakeLists.txt
|
||||
# cmake/FindSIMD.cmake -> ggml/cmake/FindSIMD.cmake
|
||||
#
|
||||
# src/ggml.c -> ggml/src/ggml.c
|
||||
# src/ggml-aarch64.c -> ggml/src/ggml-aarch64.c
|
||||
# src/ggml-aarch64.h -> ggml/src/ggml-aarch64.h
|
||||
# src/ggml-alloc.c -> ggml/src/ggml-alloc.c
|
||||
# src/ggml-amx/* -> ggml/src/ggml-amx/
|
||||
# src/ggml-amx.cpp -> ggml/src/ggml-amx.cpp
|
||||
# src/ggml-backend-impl.h -> ggml/src/ggml-backend-impl.h
|
||||
# src/ggml-backend.cpp -> ggml/src/ggml-backend.cpp
|
||||
# src/ggml-cann/* -> ggml/src/ggml-cann/
|
||||
# src/ggml-cann.cpp -> ggml/src/ggml-cann.cpp
|
||||
# src/ggml-common.h -> ggml/src/ggml-common.h
|
||||
# src/ggml-cuda/* -> ggml/src/ggml-cuda/
|
||||
# src/ggml-cuda.cu -> ggml/src/ggml-cuda.cu
|
||||
# src/ggml-impl.h -> ggml/src/ggml-impl.h
|
||||
# src/ggml-kompute.cpp -> ggml/src/ggml-kompute.cpp
|
||||
# src/ggml-metal.m -> ggml/src/ggml-metal.m
|
||||
# src/ggml-quants.c -> ggml/src/ggml-quants.c
|
||||
# src/ggml-quants.h -> ggml/src/ggml-quants.h
|
||||
# src/ggml-rpc.cpp -> ggml/src/ggml-rpc.cpp
|
||||
# src/ggml-sycl/* -> ggml/src/ggml-sycl/
|
||||
# src/ggml-sycl.cpp -> ggml/src/ggml-sycl.cpp
|
||||
# src/ggml-vulkan.cpp -> ggml/src/ggml-vulkan.cpp
|
||||
# src/vulkan-shaders/* -> ggml/src/vulkan-shaders/
|
||||
# src/ggml*.c -> ggml/src/ggml*.c
|
||||
# src/ggml*.cpp -> ggml/src/ggml*.cpp
|
||||
# src/ggml*.h -> ggml/src/ggml*.h
|
||||
# src/ggml*.cu -> ggml/src/ggml*.cu
|
||||
# src/ggml*.m -> ggml/src/ggml*.m
|
||||
# src/ggml-amx/* -> ggml/src/ggml-amx/
|
||||
# src/ggml-cann/* -> ggml/src/ggml-cann/
|
||||
# src/ggml-cuda/* -> ggml/src/ggml-cuda/
|
||||
# src/ggml-sycl/* -> ggml/src/ggml-sycl/
|
||||
# src/vulkan-shaders/* -> ggml/src/vulkan-shaders/
|
||||
#
|
||||
# include/ggml.h -> ggml/include/ggml.h
|
||||
# include/ggml-alloc.h -> ggml/include/ggml-alloc.h
|
||||
# include/ggml-amx.h -> ggml/include/ggml-amx.h
|
||||
# include/ggml-backend.h -> ggml/include/ggml-backend.h
|
||||
# include/ggml-blas.h -> ggml/include/ggml-blas.h
|
||||
# include/ggml-cann.h -> ggml/include/ggml-cann.h
|
||||
# include/ggml-cuda.h -> ggml/include/ggml-cuda.h
|
||||
# include/ggml-kompute.h -> ggml/include/ggml-kompute.h
|
||||
# include/ggml-metal.h -> ggml/include/ggml-metal.h
|
||||
# include/ggml-rpc.h -> ggml/include/ggml-rpc.h
|
||||
# include/ggml-sycl.h -> ggml/include/ggml-sycl.h
|
||||
# include/ggml-vulkan.h -> ggml/include/ggml-vulkan.h
|
||||
# include/ggml*.h -> ggml/include/ggml*.h
|
||||
#
|
||||
# tests/test-opt.cpp -> tests/test-opt.cpp
|
||||
# tests/test-grad0.cpp -> tests/test-grad0.cpp
|
||||
|
@ -168,41 +144,17 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then
|
|||
-e 's/([[:space:]]|[ab]\/)CMakeLists.txt/\1ggml\/CMakeLists.txt/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/CMakeLists.txt/\1ggml\/src\/CMakeLists.txt/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)cmake\/FindSIMD.cmake/\1ggml\/cmake\/FindSIMD.cmake/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml\.c/\1ggml\/src\/ggml.c/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.c/\1ggml\/src\/ggml-aarch64.c/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-aarch64\.h/\1ggml\/src\/ggml-aarch64.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-alloc\.c/\1ggml\/src\/ggml-alloc.c/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\1.c/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\1.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\1.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cu/\1ggml\/src\/ggml\1.cu/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.m/\1ggml\/src\/ggml\1.m/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-amx\.cpp/\1ggml\/src\/ggml-amx.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend-impl\.h/\1ggml\/src\/ggml-backend-impl.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-backend\.cpp/\1ggml\/src\/ggml-backend.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cann\.cpp/\1ggml\/src\/ggml-cann.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-common\.h/\1ggml\/src\/ggml-common.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\.cu/\1ggml\/src\/ggml-cuda.cu/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-impl\.h/\1ggml\/src\/ggml-impl.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\.cpp/\1ggml\/src\/ggml-kompute.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-metal\.m/\1ggml\/src\/ggml-metal.m/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-quants\.c/\1ggml\/src\/ggml-quants.c/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-quants\.h/\1ggml\/src\/ggml-quants.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\.cpp/\1ggml\/src\/ggml-rpc.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\.cpp/\1ggml\/src\/ggml-sycl.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\.cpp/\1ggml\/src\/ggml-vulkan.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml\.h/\1ggml\/include\/ggml.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-alloc\.h/\1ggml\/include\/ggml-alloc.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-amx\.h/\1ggml\/include\/ggml-amx.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-backend\.h/\1ggml\/include\/ggml-backend.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-blas\.h/\1ggml\/include\/ggml-blas.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-cann\.h/\1ggml\/include\/ggml-cann.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-cuda\.h/\1ggml\/include\/ggml-cuda.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-kompute\.h/\1ggml\/include\/ggml-kompute.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-metal\.h/\1ggml\/include\/ggml-metal.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-rpc\.h/\1ggml\/include\/ggml-rpc.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-sycl\.h/\1ggml\/include\/ggml-sycl.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml-vulkan\.h/\1ggml\/include\/ggml-vulkan.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\1.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)examples\/common\.h/\1examples\/common.h/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \
|
||||
-e 's/([[:space:]]|[ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \
|
||||
|
|
|
@ -1 +1 @@
|
|||
a099cb514d6687e436a5a423d1fb0448be0feb20
|
||||
89952d649e0c5cabbb9ff8c4906f5a843a789fb2
|
||||
|
|
|
@ -4,43 +4,18 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt
|
|||
cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt
|
||||
cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake
|
||||
|
||||
cp -rpv ../ggml/src/ggml.c ./ggml/src/ggml.c
|
||||
cp -rpv ../ggml/src/ggml-aarch64.c ./ggml/src/ggml-aarch64.c
|
||||
cp -rpv ../ggml/src/ggml-aarch64.h ./ggml/src/ggml-aarch64.h
|
||||
cp -rpv ../ggml/src/ggml-alloc.c ./ggml/src/ggml-alloc.c
|
||||
cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/
|
||||
cp -rpv ../ggml/src/ggml-amx.cpp ./ggml/src/ggml-amx.cpp
|
||||
cp -rpv ../ggml/src/ggml-backend-impl.h ./ggml/src/ggml-backend-impl.h
|
||||
cp -rpv ../ggml/src/ggml-backend.cpp ./ggml/src/ggml-backend.cpp
|
||||
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
||||
cp -rpv ../ggml/src/ggml-cann.cpp ./ggml/src/ggml-cann.cpp
|
||||
cp -rpv ../ggml/src/ggml-common.h ./ggml/src/ggml-common.h
|
||||
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
|
||||
cp -rpv ../ggml/src/ggml-cuda.cu ./ggml/src/ggml-cuda.cu
|
||||
cp -rpv ../ggml/src/ggml-impl.h ./ggml/src/ggml-impl.h
|
||||
cp -rpv ../ggml/src/ggml-kompute.cpp ./ggml/src/ggml-kompute.cpp
|
||||
cp -rpv ../ggml/src/ggml-metal.m ./ggml/src/ggml-metal.m
|
||||
cp -rpv ../ggml/src/ggml-metal.metal ./ggml/src/ggml-metal.metal
|
||||
cp -rpv ../ggml/src/ggml-quants.c ./ggml/src/ggml-quants.c
|
||||
cp -rpv ../ggml/src/ggml-quants.h ./ggml/src/ggml-quants.h
|
||||
cp -rpv ../ggml/src/ggml-rpc.cpp ./ggml/src/ggml-rpc.cpp
|
||||
cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/
|
||||
cp -rpv ../ggml/src/ggml-sycl.cpp ./ggml/src/ggml-sycl.cpp
|
||||
cp -rpv ../ggml/src/ggml-vulkan.cpp ./ggml/src/ggml-vulkan.cpp
|
||||
cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/
|
||||
cp -rpv ../ggml/src/ggml*.c ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml*.h ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml*.cu ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml*.m ./ggml/src/
|
||||
cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/
|
||||
cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/
|
||||
cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/
|
||||
cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/
|
||||
cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/
|
||||
|
||||
cp -rpv ../ggml/include/ggml.h ./ggml/include/ggml.h
|
||||
cp -rpv ../ggml/include/ggml-alloc.h ./ggml/include/ggml-alloc.h
|
||||
cp -rpv ../ggml/include/ggml-amx.h ./ggml/include/ggml-amx.h
|
||||
cp -rpv ../ggml/include/ggml-backend.h ./ggml/include/ggml-backend.h
|
||||
cp -rpv ../ggml/include/ggml-blas.h ./ggml/include/ggml-blas.h
|
||||
cp -rpv ../ggml/include/ggml-cann.h ./ggml/include/ggml-cann.h
|
||||
cp -rpv ../ggml/include/ggml-cuda.h ./ggml/include/ggml-cuda.h
|
||||
cp -rpv ../ggml/include/ggml-kompute.h ./ggml/include/ggml-kompute.h
|
||||
cp -rpv ../ggml/include/ggml-metal.h ./ggml/include/ggml-metal.h
|
||||
cp -rpv ../ggml/include/ggml-rpc.h ./ggml/include/ggml-rpc.h
|
||||
cp -rpv ../ggml/include/ggml-sycl.h ./ggml/include/ggml-sycl.h
|
||||
cp -rpv ../ggml/include/ggml-vulkan.h ./ggml/include/ggml-vulkan.h
|
||||
cp -rpv ../ggml/include/ggml*.h ./ggml/include/
|
||||
|
||||
cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp
|
||||
cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp
|
||||
|
|
|
@ -1876,8 +1876,11 @@ static void llama_sampler_dry_reset(struct llama_sampler * smpl) {
|
|||
static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (llama_sampler_dry *) smpl->ctx;
|
||||
|
||||
// nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
|
||||
auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||
llama_vocab dummy_vocab;
|
||||
|
||||
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
|
||||
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
|
||||
|
||||
// Copy the state, including the processed breakers
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_dry *) result->ctx;
|
||||
|
|
|
@ -2301,6 +2301,7 @@ enum e_model {
|
|||
MODEL_1B,
|
||||
MODEL_1_3B,
|
||||
MODEL_1_4B,
|
||||
MODEL_1_5B,
|
||||
MODEL_1_6B,
|
||||
MODEL_2B,
|
||||
MODEL_2_8B,
|
||||
|
@ -5227,6 +5228,7 @@ static const char * llama_model_type_name(e_model type) {
|
|||
case MODEL_1B: return "1B";
|
||||
case MODEL_1_3B: return "1.3B";
|
||||
case MODEL_1_4B: return "1.4B";
|
||||
case MODEL_1_5B: return "1.5B";
|
||||
case MODEL_1_6B: return "1.6B";
|
||||
case MODEL_2B: return "2B";
|
||||
case MODEL_2_8B: return "2.8B";
|
||||
|
@ -5598,6 +5600,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break;
|
||||
case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break;
|
||||
case 32: model.type = e_model::MODEL_7B; break;
|
||||
case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break;
|
||||
case 80: model.type = e_model::MODEL_70B; break;
|
||||
|
@ -7011,7 +7014,7 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
|
|||
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
@ -7127,7 +7130,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
|||
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
|
||||
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
|
||||
} break;
|
||||
case GGML_OP_RWKV_WKV:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
{
|
||||
// FIXME
|
||||
const int64_t S = 123;
|
||||
|
@ -7140,7 +7143,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
|||
ggml_tensor * tf = w;
|
||||
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
||||
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
|
||||
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||
|
@ -9134,7 +9137,7 @@ static bool llm_load_tensors(
|
|||
|
||||
// print memory requirements per buffer type
|
||||
for (auto & buf : model.bufs) {
|
||||
LLAMA_LOG_INFO("%s: %10s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||
LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// populate tensors_by_name
|
||||
|
@ -10083,7 +10086,7 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|||
v = ggml_transpose(ctx, v);
|
||||
r = ggml_transpose(ctx, r);
|
||||
|
||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { readFileSync } from "fs"
|
||||
import { SchemaConverter } from "../examples/server/public/json-schema-to-grammar.mjs"
|
||||
import { SchemaConverter } from "../examples/server/public_legacy/json-schema-to-grammar.mjs"
|
||||
|
||||
const [, , file] = process.argv
|
||||
const url = `file://${file}`
|
||||
|
|
|
@ -1614,8 +1614,8 @@ struct test_ssm_scan : public test_case {
|
|||
}
|
||||
};
|
||||
|
||||
// GGML_OP_RWKV_WKV
|
||||
struct test_rwkv_wkv : public test_case {
|
||||
// GGML_OP_RWKV_WKV6
|
||||
struct test_rwkv_wkv6 : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
|
@ -1627,7 +1627,7 @@ struct test_rwkv_wkv : public test_case {
|
|||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_rwkv_wkv(ggml_type type = GGML_TYPE_F32,
|
||||
test_rwkv_wkv6(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
|
@ -1639,7 +1639,7 @@ struct test_rwkv_wkv : public test_case {
|
|||
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
||||
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv(ctx, k, v, r, tf, td, s);
|
||||
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
@ -3499,10 +3499,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
|
||||
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
|
||||
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
#if 1
|
||||
for (ggml_type type_a : base_types) {
|
||||
|
@ -3599,7 +3599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (int n_mats : {4}) {
|
||||
for (int n_used : {2}) {
|
||||
for (bool b : {false}) {
|
||||
for (int n : {1}) {
|
||||
for (int n : {1, 32}) {
|
||||
int m = 512;
|
||||
int k = 256;
|
||||
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
|
||||
|
@ -3745,7 +3745,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
for (int nh : { 32, }) {
|
||||
for (int kv : { 512, 1024, }) {
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue