merge master
This commit is contained in:
commit
d63497b3a3
62 changed files with 1828 additions and 482 deletions
|
@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
|
||||
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
|
||||
- [x] [Phi models](https://huggingface.co/models?search=microsoft/phi)
|
||||
- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003)
|
||||
- [x] [GPT-2](https://huggingface.co/gpt2)
|
||||
- [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118)
|
||||
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
|
||||
|
@ -98,6 +99,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
|
|||
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
|
||||
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
|
||||
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
|
||||
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
|
||||
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
|
||||
|
||||
#### Multimodal
|
||||
|
|
|
@ -326,6 +326,7 @@ class Model:
|
|||
gguf.MODEL_TENSOR.TIME_MIX_W2,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
|
||||
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM1,
|
||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||
)
|
||||
|
@ -477,6 +478,11 @@ class Model:
|
|||
return modelcls
|
||||
return func
|
||||
|
||||
@classmethod
|
||||
def print_registered_models(cls):
|
||||
for name in sorted(cls._model_classes.keys()):
|
||||
logger.error(f"- {name}")
|
||||
|
||||
@classmethod
|
||||
def from_model_architecture(cls, arch: str) -> type[Model]:
|
||||
try:
|
||||
|
@ -2562,6 +2568,63 @@ class Phi3MiniModel(Model):
|
|||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
|
||||
@Model.register("PhiMoEForCausalLM")
|
||||
class PhiMoeModel(Phi3MiniModel):
|
||||
model_arch = gguf.MODEL_ARCH.PHIMOE
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.hparams["num_local_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["w1", "w2", "w3"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("PlamoForCausalLM")
|
||||
class PlamoModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.PLAMO
|
||||
|
@ -3259,6 +3322,8 @@ class Rwkv6Model(Model):
|
|||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
lerp_weights: dict[int, dict[str, Tensor]] = {}
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
|
@ -3274,14 +3339,84 @@ class Rwkv6Model(Model):
|
|||
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
|
||||
data_torch = data_torch.squeeze()
|
||||
|
||||
rescale_every_n_layers = self.hparams["rescale_every"]
|
||||
if rescale_every_n_layers > 0:
|
||||
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|
||||
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
|
||||
try:
|
||||
rescale_every_n_layers = self.hparams["rescale_every"]
|
||||
if rescale_every_n_layers > 0:
|
||||
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
|
||||
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# concat time_mix_lerp weights to reduce some cpu overhead
|
||||
# also reduces the number of tensors in the model
|
||||
if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name:
|
||||
try:
|
||||
self.lerp_weights[bid][new_name] = data_torch
|
||||
except KeyError:
|
||||
self.lerp_weights[bid] = {new_name: data_torch}
|
||||
if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
|
||||
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
|
||||
data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
|
||||
yield (new_name, data)
|
||||
return
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("RWKV6Qwen2ForCausalLM")
|
||||
class RWKV6Qwen2Model(Rwkv6Model):
|
||||
model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
|
||||
|
||||
def set_vocab(self):
|
||||
try:
|
||||
self._set_vocab_sentencepiece()
|
||||
except FileNotFoundError:
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
num_attention_heads = self.hparams["num_attention_heads"]
|
||||
num_key_value_heads = self.hparams["num_key_value_heads"]
|
||||
hidden_size = self.hparams["hidden_size"]
|
||||
head_size = hidden_size // num_attention_heads
|
||||
rms_norm_eps = self.hparams["rms_norm_eps"]
|
||||
intermediate_size = self.hparams["intermediate_size"]
|
||||
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
|
||||
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
|
||||
|
||||
# RWKV isn't context limited
|
||||
self.gguf_writer.add_context_length(1048576)
|
||||
self.gguf_writer.add_embedding_length(hidden_size)
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_wkv_head_size(head_size)
|
||||
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
|
||||
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# special parameters for time_mixing in RWKV6QWEN2
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_token_shift_count(1)
|
||||
# RWKV6QWEN2 use grouped key/value like GQA
|
||||
self.gguf_writer.add_head_count_kv(num_key_value_heads)
|
||||
|
||||
# required by llama.cpp, unused
|
||||
self.gguf_writer.add_head_count(0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
for new_name, data in super().modify_tensors(data_torch, name, bid):
|
||||
if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
|
||||
data = data.view(5, -1, data.shape[-1])
|
||||
# rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
|
||||
# permute them here to avoid code changes
|
||||
data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
|
||||
if "w2" in new_name:
|
||||
data = data.view(5, -1, data.shape[-1])
|
||||
yield (new_name, data)
|
||||
continue
|
||||
yield (new_name, data)
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
@ -4799,6 +4934,7 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"model", type=Path,
|
||||
help="directory containing model file",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-temp-file", action="store_true",
|
||||
|
@ -4836,8 +4972,15 @@ def parse_args() -> argparse.Namespace:
|
|||
"--metadata", type=Path,
|
||||
help="Specify the path for an authorship metadata override file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-supported-models", action="store_true",
|
||||
help="Print the supported models"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
if not args.print_supported_models and args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
return args
|
||||
|
||||
|
||||
def split_str_to_n_bytes(split_str: str) -> int:
|
||||
|
@ -4861,6 +5004,11 @@ def split_str_to_n_bytes(split_str: str) -> int:
|
|||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.print_supported_models:
|
||||
logger.error("Supported models:")
|
||||
Model.print_registered_models()
|
||||
sys.exit(0)
|
||||
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
|
|
|
@ -127,6 +127,8 @@ For detailed info, please refer to [llama.cpp for SYCL](./backend/SYCL.md).
|
|||
|
||||
This provides GPU acceleration using an NVIDIA GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from the [NVIDIA developer site](https://developer.nvidia.com/cuda-downloads).
|
||||
|
||||
If you are using Fedora (using Fedora Workstation, or an 'Atomic' variant such as Silverblue), or would like to set up CUDA in a toolbox, please consider our [Fedora CUDA guide](./cuda-fedora.md). Unfortunately, the process is not as simple as one might expect.
|
||||
|
||||
- Using `CMake`:
|
||||
|
||||
```bash
|
||||
|
|
317
docs/cuda-fedora.md
Normal file
317
docs/cuda-fedora.md
Normal file
|
@ -0,0 +1,317 @@
|
|||
# Setting Up CUDA on Fedora
|
||||
|
||||
In this guide we setup [Nvidia CUDA](https://docs.nvidia.com/cuda/) in a toolbox container. This guide is applicable for:
|
||||
- [Fedora Workstation](https://fedoraproject.org/workstation/)
|
||||
- [Atomic Desktops for Fedora](https://fedoraproject.org/atomic-desktops/)
|
||||
- [Fedora Spins](https://fedoraproject.org/spins)
|
||||
- [Other Distributions](https://containertoolbx.org/distros/), including `Red Hat Enterprise Linux >= 8.`, `Arch Linux`, and `Ubuntu`.
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Monitoring NVIDIA CUDA Repositories](#monitoring-nvidia-cuda-repositories)
|
||||
- [Using the Fedora 39 CUDA Repository](#using-the-fedora-39-cuda-repository)
|
||||
- [Creating a Fedora Toolbox Environment](#creating-a-fedora-toolbox-environment)
|
||||
- [Installing Essential Development Tools](#installing-essential-development-tools)
|
||||
- [Adding the CUDA Repository](#adding-the-cuda-repository)
|
||||
- [Installing `nvidia-driver-libs`](#installing-nvidia-driver-libs)
|
||||
- [Manually Resolving Package Conflicts](#manually-resolving-package-conflicts)
|
||||
- [Finalizing the Installation of `nvidia-driver-libs`](#finalizing-the-installation-of-nvidia-driver-libs)
|
||||
- [Installing the CUDA Meta-Package](#installing-the-cuda-meta-package)
|
||||
- [Configuring the Environment](#configuring-the-environment)
|
||||
- [Verifying the Installation](#verifying-the-installation)
|
||||
- [Conclusion](#conclusion)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Additional Notes](#additional-notes)
|
||||
- [References](#references)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Toolbox Installed on the Host System** `Fedora Silverblue` and `Fedora Workstation` both have toolbox by default, other distributions may need to install the [toolbox package](https://containertoolbx.org/install/).
|
||||
- **NVIDIA Drivers and Graphics Card installed on Host System (optional)** To run CUDA program, such as `llama.cpp`, the host should be setup to access your NVIDIA hardware. Fedora Hosts can use the [RPM Fusion Repository](https://rpmfusion.org/Howto/NVIDIA).
|
||||
- **Internet connectivity** to download packages.
|
||||
|
||||
### Monitoring NVIDIA CUDA Repositories
|
||||
|
||||
Before proceeding, it is advisable to check if NVIDIA has updated their CUDA repositories for your Fedora version. NVIDIA's repositories can be found at:
|
||||
|
||||
- [Fedora 40 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora40/x86_64/)
|
||||
- [Fedora 41 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora41/x86_64/)
|
||||
|
||||
As of the latest update, these repositories do not contain the `cuda` meta-package or are missing essential components.
|
||||
|
||||
### Using the Fedora 39 CUDA Repository
|
||||
|
||||
Since the newer repositories are incomplete, we'll use the Fedora 39 repository:
|
||||
|
||||
- [Fedora 39 CUDA Repository](https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/)
|
||||
|
||||
**Note:** Fedora 39 is no longer maintained, so we recommend using a toolbox environment to prevent system conflicts.
|
||||
|
||||
## Creating a Fedora Toolbox Environment
|
||||
|
||||
This guide focuses on Fedora hosts, but with small adjustments, it can work for other hosts. Using a Fedora 39 toolbox allows us to install the necessary packages without affecting the host system.
|
||||
|
||||
**Note:** Toolbox is available for other systems, and even without Toolbox, it is possible to use Podman or Docker.
|
||||
|
||||
We do not recommend installing on the host system, as Fedora 39 is out-of-maintenance, and instead you should upgrade to a maintained version of Fedora for your host.
|
||||
|
||||
1. **Create a Fedora 39 Toolbox:**
|
||||
|
||||
```bash
|
||||
toolbox create --image registry.fedoraproject.org/fedora-toolbox:39 --container fedora-toolbox-39-cuda
|
||||
```
|
||||
|
||||
2. **Enter the Toolbox:**
|
||||
|
||||
```bash
|
||||
toolbox enter --container fedora-toolbox-39-cuda
|
||||
```
|
||||
|
||||
Inside the toolbox, you have root privileges and can install packages without affecting the host system.
|
||||
|
||||
## Installing Essential Development Tools
|
||||
|
||||
1. **Synchronize the DNF Package Manager:**
|
||||
|
||||
```bash
|
||||
sudo dnf distro-sync
|
||||
```
|
||||
|
||||
2. **Install the Default Text Editor (Optional):**
|
||||
|
||||
```bash
|
||||
sudo dnf install vim-default-editor --allowerasing
|
||||
```
|
||||
|
||||
The `--allowerasing` flag resolves any package conflicts.
|
||||
|
||||
3. **Install Development Tools and Libraries:**
|
||||
|
||||
```bash
|
||||
sudo dnf install @c-development @development-tools cmake
|
||||
```
|
||||
|
||||
This installs essential packages for compiling software, including `gcc`, `make`, and other development headers.
|
||||
|
||||
## Adding the CUDA Repository
|
||||
|
||||
Add the NVIDIA CUDA repository to your DNF configuration:
|
||||
|
||||
```bash
|
||||
sudo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/fedora39/x86_64/cuda-fedora39.repo
|
||||
```
|
||||
|
||||
After adding the repository, synchronize the package manager again:
|
||||
|
||||
```bash
|
||||
sudo dnf distro-sync
|
||||
```
|
||||
|
||||
## Installing `nvidia-driver-libs`
|
||||
|
||||
Attempt to install `nvidia-driver-libs`:
|
||||
|
||||
```bash
|
||||
sudo dnf install nvidia-driver-libs
|
||||
```
|
||||
|
||||
**Explanation:**
|
||||
|
||||
- `nvidia-driver-libs` contains necessary NVIDIA driver libraries required by CUDA.
|
||||
- This step might fail due to conflicts with existing NVIDIA drivers on the host system.
|
||||
|
||||
## Manually Resolving Package Conflicts
|
||||
|
||||
If the installation fails due to conflicts, we'll manually download and install the required packages, excluding conflicting files.
|
||||
|
||||
### 1. Download the `nvidia-driver-libs` RPM
|
||||
|
||||
```bash
|
||||
sudo dnf download --arch x86_64 nvidia-driver-libs
|
||||
```
|
||||
|
||||
You should see a file similar to:
|
||||
|
||||
```
|
||||
nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm
|
||||
```
|
||||
|
||||
### 2. Attempt to Install the RPM
|
||||
|
||||
```bash
|
||||
sudo dnf install nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm
|
||||
```
|
||||
|
||||
**Expected Error:**
|
||||
|
||||
Installation may fail with errors pointing to conflicts with `egl-gbm` and `egl-wayland`.
|
||||
|
||||
**Note: It is important to carefully read the error messages to identify the exact paths that need to be excluded.**
|
||||
|
||||
### 3. Download Dependencies
|
||||
|
||||
```bash
|
||||
sudo dnf download --arch x86_64 egl-gbm egl-wayland
|
||||
```
|
||||
|
||||
### 4. Install `egl-gbm` with Excluded Paths
|
||||
|
||||
Exclude conflicting files during installation:
|
||||
|
||||
```bash
|
||||
sudo rpm --install --verbose --hash \
|
||||
--excludepath=/usr/lib64/libnvidia-egl-gbm.so.1.1.2 \
|
||||
--excludepath=/usr/share/egl/egl_external_platform.d/15_nvidia_gbm.json \
|
||||
egl-gbm-1.1.2^20240919gitb24587d-3.fc39.x86_64.rpm
|
||||
```
|
||||
|
||||
**Explanation:**
|
||||
|
||||
- The `--excludepath` option skips installing files that conflict with existing files.
|
||||
- Adjust the paths based on the error messages you receive.
|
||||
|
||||
### 5. Install `egl-wayland` with Excluded Paths
|
||||
|
||||
```bash
|
||||
sudo rpm --install --verbose --hash \
|
||||
--excludepath=/usr/share/egl/egl_external_platform.d/10_nvidia_wayland.json \
|
||||
egl-wayland-1.1.17^20241118giteeb29e1-5.fc39.x86_64.rpm
|
||||
```
|
||||
|
||||
### 6. Install `nvidia-driver-libs` with Excluded Paths
|
||||
|
||||
```bash
|
||||
sudo rpm --install --verbose --hash \
|
||||
--excludepath=/usr/share/glvnd/egl_vendor.d/10_nvidia.json \
|
||||
--excludepath=/usr/share/nvidia/nvoptix.bin \
|
||||
nvidia-driver-libs-560.35.05-1.fc39.x86_64.rpm
|
||||
```
|
||||
|
||||
**Note:**
|
||||
|
||||
- Replace the paths with the ones causing conflicts in your installation if they differ.
|
||||
- The `--verbose` and `--hash` options provide detailed output during installation.
|
||||
|
||||
## Finalizing the Installation of `nvidia-driver-libs`
|
||||
|
||||
After manually installing the dependencies, run:
|
||||
|
||||
```bash
|
||||
sudo dnf install nvidia-driver-libs
|
||||
```
|
||||
|
||||
You should receive a message indicating the package is already installed:
|
||||
|
||||
```
|
||||
Package nvidia-driver-libs-3:560.35.05-1.fc39.x86_64 is already installed.
|
||||
Dependencies resolved.
|
||||
Nothing to do.
|
||||
Complete!
|
||||
```
|
||||
|
||||
## Installing the CUDA Meta-Package
|
||||
|
||||
Now that the driver libraries are installed, proceed to install CUDA:
|
||||
|
||||
```bash
|
||||
sudo dnf install cuda
|
||||
```
|
||||
|
||||
This installs the CUDA toolkit and associated packages.
|
||||
|
||||
## Configuring the Environment
|
||||
|
||||
To use CUDA, add its binary directory to your system's `PATH`.
|
||||
|
||||
1. **Create a Profile Script:**
|
||||
|
||||
```bash
|
||||
sudo sh -c 'echo "export PATH=\$PATH:/usr/local/cuda/bin" >> /etc/profile.d/cuda.sh'
|
||||
```
|
||||
|
||||
**Explanation:**
|
||||
|
||||
- We add to `/etc/profile.d/` as the `/etc/` folder is unique to this particular container, and is not shared with other containers or the host system.
|
||||
- The backslash `\` before `$PATH` ensures the variable is correctly written into the script.
|
||||
|
||||
2. **Make the Script Executable:**
|
||||
|
||||
```bash
|
||||
sudo chmod +x /etc/profile.d/cuda.sh
|
||||
```
|
||||
|
||||
3. **Source the Script to Update Your Environment:**
|
||||
|
||||
```bash
|
||||
source /etc/profile.d/cuda.sh
|
||||
```
|
||||
|
||||
**Note:** This command updates your current shell session with the new `PATH`. The `/etc/profile.d/cuda.sh` script ensures that the CUDA binaries are available in your `PATH` for all future sessions.
|
||||
|
||||
## Verifying the Installation
|
||||
|
||||
To confirm that CUDA is correctly installed and configured, check the version of the NVIDIA CUDA Compiler (`nvcc`):
|
||||
|
||||
```bash
|
||||
nvcc --version
|
||||
```
|
||||
|
||||
You should see output similar to:
|
||||
|
||||
```
|
||||
nvcc: NVIDIA (R) Cuda compiler driver
|
||||
Copyright (c) 2005-2024 NVIDIA Corporation
|
||||
Built on Tue_Oct_29_23:50:19_PDT_2024
|
||||
Cuda compilation tools, release 12.6, V12.6.85
|
||||
Build cuda_12.6.r12.6/compiler.35059454_0
|
||||
```
|
||||
|
||||
This output confirms that the CUDA compiler is accessible and indicates the installed version.
|
||||
|
||||
## Conclusion
|
||||
|
||||
You have successfully set up CUDA on Fedora within a toolbox environment using the Fedora 39 CUDA repository. By manually resolving package conflicts and configuring the environment, you can develop CUDA applications without affecting your host system.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **Installation Failures:**
|
||||
- If you encounter errors during installation, carefully read the error messages. They often indicate conflicting files or missing dependencies.
|
||||
- Use the `--excludepath` option with `rpm` to exclude conflicting files during manual installations.
|
||||
|
||||
- **Driver Conflicts:**
|
||||
- Since the host system may already have NVIDIA drivers installed, conflicts can arise. Using the toolbox environment helps isolate these issues.
|
||||
|
||||
- **Environment Variables Not Set:**
|
||||
- If `nvcc` is not found after installation, ensure that `/usr/local/cuda/bin` is in your `PATH`.
|
||||
- Run `echo $PATH` to check if the path is included.
|
||||
- Re-source the profile script or open a new terminal session.
|
||||
|
||||
## Additional Notes
|
||||
|
||||
- **Updating CUDA in the Future:**
|
||||
- Keep an eye on the official NVIDIA repositories for updates to your Fedora version.
|
||||
- When an updated repository becomes available, adjust your `dnf` configuration accordingly.
|
||||
|
||||
- **Building `llama.cpp`:**
|
||||
- With CUDA installed, you can follow these [build instructions for `llama.cpp`](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) to compile it with CUDA support.
|
||||
- Ensure that any CUDA-specific build flags or paths are correctly set in your build configuration.
|
||||
|
||||
- **Using the Toolbox Environment:**
|
||||
- The toolbox environment is isolated from your host system, which helps prevent conflicts.
|
||||
- Remember that system files and configurations inside the toolbox are separate from the host. By default the home directory of the user is shared between the host and the toolbox.
|
||||
|
||||
---
|
||||
|
||||
**Disclaimer:** Manually installing and modifying system packages can lead to instability of the container. The above steps are provided as a guideline and may need adjustments based on your specific system configuration. Always back up important data before making significant system changes, especially as your home folder is writable and shared with he toolbox.
|
||||
|
||||
**Acknowledgments:** Special thanks to the Fedora community and NVIDIA documentation for providing resources that assisted in creating this guide.
|
||||
|
||||
## References
|
||||
|
||||
- [Fedora Toolbox Documentation](https://docs.fedoraproject.org/en-US/fedora-silverblue/toolbox/)
|
||||
- [NVIDIA CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
|
||||
- [Podman Documentation](https://podman.io/get-started)
|
||||
|
||||
---
|
|
@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
|
|||
```python
|
||||
@Model.register("MyModelForCausalLM")
|
||||
class MyModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GROK
|
||||
model_arch = gguf.MODEL_ARCH.MYMODEL
|
||||
```
|
||||
|
||||
2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
|
||||
|
@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
|
|||
- `Model#set_vocab`
|
||||
- `Model#write_tensors`
|
||||
|
||||
NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
|
||||
NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
|
||||
|
||||
### 2. Define the model architecture in `llama.cpp`
|
||||
|
||||
The model params and tensors layout must be defined in `llama.cpp`:
|
||||
1. Define a new `llm_arch`
|
||||
2. Define the tensors layout in `LLM_TENSOR_NAMES`
|
||||
3. Add any non standard metadata in `llm_load_hparams`
|
||||
3. Add any non-standard metadata in `llm_load_hparams`
|
||||
4. Create the tensors for inference in `llm_load_tensors`
|
||||
5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
|
||||
|
||||
|
@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
|
|||
|
||||
This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
|
||||
|
||||
Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
|
||||
|
||||
When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
|
||||
Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
|
||||
|
||||
Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
|
||||
|
||||
|
|
Binary file not shown.
|
@ -62,53 +62,57 @@
|
|||
<!-- action buttons (top right) -->
|
||||
<div class="flex items-center">
|
||||
<div v-if="messages.length > 0" class="dropdown dropdown-end">
|
||||
<!-- "more" button -->
|
||||
<!-- "..." button -->
|
||||
<button tabindex="0" role="button" class="btn m-1" :disabled="isGenerating">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-three-dots-vertical" viewBox="0 0 16 16">
|
||||
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0"/>
|
||||
</svg>
|
||||
</button>
|
||||
<!-- "more" dropdown menu -->
|
||||
<!-- "delete" dropdown menu -->
|
||||
<ul tabindex="0" class="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
|
||||
<li @click="downloadConv(viewingConvId)"><a>Download</a></li>
|
||||
<li class="text-error" @click="deleteConv(viewingConvId)"><a>Delete</a></li>
|
||||
</ul>
|
||||
</div>
|
||||
<button class="btn" @click="showConfigDialog = true" :disabled="isGenerating">
|
||||
<!-- settings button -->
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-gear" viewBox="0 0 16 16">
|
||||
<path d="M8 4.754a3.246 3.246 0 1 0 0 6.492 3.246 3.246 0 0 0 0-6.492M5.754 8a2.246 2.246 0 1 1 4.492 0 2.246 2.246 0 0 1-4.492 0"/>
|
||||
<path d="M9.796 1.343c-.527-1.79-3.065-1.79-3.592 0l-.094.319a.873.873 0 0 1-1.255.52l-.292-.16c-1.64-.892-3.433.902-2.54 2.541l.159.292a.873.873 0 0 1-.52 1.255l-.319.094c-1.79.527-1.79 3.065 0 3.592l.319.094a.873.873 0 0 1 .52 1.255l-.16.292c-.892 1.64.901 3.434 2.541 2.54l.292-.159a.873.873 0 0 1 1.255.52l.094.319c.527 1.79 3.065 1.79 3.592 0l.094-.319a.873.873 0 0 1 1.255-.52l.292.16c1.64.893 3.434-.902 2.54-2.541l-.159-.292a.873.873 0 0 1 .52-1.255l.319-.094c1.79-.527 1.79-3.065 0-3.592l-.319-.094a.873.873 0 0 1-.52-1.255l.16-.292c.893-1.64-.902-3.433-2.541-2.54l-.292.159a.873.873 0 0 1-1.255-.52zm-2.633.283c.246-.835 1.428-.835 1.674 0l.094.319a1.873 1.873 0 0 0 2.693 1.115l.291-.16c.764-.415 1.6.42 1.184 1.185l-.159.292a1.873 1.873 0 0 0 1.116 2.692l.318.094c.835.246.835 1.428 0 1.674l-.319.094a1.873 1.873 0 0 0-1.115 2.693l.16.291c.415.764-.42 1.6-1.185 1.184l-.291-.159a1.873 1.873 0 0 0-2.693 1.116l-.094.318c-.246.835-1.428.835-1.674 0l-.094-.319a1.873 1.873 0 0 0-2.692-1.115l-.292.16c-.764.415-1.6-.42-1.184-1.185l.159-.291A1.873 1.873 0 0 0 1.945 8.93l-.319-.094c-.835-.246-.835-1.428 0-1.674l.319-.094A1.873 1.873 0 0 0 3.06 4.377l-.16-.292c-.415-.764.42-1.6 1.185-1.184l.292.159a1.873 1.873 0 0 0 2.692-1.115z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<div class="tooltip tooltip-bottom" data-tip="Settings">
|
||||
<button class="btn" @click="showConfigDialog = true" :disabled="isGenerating">
|
||||
<!-- settings button -->
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-gear" viewBox="0 0 16 16">
|
||||
<path d="M8 4.754a3.246 3.246 0 1 0 0 6.492 3.246 3.246 0 0 0 0-6.492M5.754 8a2.246 2.246 0 1 1 4.492 0 2.246 2.246 0 0 1-4.492 0"/>
|
||||
<path d="M9.796 1.343c-.527-1.79-3.065-1.79-3.592 0l-.094.319a.873.873 0 0 1-1.255.52l-.292-.16c-1.64-.892-3.433.902-2.54 2.541l.159.292a.873.873 0 0 1-.52 1.255l-.319.094c-1.79.527-1.79 3.065 0 3.592l.319.094a.873.873 0 0 1 .52 1.255l-.16.292c-.892 1.64.901 3.434 2.541 2.54l.292-.159a.873.873 0 0 1 1.255.52l.094.319c.527 1.79 3.065 1.79 3.592 0l.094-.319a.873.873 0 0 1 1.255-.52l.292.16c1.64.893 3.434-.902 2.54-2.541l-.159-.292a.873.873 0 0 1 .52-1.255l.319-.094c1.79-.527 1.79-3.065 0-3.592l-.319-.094a.873.873 0 0 1-.52-1.255l.16-.292c.893-1.64-.902-3.433-2.541-2.54l-.292.159a.873.873 0 0 1-1.255-.52zm-2.633.283c.246-.835 1.428-.835 1.674 0l.094.319a1.873 1.873 0 0 0 2.693 1.115l.291-.16c.764-.415 1.6.42 1.184 1.185l-.159.292a1.873 1.873 0 0 0 1.116 2.692l.318.094c.835.246.835 1.428 0 1.674l-.319.094a1.873 1.873 0 0 0-1.115 2.693l.16.291c.415.764-.42 1.6-1.185 1.184l-.291-.159a1.873 1.873 0 0 0-2.693 1.116l-.094.318c-.246.835-1.428.835-1.674 0l-.094-.319a1.873 1.873 0 0 0-2.692-1.115l-.292.16c-.764.415-1.6-.42-1.184-1.185l.159-.291A1.873 1.873 0 0 0 1.945 8.93l-.319-.094c-.835-.246-.835-1.428 0-1.674l.319-.094A1.873 1.873 0 0 0 3.06 4.377l-.16-.292c-.415-.764.42-1.6 1.185-1.184l.292.159a1.873 1.873 0 0 0 2.692-1.115z"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- theme controller is copied from https://daisyui.com/components/theme-controller/ -->
|
||||
<div class="dropdown dropdown-end dropdown-bottom">
|
||||
<div tabindex="0" role="button" class="btn m-1">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-palette2" viewBox="0 0 16 16">
|
||||
<path d="M0 .5A.5.5 0 0 1 .5 0h5a.5.5 0 0 1 .5.5v5.277l4.147-4.131a.5.5 0 0 1 .707 0l3.535 3.536a.5.5 0 0 1 0 .708L10.261 10H15.5a.5.5 0 0 1 .5.5v5a.5.5 0 0 1-.5.5H3a3 3 0 0 1-2.121-.879A3 3 0 0 1 0 13.044m6-.21 7.328-7.3-2.829-2.828L6 7.188zM4.5 13a1.5 1.5 0 1 0-3 0 1.5 1.5 0 0 0 3 0M15 15v-4H9.258l-4.015 4zM0 .5v12.495zm0 12.495V13z"/>
|
||||
</svg>
|
||||
<div class="tooltip tooltip-bottom" data-tip="Themes">
|
||||
<div class="dropdown dropdown-end dropdown-bottom">
|
||||
<div tabindex="0" role="button" class="btn m-1">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" class="bi bi-palette2" viewBox="0 0 16 16">
|
||||
<path d="M0 .5A.5.5 0 0 1 .5 0h5a.5.5 0 0 1 .5.5v5.277l4.147-4.131a.5.5 0 0 1 .707 0l3.535 3.536a.5.5 0 0 1 0 .708L10.261 10H15.5a.5.5 0 0 1 .5.5v5a.5.5 0 0 1-.5.5H3a3 3 0 0 1-2.121-.879A3 3 0 0 1 0 13.044m6-.21 7.328-7.3-2.829-2.828L6 7.188zM4.5 13a1.5 1.5 0 1 0-3 0 1.5 1.5 0 0 0 3 0M15 15v-4H9.258l-4.015 4zM0 .5v12.495zm0 12.495V13z"/>
|
||||
</svg>
|
||||
</div>
|
||||
<ul tabindex="0" class="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto">
|
||||
<li>
|
||||
<button
|
||||
class="btn btn-sm btn-block btn-ghost justify-start"
|
||||
:class="{ 'btn-active': selectedTheme === 'auto' }"
|
||||
@click="setSelectedTheme('auto')">
|
||||
auto
|
||||
</button>
|
||||
</li>
|
||||
<li v-for="theme in themes">
|
||||
<input
|
||||
type="radio"
|
||||
name="theme-dropdown"
|
||||
class="theme-controller btn btn-sm btn-block btn-ghost justify-start"
|
||||
:aria-label="theme"
|
||||
:value="theme"
|
||||
:checked="selectedTheme === theme"
|
||||
@click="setSelectedTheme(theme)" />
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
<ul tabindex="0" class="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto">
|
||||
<li>
|
||||
<button
|
||||
class="btn btn-sm btn-block btn-ghost justify-start"
|
||||
:class="{ 'btn-active': selectedTheme === 'auto' }"
|
||||
@click="setSelectedTheme('auto')">
|
||||
auto
|
||||
</button>
|
||||
</li>
|
||||
<li v-for="theme in themes">
|
||||
<input
|
||||
type="radio"
|
||||
name="theme-dropdown"
|
||||
class="theme-controller btn btn-sm btn-block btn-ghost justify-start"
|
||||
:aria-label="theme"
|
||||
:value="theme"
|
||||
:checked="selectedTheme === theme"
|
||||
@click="setSelectedTheme(theme)" />
|
||||
</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
|
80
examples/tts/README.md
Normal file
80
examples/tts/README.md
Normal file
|
@ -0,0 +1,80 @@
|
|||
# llama.cpp/example/tts
|
||||
This example demonstrates the Text To Speech feature. It uses a
|
||||
[model](https://www.outeai.com/blog/outetts-0.2-500m) from
|
||||
[outeai](https://www.outeai.com/).
|
||||
|
||||
## Quickstart
|
||||
If you have built llama.cpp with `-DLLAMA_CURL=ON` you can simply run the
|
||||
following command and the required models will be downloaded automatically:
|
||||
```console
|
||||
$ build/bin/llama-tts --tts-oute-default -p "Hello world" && aplay output.wav
|
||||
```
|
||||
For details about the models and how to convert them to the required format
|
||||
see the following sections.
|
||||
|
||||
### Model conversion
|
||||
Checkout or download the model that contains the LLM model:
|
||||
```console
|
||||
$ pushd models
|
||||
$ git clone --branch main --single-branch --depth 1 https://huggingface.co/OuteAI/OuteTTS-0.2-500M
|
||||
$ cd OuteTTS-0.2-500M && git lfs install && git lfs pull
|
||||
$ popd
|
||||
```
|
||||
Convert the model to .gguf format:
|
||||
```console
|
||||
(venv) python convert_hf_to_gguf.py models/OuteTTS-0.2-500M \
|
||||
--outfile models/outetts-0.2-0.5B-f16.gguf --outtype f16
|
||||
```
|
||||
The generated model will be `models/outetts-0.2-0.5B-f16.gguf`.
|
||||
|
||||
We can optionally quantize this to Q8_0 using the following command:
|
||||
```console
|
||||
$ build/bin/llama-quantize models/outetts-0.2-0.5B-f16.gguf \
|
||||
models/outetts-0.2-0.5B-q8_0.gguf q8_0
|
||||
```
|
||||
The quantized model will be `models/outetts-0.2-0.5B-q8_0.gguf`.
|
||||
|
||||
Next we do something simlar for the audio decoder. First download or checkout
|
||||
the model for the voice decoder:
|
||||
```console
|
||||
$ pushd models
|
||||
$ git clone --branch main --single-branch --depth 1 https://huggingface.co/novateur/WavTokenizer-large-speech-75token
|
||||
$ cd WavTokenizer-large-speech-75token && git lfs install && git lfs pull
|
||||
$ popd
|
||||
```
|
||||
This model file is PyTorch checkpoint (.ckpt) and we first need to convert it to
|
||||
huggingface format:
|
||||
```console
|
||||
(venv) python examples/tts/convert_pt_to_hf.py \
|
||||
models/WavTokenizer-large-speech-75token/wavtokenizer_large_speech_320_24k.ckpt
|
||||
...
|
||||
Model has been successfully converted and saved to models/WavTokenizer-large-speech-75token/model.safetensors
|
||||
Metadata has been saved to models/WavTokenizer-large-speech-75token/index.json
|
||||
Config has been saved to models/WavTokenizer-large-speech-75tokenconfig.json
|
||||
```
|
||||
Then we can convert the huggingface format to gguf:
|
||||
```console
|
||||
(venv) python convert_hf_to_gguf.py models/WavTokenizer-large-speech-75token \
|
||||
--outfile models/wavtokenizer-large-75-f16.gguf --outtype f16
|
||||
...
|
||||
INFO:hf-to-gguf:Model successfully exported to models/wavtokenizer-large-75-f16.gguf
|
||||
```
|
||||
|
||||
### Running the example
|
||||
|
||||
With both of the models generated, the LLM model and the voice decoder model,
|
||||
we can run the example:
|
||||
```console
|
||||
$ build/bin/llama-tts -m ./models/outetts-0.2-0.5B-q8_0.gguf \
|
||||
-mv ./models/wavtokenizer-large-75-f16.gguf \
|
||||
-p "Hello world"
|
||||
...
|
||||
main: audio written to file 'output.wav'
|
||||
```
|
||||
The output.wav file will contain the audio of the prompt. This can be heard
|
||||
by playing the file with a media player. On Linux the following command will
|
||||
play the audio:
|
||||
```console
|
||||
$ aplay output.wav
|
||||
```
|
||||
|
|
@ -501,6 +501,7 @@ extern "C" {
|
|||
GGML_OP_GET_REL_POS,
|
||||
GGML_OP_ADD_REL_POS,
|
||||
GGML_OP_RWKV_WKV6,
|
||||
GGML_OP_GATED_LINEAR_ATTN,
|
||||
|
||||
GGML_OP_UNARY,
|
||||
|
||||
|
@ -1859,6 +1860,15 @@ extern "C" {
|
|||
struct ggml_tensor * td,
|
||||
struct ggml_tensor * state);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * state,
|
||||
float scale);
|
||||
|
||||
// custom operators
|
||||
|
||||
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
|
||||
|
|
|
@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
|
|||
static void ggml_compute_forward_rwkv_wkv6_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[3];
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[2];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t n_seqs = dst->src[5]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
|
||||
|
@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
|
|||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_gla
|
||||
|
||||
static void ggml_compute_forward_gla_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
const int64_t T = dst->src[1]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t HEADS = dst->src[1]->ne[1];
|
||||
const int64_t n_seqs = dst->src[4]->ne[1];
|
||||
const int64_t head_size = C / HEADS;
|
||||
const float scale = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
float * dst_data = (float *) dst->data;
|
||||
float * state = ((float *) dst->data) + C * T;
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
if (ith >= HEADS) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int h_start = (HEADS * ith) / nth;
|
||||
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
|
||||
(HEADS * (ith + 1)) / nth : HEADS;
|
||||
|
||||
float * k = (float *) dst->src[0]->data;
|
||||
float * v = (float *) dst->src[1]->data;
|
||||
float * q = (float *) dst->src[2]->data;
|
||||
float * g = (float *) dst->src[3]->data;
|
||||
|
||||
size_t t_stride = HEADS * head_size; // Same to C
|
||||
|
||||
size_t h_stride = C / HEADS;
|
||||
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
|
||||
size_t h_stride_2d = head_size * head_size;
|
||||
|
||||
if (ith == 0) {
|
||||
memset(dst_data, 0, T * C * sizeof(float));
|
||||
}
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
|
||||
#if defined(__AVX__) && !defined(__AVX512F__)
|
||||
#define GGML_F32X GGML_F32x8
|
||||
#define GGML_F32X_SET1 GGML_F32x8_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x8_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x8_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x8_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x8_FMA
|
||||
#define GLA_VECTOR_SIZE 8
|
||||
#elif defined(__AVX512F__)
|
||||
#define GGML_F32X GGML_F32x16
|
||||
#define GGML_F32X_SET1 GGML_F32x16_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x16_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x16_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x16_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x16_FMA
|
||||
#define GLA_VECTOR_SIZE 16
|
||||
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||
#define GGML_F32X GGML_F32x4
|
||||
#define GGML_F32X_SET1 GGML_F32x4_SET1
|
||||
#define GGML_F32X_LOAD GGML_F32x4_LOAD
|
||||
#define GGML_F32X_STORE GGML_F32x4_STORE
|
||||
#define GGML_F32X_MUL GGML_F32x4_MUL
|
||||
#define GGML_F32X_FMA GGML_F32x4_FMA
|
||||
#define GLA_VECTOR_SIZE 4
|
||||
#endif
|
||||
|
||||
#ifdef GLA_VECTOR_SIZE
|
||||
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
|
||||
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float q_val = q[t_h_i_offset] * scale;
|
||||
float g_val = g[t_h_i_offset];
|
||||
|
||||
// Broadcast scalar values to vectors
|
||||
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
|
||||
GGML_F32X q_vec = GGML_F32X_SET1(q_val);
|
||||
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
|
||||
|
||||
for (int64_t j = 0; j < vec_count; j++) {
|
||||
size_t base_j = j * GLA_VECTOR_SIZE;
|
||||
size_t t_h_j_offset = t_h_offset + base_j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
|
||||
|
||||
// Load x elements at once
|
||||
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
|
||||
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
|
||||
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
|
||||
|
||||
// Compute kv = v * k
|
||||
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
|
||||
|
||||
// Compute temp = prev_state * g + kv
|
||||
GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
|
||||
|
||||
// Update dst: dst += temp * q
|
||||
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
|
||||
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
|
||||
|
||||
// Update state
|
||||
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
|
||||
}
|
||||
|
||||
// Handle remaining elements, this will not be used.
|
||||
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = kv_val + prev_state_val * g_val;
|
||||
dst_data[t_h_j_offset] += temp_val * q_val;
|
||||
state_cur[h_2d_i_j_offset] = temp_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
for (int64_t t = 0; t < T; t++) {
|
||||
size_t t_offset = t * t_stride;
|
||||
size_t state_offset = head_size * C * (t / (T / n_seqs));
|
||||
float * state_cur = state + state_offset;
|
||||
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
|
||||
|
||||
for (int64_t h = h_start; h < h_end; h++) {
|
||||
size_t h_offset = h * h_stride;
|
||||
size_t t_h_offset = t_offset + h_offset;
|
||||
size_t h_2d_offset = h * h_stride_2d;
|
||||
|
||||
for (int64_t i = 0; i < head_size; i++) {
|
||||
size_t t_h_i_offset = t_h_offset + i;
|
||||
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
|
||||
|
||||
float k_val = k[t_h_i_offset];
|
||||
float q_val = q[t_h_i_offset] * scale;
|
||||
float g_val = g[t_h_i_offset];
|
||||
|
||||
for (int64_t j = 0; j < head_size; j++) {
|
||||
size_t t_h_j_offset = t_h_offset + j;
|
||||
size_t h_2d_i_j_offset = h_2d_i_offset + j;
|
||||
|
||||
float v_val = v[t_h_j_offset];
|
||||
float kv_val = v_val * k_val;
|
||||
float prev_state_val = state_prev[h_2d_i_j_offset];
|
||||
float temp_val = prev_state_val * g_val + kv_val;
|
||||
dst_data[t_h_j_offset] += temp_val * q_val;
|
||||
state_cur[h_2d_i_j_offset] = temp_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
static void ggml_compute_forward_gla(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_gla_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_map_unary
|
||||
|
||||
static void ggml_compute_forward_map_unary_f32(
|
||||
|
@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
{
|
||||
ggml_compute_forward_rwkv_wkv6(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
{
|
||||
ggml_compute_forward_gla(params, tensor);
|
||||
} break;
|
||||
case GGML_OP_MAP_UNARY:
|
||||
{
|
||||
ggml_unary_op_f32_t fun;
|
||||
|
@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
case GGML_OP_WIN_UNPART:
|
||||
case GGML_OP_GET_REL_POS:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
case GGML_OP_MAP_UNARY:
|
||||
case GGML_OP_MAP_BINARY:
|
||||
case GGML_OP_MAP_CUSTOM1_F32:
|
||||
|
|
|
@ -37,6 +37,7 @@
|
|||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/wkv6.cuh"
|
||||
#include "ggml-cuda/gla.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_RWKV_WKV6:
|
||||
ggml_cuda_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
ggml_cuda_op_gated_linear_attn(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
ggml_cuda_cross_entropy_loss_back(ctx, dst);
|
||||
break;
|
||||
|
@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_GATED_LINEAR_ATTN:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT: {
|
||||
#ifndef FLASH_ATTN_AVAILABLE
|
||||
|
|
93
ggml/src/ggml-cuda/gla.cu
Normal file
93
ggml/src/ggml-cuda/gla.cu
Normal file
|
@ -0,0 +1,93 @@
|
|||
#include "common.cuh"
|
||||
#include "gla.cuh"
|
||||
|
||||
template<int HEAD_SIZE>
|
||||
static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
|
||||
const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
|
||||
const int head_size = HEAD_SIZE;
|
||||
const int batch_i = bid / H;
|
||||
const int head_i = bid % H;
|
||||
const int state_size = C * head_size;
|
||||
const int n_seq_tokens = T / B;
|
||||
|
||||
float state[head_size];
|
||||
__shared__ float _k[head_size], _r[head_size], _td[head_size];
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
||||
}
|
||||
|
||||
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
|
||||
__syncthreads();
|
||||
_k[tid] = k[t];
|
||||
_r[tid] = r[t];
|
||||
_td[tid] = td[t];
|
||||
__syncthreads();
|
||||
|
||||
const float _v = v[t];
|
||||
float y = 0;
|
||||
for (int j = 0; j < head_size; j += 4) {
|
||||
const float4 & k = (float4 &)(_k[j]);
|
||||
const float4 & r = (float4 &)(_r[j]);
|
||||
const float4 & td = (float4 &)(_td[j]);
|
||||
float4 & s = (float4 &)(state[j]);
|
||||
float4 kv;
|
||||
|
||||
kv.x = k.x * _v;
|
||||
kv.y = k.y * _v;
|
||||
kv.z = k.z * _v;
|
||||
kv.w = k.w * _v;
|
||||
|
||||
s.x = s.x * td.x + kv.x;
|
||||
s.y = s.y * td.y + kv.y;
|
||||
s.z = s.z * td.z + kv.z;
|
||||
s.w = s.w * td.w + kv.w;
|
||||
|
||||
y += r.x * s.x;
|
||||
y += r.y * s.y;
|
||||
y += r.z * s.z;
|
||||
y += r.w * s.w;
|
||||
}
|
||||
dst[t] = y * scale;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < head_size; i++) {
|
||||
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const float * k_d = (const float *)dst->src[0]->data;
|
||||
const float * v_d = (const float *)dst->src[1]->data;
|
||||
const float * r_d = (const float *)dst->src[2]->data;
|
||||
const float * td_d = (const float *)dst->src[3]->data;
|
||||
const float * s_d = (const float *)dst->src[4]->data;
|
||||
|
||||
const int64_t B = dst->src[4]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float scale;
|
||||
memcpy(&scale, (float*)dst->op_params, sizeof(float));
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
GGML_ASSERT(C / H == 64 || C / H == 128);
|
||||
|
||||
|
||||
if (C / H == 64) {
|
||||
gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
||||
} else {
|
||||
gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
|
||||
}
|
||||
}
|
3
ggml/src/ggml-cuda/gla.cuh
Normal file
3
ggml/src/ggml-cuda/gla.cuh
Normal file
|
@ -0,0 +1,3 @@
|
|||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
|
|||
const float * s_d = (const float *)dst->src[5]->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[3];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[2];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
|
||||
|
|
|
@ -51,6 +51,10 @@ void ggml_sycl_host_free(void* ptr) try {
|
|||
std::exit(1);
|
||||
}
|
||||
|
||||
bool gpu_has_xmx(sycl::device &dev) {
|
||||
return dev.has(sycl::aspect::ext_intel_matrix);
|
||||
}
|
||||
|
||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
|
||||
const int64_t max_range = std::numeric_limits<int>::max();
|
||||
int64_t sycl_down_blk_size = block_size;
|
||||
|
|
|
@ -662,6 +662,7 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
|
|||
}
|
||||
}
|
||||
|
||||
bool gpu_has_xmx(sycl::device &dev);
|
||||
|
||||
void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
|
|
|
@ -158,8 +158,9 @@ static void concat_f32_sycl_non_cont(
|
|||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst) {
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
const int32_t dim = ((int32_t *)dst->op_params)[0];
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst);
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_CONCAT_HPP
|
||||
|
|
|
@ -71,8 +71,9 @@ static void conv_transpose_1d_f32_f32_sycl(
|
|||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst) {
|
||||
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst);
|
||||
void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_CONV_HPP
|
||||
|
|
|
@ -882,149 +882,149 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
}
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
|
|
@ -25,52 +25,52 @@ static __dpct_inline__ float op_div(const float a, const float b) {
|
|||
}
|
||||
|
||||
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||
|
|
|
@ -54,18 +54,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|||
GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
|
||||
|
||||
int64_t total_vram = 0;
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: yes\n", __func__);
|
||||
#else
|
||||
GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ: no\n", __func__);
|
||||
#endif
|
||||
#if defined(SYCL_USE_XMX)
|
||||
GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
|
||||
#else
|
||||
GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
|
||||
#endif
|
||||
GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
|
||||
|
||||
/* This is a bit misleading; reserved for later */
|
||||
// #if defined(SYCL_USE_XMX)
|
||||
// GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
|
||||
// #else
|
||||
// GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
|
||||
// #endif
|
||||
for (int i = 0; i < info.device_count; ++i) {
|
||||
info.devices[i].vmm = 0;
|
||||
dpct::device_info prop;
|
||||
|
@ -109,11 +103,11 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
|
|||
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
|
||||
|
||||
auto global_mem_size = prop.get_global_mem_size()/1000000;
|
||||
|
||||
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
|
||||
std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
|
||||
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
|
||||
name.c_str(), version.c_str(), prop.get_max_compute_units(),
|
||||
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
|
||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
|
||||
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str(), xmx.c_str());
|
||||
}
|
||||
|
||||
void ggml_backend_sycl_print_sycl_devices() {
|
||||
|
@ -124,16 +118,16 @@ void ggml_backend_sycl_print_sycl_devices() {
|
|||
|
||||
GGML_LOG_INFO(
|
||||
"| | | | "
|
||||
" |Max | |Max |Global | |\n");
|
||||
" |Max | |Max |Global | | XMX |\n");
|
||||
GGML_LOG_INFO(
|
||||
"| | | | "
|
||||
" |compute|Max work|sub |mem | |\n");
|
||||
" |compute|Max work|sub |mem | | or |\n");
|
||||
GGML_LOG_INFO(
|
||||
"|ID| Device Type| "
|
||||
"Name|Version|units |group |group|size | Driver version|\n");
|
||||
"Name|Version|units |group |group|size | Driver version| Tensor Cores |\n");
|
||||
GGML_LOG_INFO(
|
||||
"|--|-------------------|---------------------------------------|------"
|
||||
"-|-------|--------|-----|-------|---------------------|\n");
|
||||
"-|-------|--------|-----|-------|---------------------|--------------|\n");
|
||||
|
||||
for (int id = 0; id < device_count; ++id) {
|
||||
sycl::device device = dpct::dev_mgr::instance().get_device(id);
|
||||
|
@ -164,14 +158,18 @@ static void ggml_check_sycl() try {
|
|||
static bool initialized = false;
|
||||
|
||||
if (!initialized) {
|
||||
GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
|
||||
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
||||
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
||||
GGML_LOG_INFO("%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
|
||||
|
||||
#if defined(GGML_SYCL_F16)
|
||||
GGML_LOG_INFO("%s: GGML_SYCL_F16: yes\n", __func__);
|
||||
GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
||||
#if defined(GGML_SYCL_FORCE_MMQ)
|
||||
GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO("%s: GGML_SYCL_F16: no\n", __func__);
|
||||
GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ: no\n");
|
||||
#endif
|
||||
#if defined(GGML_SYCL_F16)
|
||||
GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
|
||||
#else
|
||||
GGML_LOG_INFO("GGML_SYCL_F16: no\n");
|
||||
#endif
|
||||
|
||||
/* NOT REMOVE, keep it for next optimize for XMX.
|
||||
|
@ -1189,7 +1187,6 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(q
|
|||
/// kernels
|
||||
|
||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||
typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
|
||||
typedef void (*ggml_sycl_op_mul_mat_t)(
|
||||
ggml_backend_sycl_context & ctx,
|
||||
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
|
||||
|
@ -3171,33 +3168,33 @@ catch (sycl::exception const &exc) {
|
|||
}
|
||||
|
||||
|
||||
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
|
@ -3572,9 +3569,10 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1,
|
||||
static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
|
||||
ggml_tensor *dst) try {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
||||
|
||||
const ggml_tensor *ids = dst->src[2];
|
||||
|
@ -3740,12 +3738,12 @@ catch (sycl::exception const &exc) {
|
|||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
|
||||
static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
|
||||
}
|
||||
|
||||
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
|
||||
static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
|
||||
}
|
||||
|
||||
static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
|
@ -3787,7 +3785,6 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
|
|||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
GGML_UNUSED(dst);
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
|
@ -3796,59 +3793,52 @@ catch (sycl::exception const &exc) {
|
|||
std::exit(1);
|
||||
}
|
||||
|
||||
static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
// TODO: why do we pass dst as src1 here?
|
||||
ggml_sycl_cpy(ctx, src0, dst, nullptr);
|
||||
GGML_UNUSED(src1);
|
||||
ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
|
||||
}
|
||||
|
||||
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
|
||||
static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
|
||||
}
|
||||
|
||||
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
|
||||
static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
|
||||
}
|
||||
|
||||
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
|
||||
static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
|
||||
}
|
||||
|
||||
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
|
||||
static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
|
||||
}
|
||||
|
||||
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
|
||||
static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
|
||||
}
|
||||
|
||||
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
|
||||
static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
|
||||
}
|
||||
|
||||
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
|
||||
static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
|
||||
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
|
||||
}
|
||||
|
||||
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
|
||||
static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
|
||||
ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
|
||||
}
|
||||
|
||||
static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_UNUSED(src0);
|
||||
GGML_UNUSED(src1);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
void ggml_sycl_set_main_device(const int main_device) try {
|
||||
if (dpct::get_current_device_id() == static_cast<unsigned int> (main_device)) {
|
||||
|
@ -3871,191 +3861,189 @@ catch (sycl::exception const &exc) {
|
|||
std::exit(1);
|
||||
}
|
||||
|
||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
|
||||
bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
|
||||
if (!g_sycl_loaded) return false;
|
||||
|
||||
ggml_sycl_func_t func;
|
||||
if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
|
||||
ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
|
||||
}
|
||||
|
||||
switch (tensor->op) {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ARGMAX:
|
||||
func = ggml_sycl_argmax;
|
||||
ggml_sycl_argmax(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
func = ggml_sycl_op_conv_transpose_1d;
|
||||
ggml_sycl_op_conv_transpose_1d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_REPEAT:
|
||||
func = ggml_sycl_repeat;
|
||||
ggml_sycl_repeat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
func = ggml_sycl_get_rows;
|
||||
ggml_sycl_get_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DUP:
|
||||
func = ggml_sycl_dup;
|
||||
ggml_sycl_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD1: // TODO: more efficient implementation
|
||||
func = ggml_sycl_add;
|
||||
ggml_sycl_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUB:
|
||||
func = ggml_sycl_sub;
|
||||
ggml_sycl_sub(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
func = ggml_sycl_acc;
|
||||
ggml_sycl_acc(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL:
|
||||
func = ggml_sycl_mul;
|
||||
ggml_sycl_mul(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LOG:
|
||||
func = ggml_sycl_log;
|
||||
ggml_sycl_log(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIV:
|
||||
func = ggml_sycl_div;
|
||||
ggml_sycl_div(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
switch (ggml_get_unary_op(dst)) {
|
||||
case GGML_UNARY_OP_NEG:
|
||||
func = ggml_sycl_neg;
|
||||
ggml_sycl_neg(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
func = ggml_sycl_step;
|
||||
ggml_sycl_step(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
func = ggml_sycl_gelu;
|
||||
ggml_sycl_gelu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SILU:
|
||||
func = ggml_sycl_silu;
|
||||
ggml_sycl_silu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
func = ggml_sycl_gelu_quick;
|
||||
ggml_sycl_gelu_quick(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_TANH:
|
||||
func = ggml_sycl_tanh;
|
||||
ggml_sycl_tanh(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_RELU:
|
||||
func = ggml_sycl_relu;
|
||||
ggml_sycl_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
func = ggml_sycl_sigmoid;
|
||||
ggml_sycl_sigmoid(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
func = ggml_sycl_hardsigmoid;
|
||||
ggml_sycl_hardsigmoid(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_HARDSWISH:
|
||||
func = ggml_sycl_hardswish;
|
||||
ggml_sycl_hardswish(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
func = ggml_sycl_exp;
|
||||
ggml_sycl_exp(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NORM:
|
||||
func = ggml_sycl_norm;
|
||||
ggml_sycl_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_GROUP_NORM:
|
||||
func = ggml_sycl_group_norm;
|
||||
ggml_sycl_group_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
func = ggml_sycl_op_concat;
|
||||
ggml_sycl_op_concat(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
func = ggml_sycl_upscale;
|
||||
ggml_sycl_upscale(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_PAD:
|
||||
func = ggml_sycl_pad;
|
||||
ggml_sycl_pad(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
func = ggml_sycl_leaky_relu;
|
||||
ggml_sycl_leaky_relu(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
func = ggml_sycl_rms_norm;
|
||||
ggml_sycl_rms_norm(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_sycl_mul_mat;
|
||||
/* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
|
||||
ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
|
||||
break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
|
||||
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
func = ggml_sycl_mul_mat_id;
|
||||
ggml_sycl_mul_mat_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
func = ggml_sycl_op_out_prod;
|
||||
ggml_sycl_op_out_prod(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SCALE:
|
||||
func = ggml_sycl_scale;
|
||||
ggml_sycl_scale(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SQR:
|
||||
func = ggml_sycl_sqr;
|
||||
ggml_sycl_sqr(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SQRT:
|
||||
func = ggml_sycl_sqrt;
|
||||
ggml_sycl_sqrt(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SIN:
|
||||
func = ggml_sycl_sin;
|
||||
ggml_sycl_sin(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_COS:
|
||||
func = ggml_sycl_cos;
|
||||
ggml_sycl_cos(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CLAMP:
|
||||
func = ggml_sycl_clamp;
|
||||
ggml_sycl_clamp(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CPY:
|
||||
func = ggml_sycl_cpy;
|
||||
ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
|
||||
break;
|
||||
case GGML_OP_CONT:
|
||||
func = ggml_sycl_dup;
|
||||
ggml_sycl_dup(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
func = ggml_sycl_nop;
|
||||
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
func = ggml_sycl_diag_mask_inf;
|
||||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
func = ggml_sycl_soft_max;
|
||||
ggml_sycl_soft_max(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ROPE:
|
||||
func = ggml_sycl_rope;
|
||||
ggml_sycl_rope(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_IM2COL:
|
||||
func = ggml_sycl_im2col;
|
||||
ggml_sycl_im2col(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
func = ggml_sycl_pool2d;
|
||||
ggml_sycl_pool2d(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM:
|
||||
func = ggml_sycl_sum;
|
||||
ggml_sycl_sum(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
func = ggml_sycl_sum_rows;
|
||||
ggml_sycl_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ARGSORT:
|
||||
func = ggml_sycl_argsort;
|
||||
ggml_sycl_argsort(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_TIMESTEP_EMBEDDING:
|
||||
func = ggml_sycl_op_timestep_embedding;
|
||||
ggml_sycl_op_timestep_embedding(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
func = ggml_sycl_op_rwkv_wkv6;
|
||||
ggml_sycl_op_rwkv_wkv6(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
|
||||
ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
|
||||
}
|
||||
|
||||
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@
|
|||
#include "outprod.hpp"
|
||||
|
||||
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst);
|
||||
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_OUTPROD_HPP
|
||||
|
|
|
@ -55,8 +55,9 @@ static void timestep_embedding_f32_sycl(
|
|||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor * dst) {
|
||||
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
float * dst_d = (float *)dst->data;
|
||||
dpct::queue_ptr stream = ctx.stream();
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor * dst);
|
||||
void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
#endif // GGML_SYCL_TSEMBD_HPP
|
||||
|
|
|
@ -95,8 +95,10 @@ static void rwkv_wkv_f32_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
||||
const ggml_tensor* src1, ggml_tensor* dst) {
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
|
||||
|
||||
const ggml_tensor *src0 = dst->src[0];
|
||||
const ggml_tensor *src1 = dst->src[1];
|
||||
|
||||
const float* k_d = (const float*)dst->src[0]->data;
|
||||
const float* v_d = (const float*)dst->src[1]->data;
|
||||
|
@ -107,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
|
|||
float* dst_d = (float*)dst->data;
|
||||
|
||||
const int64_t B = dst->src[5]->ne[1];
|
||||
const int64_t T = dst->src[0]->ne[3];
|
||||
const int64_t T = dst->src[0]->ne[2];
|
||||
const int64_t C = dst->ne[0];
|
||||
const int64_t H = dst->src[0]->ne[2];
|
||||
const int64_t H = dst->src[0]->ne[1];
|
||||
|
||||
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(C % H == 0);
|
||||
|
|
|
@ -3,8 +3,7 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor * dst);
|
||||
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||
|
||||
|
||||
#endif // GGML_SYCL_WKV6_HPP
|
||||
|
|
|
@ -2277,6 +2277,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
if (device->subgroup_size_control) {
|
||||
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
|
||||
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
|
||||
device_extensions.push_back("VK_EXT_subgroup_size_control");
|
||||
}
|
||||
|
||||
device->subgroup_size_control = device->subgroup_size_control &&
|
||||
|
@ -2285,7 +2286,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|||
|
||||
if (device->subgroup_size_control) {
|
||||
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
|
||||
device_extensions.push_back("VK_EXT_subgroup_size_control");
|
||||
}
|
||||
|
||||
#if defined(VK_KHR_cooperative_matrix)
|
||||
|
@ -5633,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
|
|||
}
|
||||
|
||||
static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t seq_length = dst->src[0]->ne[3];
|
||||
const size_t seq_length = dst->src[0]->ne[2];
|
||||
const size_t n_embed = dst->ne[0];
|
||||
const size_t n_heads = dst->src[0]->ne[2];
|
||||
const size_t n_heads = dst->src[0]->ne[1];
|
||||
const size_t n_seqs = dst->src[5]->ne[1];
|
||||
|
||||
ggml_vk_op_f32_rwkv6(
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
#version 450
|
||||
|
||||
#ifdef FLOAT16
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#endif
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
|
|||
|
||||
#if K_PER_ITER == 8
|
||||
#if QUANT_R == 2
|
||||
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
|
||||
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
|
||||
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
|
||||
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
|
||||
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
||||
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
|
||||
#else
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -36,19 +36,19 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
|||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||
|
||||
const f16vec2 d = data_a[ib0 + i].d;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
|
||||
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
|
||||
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
|
||||
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
|
||||
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
|
||||
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
|
||||
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
|
||||
|
||||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#version 450
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -46,14 +46,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
|
|||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
|
||||
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
|
||||
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
|
||||
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
|
||||
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
|
||||
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
|
||||
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
|
||||
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -43,7 +43,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
const f16vec2 d = data_a[ib0 + i].d;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
|
@ -96,10 +96,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const FLOAT_TYPE q4_15 = qs64_hi4.w;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
|
||||
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
|
||||
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
|
||||
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
|
||||
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
|
||||
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
|
||||
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
|
||||
|
||||
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
|
||||
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -42,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
f16vec2 d = data_a[ib0 + i].d;
|
||||
vec2 d = vec2(data_a[ib0 + i].d);
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
|
@ -107,14 +107,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
|
||||
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
|
||||
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
|
||||
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
|
||||
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
|
||||
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
|
||||
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
|
||||
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
|
||||
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
|
||||
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
|
||||
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
|
||||
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
|
||||
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
|
||||
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
|
||||
|
||||
const FLOAT_TYPE sx =
|
||||
fma(FLOAT_TYPE(by10.x), q4_0,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
|
@ -57,10 +57,10 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
|||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
|
||||
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
|
||||
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
|
||||
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
|
||||
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
|
||||
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
|
||||
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
|
||||
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
|
||||
|
||||
FLOAT_TYPE sum[4] = {0, 0, 0, 0};
|
||||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
|
|
|
@ -2,7 +2,10 @@
|
|||
#if !defined(GGML_TYPES_COMP)
|
||||
#define GGML_TYPES_COMP
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#if defined(DATA_A_F32)
|
||||
#define QUANT_K 1
|
||||
|
|
|
@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"GET_REL_POS",
|
||||
"ADD_REL_POS",
|
||||
"RWKV_WKV6",
|
||||
"GATED_LINEAR_ATTN",
|
||||
|
||||
"UNARY",
|
||||
|
||||
|
@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"get_rel_pos(x)",
|
||||
"add_rel_pos(x)",
|
||||
"rwkv_wkv6(k, v, r, tf, td, s)",
|
||||
"gated_linear_attn(k, v, q, gate, s)",
|
||||
|
||||
"unary(x)",
|
||||
|
||||
|
@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -4629,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[2];
|
||||
const int64_t n_tokens = k->ne[3];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(k->ne[1] == 1);
|
||||
GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
|
||||
GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
|
||||
// TODO: RWKV v4 and v5
|
||||
GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
|
||||
GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
|
@ -4656,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
|
|||
return result;
|
||||
}
|
||||
|
||||
// ggml_gated_linear_attn
|
||||
|
||||
struct ggml_tensor * ggml_gated_linear_attn(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * k,
|
||||
struct ggml_tensor * v,
|
||||
struct ggml_tensor * q,
|
||||
struct ggml_tensor * g,
|
||||
struct ggml_tensor * state,
|
||||
float scale) {
|
||||
GGML_ASSERT(ggml_is_contiguous(k));
|
||||
GGML_ASSERT(ggml_is_contiguous(v));
|
||||
GGML_ASSERT(ggml_is_contiguous(q));
|
||||
GGML_ASSERT(ggml_is_contiguous(g));
|
||||
GGML_ASSERT(ggml_is_contiguous(state));
|
||||
|
||||
const int64_t S = k->ne[0];
|
||||
const int64_t H = k->ne[1];
|
||||
const int64_t n_tokens = k->ne[2];
|
||||
const int64_t n_seqs = state->ne[1];
|
||||
{
|
||||
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
|
||||
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
|
||||
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
|
||||
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
|
||||
}
|
||||
|
||||
// concat output and new_state
|
||||
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
ggml_set_op_params_f32(result, 0, scale);
|
||||
|
||||
result->op = GGML_OP_GATED_LINEAR_ATTN;
|
||||
result->src[0] = k;
|
||||
result->src[1] = v;
|
||||
result->src[2] = q;
|
||||
result->src[3] = g;
|
||||
result->src[4] = state;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_unary
|
||||
|
||||
static struct ggml_tensor * ggml_unary_impl(
|
||||
|
|
|
@ -15,6 +15,8 @@ pip install gguf
|
|||
|
||||
[examples/writer.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
|
||||
|
||||
[examples/reader.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/examples/reader.py) — Extracts and displays key-value pairs and tensor details from a GGUF file in a readable format.
|
||||
|
||||
[gguf/scripts/gguf_dump.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_dump.py) — Dumps a GGUF file's metadata to the console.
|
||||
|
||||
[gguf/scripts/gguf_set_metadata.py](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_set_metadata.py) — Allows changing simple metadata values in a GGUF file by key.
|
||||
|
|
|
@ -115,6 +115,7 @@ class Keys:
|
|||
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
|
||||
RESIDUAL_SCALE = "{arch}.residual_scale"
|
||||
EMBEDDING_SCALE = "{arch}.embedding_scale"
|
||||
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
|
@ -244,6 +245,7 @@ class MODEL_ARCH(IntEnum):
|
|||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PHIMOE = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
|
@ -254,6 +256,7 @@ class MODEL_ARCH(IntEnum):
|
|||
GEMMA2 = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
MAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
|
@ -333,6 +336,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
TIME_MIX_LERP_V = auto()
|
||||
TIME_MIX_LERP_R = auto()
|
||||
TIME_MIX_LERP_G = auto()
|
||||
TIME_MIX_LERP_FUSED = auto()
|
||||
TIME_MIX_LERP_W = auto()
|
||||
TIME_MIX_FIRST = auto()
|
||||
TIME_MIX_DECAY = auto()
|
||||
|
@ -428,6 +432,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PHIMOE: "phimoe",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
|
@ -438,6 +443,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
|
@ -517,6 +523,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused",
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w",
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first",
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay",
|
||||
|
@ -940,6 +947,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.PHIMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.CODESHELL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
|
@ -1083,6 +1108,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.TIME_MIX_LERP_R,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_FIRST,
|
||||
MODEL_TENSOR.TIME_MIX_DECAY,
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1,
|
||||
|
@ -1099,6 +1125,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE,
|
||||
],
|
||||
MODEL_ARCH.RWKV6QWEN2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.TIME_MIX_W1,
|
||||
MODEL_TENSOR.TIME_MIX_W2,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W,
|
||||
MODEL_TENSOR.TIME_MIX_LERP_FUSED,
|
||||
MODEL_TENSOR.TIME_MIX_FIRST,
|
||||
MODEL_TENSOR.TIME_MIX_DECAY,
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1,
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2,
|
||||
MODEL_TENSOR.TIME_MIX_KEY,
|
||||
MODEL_TENSOR.TIME_MIX_VALUE,
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE,
|
||||
MODEL_TENSOR.TIME_MIX_GATE,
|
||||
MODEL_TENSOR.TIME_MIX_LN,
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
@ -743,6 +743,9 @@ class GGUFWriter:
|
|||
def add_wkv_head_size(self, size: int) -> None:
|
||||
self.add_uint32(Keys.WKV.HEAD_SIZE.format(arch=self.arch), size)
|
||||
|
||||
def add_token_shift_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
|
|
|
@ -12,8 +12,8 @@ from typing import Any
|
|||
import numpy as np
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from pathlib import Path
|
|||
from tqdm import tqdm
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
|
|
|
@ -13,8 +13,8 @@ from tqdm import tqdm
|
|||
from typing import Any, Sequence, NamedTuple
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
import gguf
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ import sys
|
|||
from pathlib import Path
|
||||
|
||||
# Necessary to load the local gguf package
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from gguf import GGUFReader # noqa: E402
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ class TensorNameMap:
|
|||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
|
@ -55,7 +55,7 @@ class TensorNameMap:
|
|||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
|
@ -68,7 +68,7 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
|
@ -108,7 +108,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
|
@ -152,7 +152,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
|
@ -165,7 +165,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
|
@ -179,7 +179,7 @@ class TensorNameMap:
|
|||
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
|
@ -197,7 +197,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
|
@ -242,7 +242,7 @@ class TensorNameMap:
|
|||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
|
||||
"h.{bid}.post_attention_layernorm", # bloom
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
|
@ -265,7 +265,7 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
|
@ -310,10 +310,11 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
|
@ -342,10 +343,11 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
|
@ -387,6 +389,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
|
@ -461,34 +464,42 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||
|
@ -497,30 +508,37 @@ class TensorNameMap:
|
|||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
||||
"model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_KEY: (
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.key", # rwkv
|
||||
"model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.value", # rwkv
|
||||
"model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
"model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_GATE: (
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.gate", # rwkv
|
||||
"model.layers.{bid}.self_attn.gate", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_LN: (
|
||||
|
@ -528,7 +546,8 @@ class TensorNameMap:
|
|||
),
|
||||
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: (
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"rwkv.blocks.{bid}.attention.output", # rwkv
|
||||
"model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.14.0"
|
||||
version = "0.15.0"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 195 KiB |
|
@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
|
||||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PHI3, "phi3" },
|
||||
{ LLM_ARCH_PHIMOE, "phimoe" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
|
@ -56,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
|
||||
{ LLM_ARCH_GRANITE, "granite" },
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
|
@ -105,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" },
|
||||
{ LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" },
|
||||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
|
@ -584,6 +587,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PHIMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_PLAMO,
|
||||
{
|
||||
|
@ -1144,6 +1168,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
|
||||
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
|
||||
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
|
||||
|
@ -1161,6 +1186,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" },
|
||||
{ LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" },
|
||||
{ LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" },
|
||||
{ LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" },
|
||||
{ LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" },
|
||||
{ LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" },
|
||||
{ LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" },
|
||||
{ LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" },
|
||||
{ LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" },
|
||||
{ LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" },
|
||||
{ LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GRANITE,
|
||||
{
|
||||
|
@ -1343,6 +1394,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
{LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}},
|
||||
{LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
|
|
|
@ -31,6 +31,7 @@ enum llm_arch {
|
|||
LLM_ARCH_QWEN2VL,
|
||||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PHI3,
|
||||
LLM_ARCH_PHIMOE,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
|
@ -60,6 +61,7 @@ enum llm_arch {
|
|||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_RWKV6,
|
||||
LLM_ARCH_RWKV6QWEN2,
|
||||
LLM_ARCH_GRANITE,
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
|
@ -109,6 +111,7 @@ enum llm_kv {
|
|||
LLM_KV_TIME_DECAY_EXTRA_DIM,
|
||||
LLM_KV_RESIDUAL_SCALE,
|
||||
LLM_KV_EMBEDDING_SCALE,
|
||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
|
@ -252,6 +255,7 @@ enum llm_tensor {
|
|||
LLM_TENSOR_TIME_MIX_LERP_V,
|
||||
LLM_TENSOR_TIME_MIX_LERP_R,
|
||||
LLM_TENSOR_TIME_MIX_LERP_G,
|
||||
LLM_TENSOR_TIME_MIX_LERP_FUSED,
|
||||
LLM_TENSOR_TIME_MIX_FIRST,
|
||||
LLM_TENSOR_TIME_MIX_DECAY,
|
||||
LLM_TENSOR_TIME_MIX_DECAY_W1,
|
||||
|
|
|
@ -35,6 +35,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|||
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
|
||||
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
|
||||
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
|
||||
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
|
||||
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
|
||||
{ "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR },
|
||||
{ "monarch", LLM_CHAT_TEMPLATE_MONARCH },
|
||||
|
@ -73,7 +74,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|||
return tmpl.find(haystack) != std::string::npos;
|
||||
};
|
||||
if (tmpl_contains("<|im_start|>")) {
|
||||
return LLM_CHAT_TEMPLATE_CHATML;
|
||||
return tmpl_contains("<|im_sep|>")
|
||||
? LLM_CHAT_TEMPLATE_PHI_4
|
||||
: LLM_CHAT_TEMPLATE_CHATML;
|
||||
} else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
|
||||
if (tmpl_contains("[SYSTEM_PROMPT]")) {
|
||||
return LLM_CHAT_TEMPLATE_MISTRAL_V7;
|
||||
|
@ -269,6 +272,14 @@ int32_t llm_chat_apply_template(
|
|||
if (add_ass) {
|
||||
ss << "<|assistant|>\n";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) {
|
||||
// chatml template
|
||||
for (auto message : chat) {
|
||||
ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>";
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|im_start|>assistant<|im_sep|>";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) {
|
||||
// Falcon 3
|
||||
for (auto message : chat) {
|
||||
|
|
|
@ -15,6 +15,7 @@ enum llm_chat_template {
|
|||
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
|
||||
LLM_CHAT_TEMPLATE_MISTRAL_V7,
|
||||
LLM_CHAT_TEMPLATE_PHI_3,
|
||||
LLM_CHAT_TEMPLATE_PHI_4,
|
||||
LLM_CHAT_TEMPLATE_FALCON_3,
|
||||
LLM_CHAT_TEMPLATE_ZEPHYR,
|
||||
LLM_CHAT_TEMPLATE_MONARCH,
|
||||
|
|
|
@ -52,7 +52,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return 2 * n_embd;
|
||||
return token_shift_count * n_embd;
|
||||
}
|
||||
|
||||
// TODO: maybe support other convolution strides than 1
|
||||
|
|
|
@ -76,6 +76,7 @@ struct llama_hparams {
|
|||
uint32_t time_mix_extra_dim = 0;
|
||||
uint32_t time_decay_extra_dim = 0;
|
||||
uint32_t wkv_head_size = 0;
|
||||
uint32_t token_shift_count = 2;
|
||||
|
||||
float rope_attn_factor = 1.0f;
|
||||
float rope_freq_base_train;
|
||||
|
|
|
@ -76,6 +76,7 @@ const char * llm_type_name(llm_type type) {
|
|||
case MODEL_8x7B: return "8x7B";
|
||||
case MODEL_8x22B: return "8x22B";
|
||||
case MODEL_16x12B: return "16x12B";
|
||||
case MODEL_16x3_8B: return "16x3.8B";
|
||||
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
|
||||
case MODEL_57B_A14B: return "57B.A14B";
|
||||
case MODEL_27B: return "27B";
|
||||
|
@ -661,6 +662,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
|
|||
throw std::runtime_error("invalid value for sliding_window");
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_16x3_8B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PLAMO:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
@ -1044,12 +1054,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
|
|||
}
|
||||
} break;
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false);
|
||||
ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size);
|
||||
ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim);
|
||||
ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim);
|
||||
ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false);
|
||||
ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 24: model.type = e_model::MODEL_1_6B; break;
|
||||
|
@ -1060,6 +1073,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
} break;
|
||||
case 61: model.type = e_model::MODEL_14B; break;
|
||||
case 64: model.type = e_model::MODEL_32B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
|
@ -2054,6 +2068,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_T5ENCODER:
|
||||
case LLM_ARCH_JAIS:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
|
@ -2094,6 +2109,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_OLMOE:
|
||||
case LLM_ARCH_PHI2:
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
|
@ -2197,6 +2213,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
|
|||
switch (model->arch) {
|
||||
case LLM_ARCH_MAMBA: return true;
|
||||
case LLM_ARCH_RWKV6: return true;
|
||||
case LLM_ARCH_RWKV6QWEN2: return true;
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -73,6 +73,7 @@ enum llm_type {
|
|||
MODEL_8x7B,
|
||||
MODEL_8x22B,
|
||||
MODEL_16x12B,
|
||||
MODEL_16x3_8B,
|
||||
MODEL_10B_128x3_66B,
|
||||
MODEL_57B_A14B,
|
||||
MODEL_27B,
|
||||
|
@ -240,15 +241,19 @@ struct llama_layer {
|
|||
struct ggml_tensor * time_mix_lerp_v = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_r = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_g = nullptr;
|
||||
struct ggml_tensor * time_mix_lerp_fused = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_first = nullptr;
|
||||
struct ggml_tensor * time_mix_decay = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w2 = nullptr;
|
||||
struct ggml_tensor * time_mix_key = nullptr;
|
||||
struct ggml_tensor * time_mix_value = nullptr;
|
||||
struct ggml_tensor * time_mix_receptance = nullptr;
|
||||
struct ggml_tensor * time_mix_gate = nullptr;
|
||||
struct ggml_tensor * time_mix_first = nullptr;
|
||||
struct ggml_tensor * time_mix_decay = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w1 = nullptr;
|
||||
struct ggml_tensor * time_mix_decay_w2 = nullptr;
|
||||
struct ggml_tensor * time_mix_key = nullptr;
|
||||
struct ggml_tensor * time_mix_key_b = nullptr;
|
||||
struct ggml_tensor * time_mix_value = nullptr;
|
||||
struct ggml_tensor * time_mix_value_b = nullptr;
|
||||
struct ggml_tensor * time_mix_receptance = nullptr;
|
||||
struct ggml_tensor * time_mix_receptance_b = nullptr;
|
||||
struct ggml_tensor * time_mix_gate = nullptr;
|
||||
|
||||
struct ggml_tensor * time_mix_ln = nullptr;
|
||||
struct ggml_tensor * time_mix_ln_b = nullptr;
|
||||
|
|
|
@ -620,7 +620,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
|
||||
qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
|
||||
|
||||
// sanity checks
|
||||
// sanity checks for models that have attention layers
|
||||
if (qs.n_attention_wv != 0)
|
||||
{
|
||||
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
|
||||
// attention layers have a non-zero number of kv heads
|
||||
|
@ -758,6 +759,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|||
quantize &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||
quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
|
425
src/llama.cpp
425
src/llama.cpp
|
@ -134,11 +134,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
|||
const int64_t H = 123;
|
||||
const int64_t n_tokens = 123;
|
||||
const int64_t n_seqs = 123;
|
||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
|
||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||
ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * tf = w;
|
||||
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
|
||||
ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens);
|
||||
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
|
||||
op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
|
||||
} break;
|
||||
|
@ -1212,6 +1212,50 @@ static bool llm_load_tensors(
|
|||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
const int64_t n_embd_head = n_embd / n_head;
|
||||
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
|
||||
|
||||
// output
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
if (layer.wqkv == nullptr) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
|
||||
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
||||
}
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
|
||||
layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
|
@ -2142,11 +2186,13 @@ static bool llm_load_tensors(
|
|||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
|
||||
|
||||
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
|
||||
|
||||
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
|
||||
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
|
||||
|
@ -2170,6 +2216,59 @@ static bool llm_load_tensors(
|
|||
}
|
||||
|
||||
} break;
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
const int time_mix_extra_dim = hparams.time_mix_extra_dim;
|
||||
const int time_decay_extra_dim = hparams.time_decay_extra_dim;
|
||||
const int head_size = hparams.wkv_head_size;
|
||||
const int attn_hidden_size = n_embd;
|
||||
const int n_head_kv = hparams.n_head_kv();
|
||||
int attn_key_value_size;
|
||||
if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
|
||||
attn_key_value_size = attn_hidden_size;
|
||||
} else {
|
||||
attn_key_value_size = n_head_kv * head_size;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
|
||||
layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
|
||||
|
||||
layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
|
||||
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
|
||||
|
||||
layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
|
||||
layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
|
||||
layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
|
||||
layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
|
||||
layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
|
||||
layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
|
||||
// optional bias tensors
|
||||
layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
@ -3293,16 +3392,20 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|||
const struct llama_layer * layer,
|
||||
struct ggml_tensor * cur,
|
||||
struct ggml_tensor * x_prev,
|
||||
struct ggml_tensor ** wkv_state) {
|
||||
struct ggml_tensor ** wkv_state,
|
||||
size_t wkv_head_size,
|
||||
size_t head_count_kv) {
|
||||
size_t n_embd = cur->ne[0];
|
||||
size_t n_seq_tokens = cur->ne[1];
|
||||
size_t n_seqs = cur->ne[2];
|
||||
|
||||
size_t head_size = layer->time_mix_first->ne[0];
|
||||
size_t head_count = layer->time_mix_first->ne[1];
|
||||
size_t head_size = wkv_head_size;
|
||||
size_t head_count = n_embd / head_size;
|
||||
|
||||
size_t n_tokens = n_seqs * n_seq_tokens;
|
||||
|
||||
bool is_qrwkv = layer->time_mix_first == nullptr;
|
||||
|
||||
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
|
||||
|
||||
sx = ggml_reshape_2d(ctx, sx, n_embd, n_tokens);
|
||||
|
@ -3331,69 +3434,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|||
xxx
|
||||
);
|
||||
|
||||
struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
|
||||
if (layer->time_mix_lerp_fused) {
|
||||
// fusing these weights makes some performance improvement
|
||||
sx = ggml_reshape_3d(ctx, sx, n_embd, 1, n_tokens);
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
|
||||
xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
|
||||
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
} else {
|
||||
// for backward compatibility
|
||||
xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
|
||||
xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
|
||||
xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
|
||||
xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
|
||||
xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
|
||||
|
||||
struct ggml_tensor * xw = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
ggml_add(ctx, mw, layer->time_mix_lerp_w),
|
||||
sx
|
||||
),
|
||||
cur
|
||||
);
|
||||
xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
|
||||
xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
|
||||
xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
|
||||
xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
|
||||
xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
|
||||
}
|
||||
|
||||
struct ggml_tensor * xk = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
ggml_add(ctx, mk, layer->time_mix_lerp_k),
|
||||
sx
|
||||
),
|
||||
cur
|
||||
);
|
||||
struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
|
||||
struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk);
|
||||
struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv);
|
||||
if (layer->time_mix_receptance_b) {
|
||||
r = ggml_add(ctx, r, layer->time_mix_receptance_b);
|
||||
}
|
||||
if (layer->time_mix_key_b) {
|
||||
k = ggml_add(ctx, k, layer->time_mix_key_b);
|
||||
}
|
||||
if (layer->time_mix_value_b) {
|
||||
v = ggml_add(ctx, v, layer->time_mix_value_b);
|
||||
}
|
||||
|
||||
struct ggml_tensor * xv = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
ggml_add(ctx, mv, layer->time_mix_lerp_v),
|
||||
sx
|
||||
),
|
||||
cur
|
||||
);
|
||||
struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
|
||||
if (is_qrwkv) {
|
||||
g = ggml_sigmoid(ctx, g);
|
||||
} else {
|
||||
g = ggml_silu(ctx, g);
|
||||
}
|
||||
|
||||
struct ggml_tensor * xr = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
ggml_add(ctx, mr, layer->time_mix_lerp_r),
|
||||
sx
|
||||
),
|
||||
cur
|
||||
);
|
||||
if (head_count_kv != head_count) {
|
||||
GGML_ASSERT(head_count % head_count_kv == 0);
|
||||
k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
|
||||
v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
|
||||
struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
|
||||
k = ggml_repeat(ctx, k, tmp);
|
||||
v = ggml_repeat(ctx, v, tmp);
|
||||
}
|
||||
|
||||
struct ggml_tensor * xg = ggml_add(
|
||||
ctx,
|
||||
ggml_mul(
|
||||
ctx,
|
||||
ggml_add(ctx, mg, layer->time_mix_lerp_g),
|
||||
sx
|
||||
),
|
||||
cur
|
||||
);
|
||||
|
||||
struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1, head_count, n_tokens);
|
||||
struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key, xk), 1, head_size, head_count, n_tokens);
|
||||
struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value, xv), head_size, 1, head_count, n_tokens);
|
||||
struct ggml_tensor * g = ggml_silu(
|
||||
ctx,
|
||||
llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
|
||||
);
|
||||
k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
|
||||
v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
|
||||
r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
|
||||
|
||||
struct ggml_tensor * w = ggml_mul_mat(
|
||||
ctx,
|
||||
|
@ -3404,25 +3502,35 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
|
|||
)
|
||||
);
|
||||
|
||||
w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
|
||||
w = ggml_add(ctx, w, layer->time_mix_decay);
|
||||
w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
|
||||
w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
|
||||
w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
|
||||
|
||||
k = ggml_transpose(ctx, k);
|
||||
v = ggml_transpose(ctx, v);
|
||||
r = ggml_transpose(ctx, r);
|
||||
if (is_qrwkv) {
|
||||
// k = k * (1 - w)
|
||||
k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
|
||||
}
|
||||
|
||||
struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
struct ggml_tensor * wkv_output;
|
||||
if (!layer->time_mix_first) {
|
||||
wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
|
||||
} else {
|
||||
wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
|
||||
}
|
||||
cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
|
||||
*wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
|
||||
|
||||
// group norm with head_count groups
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
|
||||
cur = ggml_norm(ctx, cur, 64e-5f);
|
||||
if (!is_qrwkv) {
|
||||
// group norm with head_count groups
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
|
||||
cur = ggml_norm(ctx, cur, 64e-5f);
|
||||
|
||||
// Convert back to regular vectors.
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
|
||||
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
|
||||
// Convert back to regular vectors.
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
|
||||
cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
|
||||
} else {
|
||||
cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
|
||||
}
|
||||
|
||||
cur = ggml_mul(ctx, cur, g);
|
||||
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
|
||||
|
@ -6266,7 +6374,7 @@ struct llm_build_context {
|
|||
|
||||
struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm,
|
||||
NULL,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(attn_norm_output, "attn_norm", il);
|
||||
|
||||
|
@ -6281,8 +6389,7 @@ struct llm_build_context {
|
|||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
|
||||
|
@ -6326,14 +6433,12 @@ struct llm_build_context {
|
|||
residual = cur;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// FF
|
||||
// special-case: the up and gate tensors are merged into a single tensor
|
||||
// TOOD: support into llm_build_ffn
|
||||
{
|
||||
// feed-forward network
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
NULL, NULL, NULL,
|
||||
|
@ -6341,6 +6446,20 @@ struct llm_build_context {
|
|||
NULL,
|
||||
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_moe_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, residual, cur);
|
||||
|
@ -6353,11 +6472,16 @@ struct llm_build_context {
|
|||
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.output_norm,
|
||||
NULL,
|
||||
model.output_norm_b,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
if (model.output_b != nullptr) {
|
||||
cb(cur, "result_output_no_bias", -1);
|
||||
cur = ggml_add(ctx0, cur, model.output_b);
|
||||
}
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
@ -9988,7 +10112,7 @@ struct llm_build_context {
|
|||
1
|
||||
);
|
||||
|
||||
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
|
||||
cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
|
@ -10055,6 +10179,118 @@ struct llm_build_context {
|
|||
return gf;
|
||||
}
|
||||
|
||||
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
||||
ggml_cgraph * build_rwkv6qwen2() {
|
||||
ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const int64_t n_tokens = ubatch.n_tokens;
|
||||
GGML_ASSERT(n_seqs != 0);
|
||||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||
struct ggml_tensor * state_mask = build_inp_s_mask();
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const llama_layer * layer = &model.layers[il];
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
|
||||
gf, kv_self.k_l[il], state_copy, state_mask,
|
||||
hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
||||
struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
|
||||
gf, kv_self.v_l[il], state_copy, state_mask,
|
||||
hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
|
||||
|
||||
cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
|
||||
|
||||
struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
|
||||
struct ggml_tensor * x_prev = ggml_concat(
|
||||
ctx0,
|
||||
token_shift,
|
||||
ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
|
||||
1
|
||||
);
|
||||
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_states,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self.v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
|
||||
ggml_build_forward_expand(gf, ffn_inp);
|
||||
ggml_build_forward_expand(
|
||||
gf,
|
||||
ggml_cpy(
|
||||
ctx0,
|
||||
wkv_states,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_self.v_l[il],
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
// ref: https://github.com/facebookresearch/chameleon
|
||||
// based on the original build_llama() function, changes:
|
||||
// * qk-norm
|
||||
|
@ -10536,6 +10772,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
result = llm.build_phi2();
|
||||
} break;
|
||||
case LLM_ARCH_PHI3:
|
||||
case LLM_ARCH_PHIMOE:
|
||||
{
|
||||
result = llm.build_phi3();
|
||||
} break;
|
||||
|
@ -10663,6 +10900,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_rwkv6();
|
||||
} break;
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
{
|
||||
result = llm.build_rwkv6qwen2();
|
||||
} break;
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
{
|
||||
result = llm.build_chameleon();
|
||||
|
|
|
@ -1659,17 +1659,46 @@ struct test_rwkv_wkv6 : public test_case {
|
|||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ head_size, 1, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size, head_count }.data());
|
||||
ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ 1, head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * td = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GATED_LINEAR_ATTN
|
||||
struct test_gla : public test_case {
|
||||
const ggml_type type;
|
||||
|
||||
const int64_t head_count;
|
||||
const int64_t head_size;
|
||||
const int64_t n_seq_tokens;
|
||||
const int64_t n_seqs;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs);
|
||||
}
|
||||
|
||||
test_gla(ggml_type type = GGML_TYPE_F32,
|
||||
int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
|
||||
: type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t n_tokens = n_seq_tokens * n_seqs;
|
||||
ggml_tensor * q = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * g = ggml_new_tensor(ctx, type, 3, std::vector<int64_t>{ head_size, head_count, n_tokens }.data());
|
||||
ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector<int64_t>{ head_size * head_size * head_count, n_seqs }.data());
|
||||
ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_MUL_MAT
|
||||
struct test_mul_mat : public test_case {
|
||||
const ggml_type type_a;
|
||||
|
@ -3626,6 +3655,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
|||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||
|
||||
for (int i = 1; i < 9; ++i) {
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||
|
|
|
@ -78,7 +78,9 @@ int main(void) {
|
|||
// ai-sage/GigaChat-20B-A3B-instruct
|
||||
"{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}",
|
||||
// Infinigence/Megrez-3B-Instruct
|
||||
u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}"
|
||||
u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}",
|
||||
// phi-4
|
||||
"{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}",
|
||||
};
|
||||
std::vector<std::string> expected_output = {
|
||||
// teknium/OpenHermes-2.5-Mistral-7B
|
||||
|
@ -137,6 +139,8 @@ int main(void) {
|
|||
"<s>You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>",
|
||||
// Infinigence/Megrez-3B-Instruct
|
||||
"<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>",
|
||||
// phi-4
|
||||
"<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>",
|
||||
};
|
||||
std::vector<char> formatted_chat(1024);
|
||||
int32_t res;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue