update: support 4 models

This commit is contained in:
Trần Đức Nam 2023-12-18 15:49:15 +07:00
parent e851199ad3
commit eb9a790c11
7 changed files with 1257 additions and 92 deletions

View file

@ -1,5 +1,15 @@
# AWQ: Activation-aware Weight Quantization for LLM - version apply to llamacpp # AWQ: Activation-aware Weight Quantization for LLM - version apply to llamacpp
[[Paper](https://arxiv.org/abs/2306.00978)][[Original Repo](https://github.com/mit-han-lab/llm-awq)][[easy-to-use Repo](https://github.com/casper-hansen/AutoAWQ)] [[Paper](https://arxiv.org/abs/2306.00978)][[Original Repo](https://github.com/mit-han-lab/llm-awq)][[Easy-to-use Repo](https://github.com/casper-hansen/AutoAWQ)]
**Supported models:**
- [X] LLaMA 🦙
- [x] LLaMA 2 🦙🦙
- [X] MPT
- [X] Mistral AI v0.1
- [] Bloom
- [] Mixtral MoE
## Contents ## Contents
@ -22,37 +32,26 @@ git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
## Convert ## Convert
Example for llama 7b model Example for llama 7b model
```bash ```bash
python convert-awq-hf-to-gguf.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --tmp-model-path models/llama-7b-scales --outfile models/llama_7b_fp16.gguf # For llama7b and llama27b models
python examples/awqutils/convert-awq.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --tmp-model-path models/llama-7b-scales --outfile models/llama_7b_fp16.gguf
``` ```
## Quantize ## Quantize
```bash ```bash
./build/bin/quantize models/llama_7b_fp16.gguf models/llama_7b_q4_0.gguf q4_0 ./quantize models/llama_7b_fp16.gguf models/llama_7b_q4_0.gguf q4_0
``` ```
## Benchmark ## Benchmark
The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512. The perplexity measurements in table above are done against the `wikitext2` test dataset (https://paperswithcode.com/dataset/wikitext-2), with context length of 512.
```bash ```bash
./build/bin/perplexity -m models/llama_7b_q4_0.gguf -f datasets/wikitext-2-raw/wiki.test.raw ./perplexity -m models/llama_7b_q4_0.gguf -f datasets/wikitext-2-raw/wiki.test.raw
``` ```
## Results ## Results
Results are run on OpenBLAS (CPU) and CuBLAS (GPU) for fair comparison
We use three types of llamacpp quantization methods to work with our version, including q4, q4_1, and q2_k
### Llama 7B ### Llama 7B (Build with OpenBLAS)
Build with OpenBLAS
#### Memory/Disk Requirements
| Model | Original | AWQ-4bit |
|------:|--------------:|--------------:|
| fp16 | 12.853 GB | 12.853 GB |
| q4_0 | 3.647 GB | 3.647 GB |
| q4_1 | 4.041 GB | 4.041 GB |
| q2_k | 2.649 GB | 2.649 GB |
#### Quantization
Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K | | Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|-----------:|--------------|-------:|-------:|-------:|-------:| |-----------:|--------------|-------:|-------:|-------:|-------:|
@ -68,21 +67,7 @@ Several quantization methods are supported. They differ in the resulting model d
|AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
### Llama2 7B ### Llama2 7B (Build with CuBLAS)
Build with CuBLAS
#### Memory/Disk Requirements
| Model | Original | AWQ-4bit |
|------:|--------------:|--------------:|
| fp16 | 12.853 GB | 12.853 GB |
| q4_0 | 3.647 GB | 3.647 GB |
| q4_1 | 4.041 GB | 4.041 GB |
| q2_k | 2.649 GB | 2.649 GB |
#### Quantization
Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K | | Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|------------:|--------------|-------:|-------:|-------:|-------:| |------------:|--------------|-------:|-------:|-------:|-------:|
@ -98,21 +83,7 @@ Several quantization methods are supported. They differ in the resulting model d
|AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
### Mistral 7B v0.1 ### Mistral 7B v0.1 (Build with CuBLAS)
Build with CuBLAS
#### Memory/Disk Requirements
| Model | Original | AWQ-4bit |
|------:|--------------:|--------------:|
| fp16 | 12.853 GB | 12.853 GB |
| q4_0 | 3.647 GB | 3.647 GB |
| q4_1 | 4.041 GB | 4.041 GB |
| q2_k | 2.649 GB | 2.649 GB |
#### Quantization
Several quantization methods are supported. They differ in the resulting model disk size and inference speed.
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K | | Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|-------------:|--------------|-------:|-------:|-------:|-------:| |-------------:|--------------|-------:|-------:|-------:|-------:|
@ -126,3 +97,18 @@ Several quantization methods are supported. They differ in the resulting model d
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx | |AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx | |AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
### MPT 7B (Build with OpenBLAS)
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|-------------:|--------------|-------:|-------:|-------:|-------:|
|Mistral 7B | perplexity | xxxxxx | xxxxxx | xxxxxx | xxxxxx |
|Mistral 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|Mistral 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|Mistral 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|AWQ-Mistral 7B| perplexity | xxxxxx | xxxxxx | xxxxx | xxxxxx |
|AWQ-Mistral 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |

View file

@ -1,14 +1,32 @@
"""
Original code from:
1. https://github.com/casper-hansen/AutoAWQ
2. https://github.com/mit-han-lab/llm-awq
"""
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""
ScaledActivation module wraps an existing activation function and applies a
scale factor to its output.
Args:
module (nn.Module): The activation function to be scaled.
scales (torch.Tensor): A tensor of size (num_features,) containing the initial
scale factors for each feature.
Returns:
torch.Tensor: The scaled output of the activation function.
"""
def __init__(self, module, scales): def __init__(self, module, scales):
super().__init__() super().__init__()
self.act = module self.act = module
@ -17,8 +35,18 @@ class ScaledActivation(nn.Module):
def forward(self, x): def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device) return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
def set_op_by_name(layer, name, new_module): def set_op_by_name(layer, name, new_module):
levels = name.split('.') """
Set the new module for given module's name.
Args:
layer (nn.Module): The layer in which to replace the submodule.
name (str): The path to the submodule to be replaced, using dot notation
to access nested modules.
new_module (nn.Module): The new module to replace the existing one.
"""
levels = name.split(".")
if len(levels) > 1: if len(levels) > 1:
mod_ = layer mod_ = layer
for l_idx in range(len(levels) - 1): for l_idx in range(len(levels) - 1):
@ -30,22 +58,45 @@ def set_op_by_name(layer, name, new_module):
else: else:
setattr(layer, name, new_module) setattr(layer, name, new_module)
def get_op_by_name(module, op_name): def get_op_by_name(module, op_name):
# get the op by its name relative to the module """
Retrieves a submodule within a given layer based on its name.
Args:
module (nn.Module): The layer containing the submodule to find.
op_name (str): The name of the submodule to search for, using dot notation for nested modules.
Returns:
nn.Module: The requested submodule found within the given layer.
Raises:
ValueError: If the specified submodule cannot be found within the layer.
"""
for name, m in module.named_modules(): for name, m in module.named_modules():
if name == op_name: if name == op_name:
return m return m
raise ValueError(f"Cannot find op {op_name} in module {module}") raise ValueError(f"Cannot find op {op_name} in module {module}")
@torch.no_grad() @torch.no_grad()
def scale_ln_fcs(ln, fcs, scales): def scale_ln_fcs(ln, fcs, scales):
"""
Scales the weights of a LayerNorm and a list of fully-connected layers proportionally.
Args:
ln (nn.LayerNorm): The LayerNorm module to be scaled.
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
"""
if not isinstance(fcs, list): if not isinstance(fcs, list):
fcs = [fcs] fcs = [fcs]
scales = scales.to(ln.weight.device) scales = scales.to(ln.weight.device)
ln.weight.div_(scales) ln.weight.div_(scales)
if hasattr(ln, 'bias') and ln.bias is not None: if hasattr(ln, "bias") and ln.bias is not None:
ln.bias.div_(scales) ln.bias.div_(scales)
for fc in fcs: for fc in fcs:
@ -60,13 +111,19 @@ def scale_ln_fcs(ln, fcs, scales):
@torch.no_grad() @torch.no_grad()
def scale_fc_fc(fc1, fc2, scales): def scale_fc_fc(fc1, fc2, scales):
"""
Scales the weights of two fully-connected layers in a specific pattern.
Args:
fc1 (nn.Linear): The first fully-connected layer to be scaled.
fc2 (nn.Linear): The second fully-connected layer to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
"""
assert isinstance(fc1, nn.Linear) assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear) assert isinstance(fc2, nn.Linear)
# assert fc1.out_features == fc2.in_features
scales = scales.to(fc1.weight.device) scales = scales.to(fc1.weight.device)
# fc1.weight.div_(scales.view(-1, 1))
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
if fc1.bias is not None: if fc1.bias is not None:
fc1.bias.div_(scales.view(-1)) fc1.bias.div_(scales.view(-1))
@ -81,6 +138,18 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu, fc, scales): def scale_gelu_fc(gelu, fc, scales):
"""
Scales the weight of a GELU activation and a fully-connected layer proportionally.
Args:
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
fc (nn.Linear): The fully-connected layer to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
Raises:
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
TypeError: If the `fc` module is not of type `nn.Linear`.
"""
assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation)) assert isinstance(gelu, (nn.GELU, BloomGelu, GELUActivation))
assert isinstance(fc, nn.Linear) assert isinstance(fc, nn.Linear)
@ -91,6 +160,20 @@ def scale_gelu_fc(gelu, fc, scales):
def apply_scale(module, scales_list, input_feat_dict=None): def apply_scale(module, scales_list, input_feat_dict=None):
"""
Applies different scaling strategies to layers based on their type and hierarchy within a given module.
Args:
module (nn.Module): The module containing the layers to be scaled.
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
* prev_op_name (str): The name of the preceding operation or module, relative to which the layers to be
scaled are located.
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
input features (optional). If provided, the input features are also
scaled proportionally after scaling the layer weights.
"""
for prev_op_name, layer_names, scales in scales_list: for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name) prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names] layers = [get_op_by_name(module, name) for name in layer_names]
@ -103,15 +186,17 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): elif (
isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm))
or "rmsnorm" in str(prev_op.__class__).lower()
):
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)): elif isinstance(prev_op, (nn.GELU, BloomGelu, GELUActivation)):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
else: else:
raise NotImplementedError( raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping # apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None: if input_feat_dict is not None:
@ -124,8 +209,18 @@ def apply_scale(module, scales_list, input_feat_dict=None):
layer.cpu() layer.cpu()
scales.cpu() scales.cpu()
@torch.no_grad() @torch.no_grad()
def apply_clip(module, clip_list): def apply_clip(module, clip_list):
"""
Applies element-wise clipping to the weight of a specific layer within a given module.
Args:
module (nn.Module): The module containing the layer to be clipped.
clip_list (List[Tuple[str, torch.Tensor]]): A list of tuples containing:
* name (str): The name of the layer to be clipped, relative to the root of the module.
* max_val (torch.Tensor): A 1D or 2D tensor defining the upper bound for each element of the layer's weight.
"""
for name, max_val in clip_list: for name, max_val in clip_list:
layer = get_op_by_name(module, name) layer = get_op_by_name(module, name)
layer.cuda() layer.cuda()
@ -136,25 +231,23 @@ def apply_clip(module, clip_list):
layer.weight.data = layer.weight.data.reshape(org_shape) layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu() layer.cpu()
def apply_awq(model, awq_results):
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
def add_scale_weights(model, model_path, scale_path, tmp_path): def add_scale_weights(model_path, scale_path, tmp_path):
print("Loading pre-computed AWQ results from", str(scale_path)) """
awq_results = torch.load(str(scale_path), map_location="cpu") Adds pre-computed Activation Weight Quantization (AWQ) results to a model, including scaling factors and clipping bounds.
apply_awq(model, awq_results)
model.save_pretrained(str(tmp_path))
os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")
return True
Args:
if __name__ == "__main__": model_path (str): Path to the pre-trained model to be equipped with AWQ.
model_path = "/data/namtd12/llm_models/Llama-2-7b-hf" scale_path (str): Path to the AWQ scale factors (.pt file).
scale_path = "awq_cache_pretrained/llama-2-7b-chat-w4-g128.pt" tmp_path (str): Path to the temporary directory where the equipped model will be saved.
tmp_path = "debug" """
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True) model_path, config=config, trust_remote_code=True
)
model.eval() model.eval()
add_scale_weights(model, scale_path, tmp_path) awq_results = torch.load(str(scale_path), map_location="cpu")
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
model.save_pretrained(str(tmp_path))
os.system(f"cp {str(model_path)}/tokenizer* {str(tmp_path)}")

File diff suppressed because it is too large Load diff

View file

@ -25,7 +25,7 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
from awqutils.apply_awq import add_scale_weights from apply_awq import add_scale_weights
from transformers import AutoModelForCausalLM, AutoConfig from transformers import AutoModelForCausalLM, AutoConfig
@ -1158,11 +1158,7 @@ def main(args_in: list[str] | None = None) -> None:
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.awq_path and args.tmp_model_path: if args.awq_path and args.tmp_model_path:
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) add_scale_weights(str(args.model), str(args.awq_path), str(args.tmp_model_path))
model = AutoModelForCausalLM.from_pretrained(
args.model, config=config, trust_remote_code=True)
model.eval()
add_scale_weights(model, args.model, args.awq_path, args.tmp_model_path)
args.model = args.tmp_model_path args.model = args.tmp_model_path
if args.dump_single: if args.dump_single:

View file

@ -114,6 +114,7 @@ class MODEL_TENSOR(IntEnum):
FFN_GATE = auto() FFN_GATE = auto()
FFN_DOWN = auto() FFN_DOWN = auto()
FFN_UP = auto() FFN_UP = auto()
FFN_ACT = auto()
FFN_NORM = auto() FFN_NORM = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
@ -158,6 +159,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -251,6 +253,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_ACT,
], ],
MODEL_ARCH.GPTJ: [ MODEL_ARCH.GPTJ: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View file

@ -164,6 +164,11 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
), ),
# Awq-activation gate
MODEL_TENSOR.FFN_ACT: (
"transformer.blocks.{bid}.ffn.act", # mpt
),
# Feed-forward gate # Feed-forward gate
MODEL_TENSOR.FFN_GATE: ( MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact "model.layers.{bid}.mlp.gate_proj", # llama-hf refact

View file

@ -341,6 +341,7 @@ enum llm_tensor {
LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_GATE,
LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_DOWN,
LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_UP,
LLM_TENSOR_FFN_ACT,
LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_NORM,
LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_Q_NORM,
LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_K_NORM,
@ -453,6 +454,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act"},
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
@ -1271,6 +1273,7 @@ struct llama_layer {
// ff bias // ff bias
struct ggml_tensor * ffn_down_b; // b2 struct ggml_tensor * ffn_down_b; // b2
struct ggml_tensor * ffn_up_b; // b3 struct ggml_tensor * ffn_up_b; // b3
struct ggml_tensor *ffn_act;
}; };
struct llama_kv_cell { struct llama_kv_cell {
@ -3420,6 +3423,7 @@ static void llm_load_tensors(
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.ffn_act = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, backend);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) { if (backend == GGML_BACKEND_GPU) {
@ -3429,6 +3433,7 @@ static void llm_load_tensors(
ggml_nbytes(layer.wo) + ggml_nbytes(layer.wo) +
ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm) +
ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_down) +
ggml_nbytes(layer.ffn_act) +
ggml_nbytes(layer.ffn_up); ggml_nbytes(layer.ffn_up);
} }
} }
@ -3672,6 +3677,7 @@ enum llm_rope_type {
enum llm_ffn_op_type { enum llm_ffn_op_type {
LLM_FFN_SILU, LLM_FFN_SILU,
LLM_FFN_GELU, LLM_FFN_GELU,
LLM_FFN_GELU_ACT,
LLM_FFN_RELU, LLM_FFN_RELU,
LLM_FFN_RELU_SQR, LLM_FFN_RELU_SQR,
}; };
@ -3839,6 +3845,7 @@ static struct ggml_tensor * llm_build_ffn(
struct ggml_tensor * gate_b, struct ggml_tensor * gate_b,
struct ggml_tensor * down, struct ggml_tensor * down,
struct ggml_tensor * down_b, struct ggml_tensor * down_b,
struct ggml_tensor *act_scales,
llm_ffn_op_type type_op, llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate, llm_ffn_gate_type type_gate,
const llm_build_cb & cb, const llm_build_cb & cb,
@ -3889,6 +3896,16 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_relu(ctx, cur); cur = ggml_relu(ctx, cur);
cb(cur, "ffn_relu", il); cb(cur, "ffn_relu", il);
} break; } break;
case LLM_FFN_GELU_ACT:
{
cur = ggml_gelu(ctx, cur);
cb(cur, "ffn_relu", il);
struct ggml_tensor *repeat = ggml_repeat(ctx, act_scales, cur);
cb(repeat, "ffn_repeat(scales)", il);
cur = ggml_div(ctx, cur, repeat);
cb(cur, "ffn_div(gelu)", il);
}
break;
case LLM_FFN_RELU_SQR: case LLM_FFN_RELU_SQR:
{ {
cur = ggml_relu(ctx, cur); cur = ggml_relu(ctx, cur);
@ -4194,6 +4211,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4314,6 +4332,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4432,6 +4451,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4540,6 +4560,7 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4748,6 +4769,7 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4838,6 +4860,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4937,6 +4960,7 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5031,7 +5055,8 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); model.layers[il].ffn_act,
LLM_FFN_GELU_ACT, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5143,6 +5168,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5259,6 +5285,7 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5441,6 +5468,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
{ "ffn_gate", OFFLOAD_FUNC }, { "ffn_gate", OFFLOAD_FUNC },
{ "ffn_gate_b", OFFLOAD_FUNC }, { "ffn_gate_b", OFFLOAD_FUNC },
{ "ffn_gate_par", OFFLOAD_FUNC }, { "ffn_gate_par", OFFLOAD_FUNC },
{"ffn_act", OFFLOAD_FUNC },
{ "ffn_down", OFFLOAD_FUNC }, { "ffn_down", OFFLOAD_FUNC },
{ "ffn_down_b", OFFLOAD_FUNC }, { "ffn_down_b", OFFLOAD_FUNC },
{ "ffn_out", OFFLOAD_FUNC }, { "ffn_out", OFFLOAD_FUNC },