update: ready for PR

This commit is contained in:
Trần Đức Nam 2023-12-19 11:19:01 +07:00
parent 576d28b7f7
commit 4cad8d7d7a
8 changed files with 165 additions and 2355 deletions

View file

@ -10,6 +10,13 @@
- [ ] Bloom - [ ] Bloom
- [ ] Mixtral MoE - [ ] Mixtral MoE
**TODO:**
- [ ] Add OPT model
- [ ] Add Bloom model
- [ ] Add Mixtral MoE
- [ ] Update version work with both MPT and MPT-AWQ model
- [ ] Support w3, w2
## Contents ## Contents
@ -33,7 +40,7 @@ git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
Example for llama 7b model Example for llama 7b model
```bash ```bash
# For llama7b and llama27b models # For llama7b and llama27b models
python examples/awqutils/convert-awq.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --tmp-model-path models/llama-7b-scales --outfile models/llama_7b_fp16.gguf python convert.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/llama_7b_fp16.gguf
``` ```
## Quantize ## Quantize
@ -57,13 +64,9 @@ We use three types of llamacpp quantization methods to work with our version, in
|-----------:|--------------|-------:|-------:|-------:|-------:| |-----------:|--------------|-------:|-------:|-------:|-------:|
|Llama 7B | perplexity | 5.9066 | 6.1214 | 6.0643 | 6.5808 | |Llama 7B | perplexity | 5.9066 | 6.1214 | 6.0643 | 6.5808 |
|Llama 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G | |Llama 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|Llama 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|Llama 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|Llama 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |Llama 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|AWQ-LLama 7B| perplexity | 5.9175 | 6.0252 | 5.9987 | 6.3692 | |AWQ-LLama 7B| perplexity | 5.9175 | 6.0252 | 5.9987 | 6.3692 |
|AWQ-LLama 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G | |AWQ-LLama 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|AWQ-LLama 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-LLama 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
@ -73,13 +76,9 @@ We use three types of llamacpp quantization methods to work with our version, in
|------------:|--------------|-------:|-------:|-------:|-------:| |------------:|--------------|-------:|-------:|-------:|-------:|
|Llama2 7B | perplexity | 5.8664 | 6.0260 | 6.0656 | 6.4496 | |Llama2 7B | perplexity | 5.8664 | 6.0260 | 6.0656 | 6.4496 |
|Llama2 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G | |Llama2 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|Llama2 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|Llama2 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|Llama2 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |Llama2 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|AWQ-LLama2 7B| perplexity | 5.8801 | 6.0054 | 5.9849 | 6.3650 | |AWQ-LLama2 7B| perplexity | 5.8801 | 6.0054 | 5.9849 | 6.3650 |
|AWQ-LLama2 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G | |AWQ-LLama2 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|AWQ-LLama2 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-LLama2 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
@ -88,27 +87,19 @@ We use three types of llamacpp quantization methods to work with our version, in
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K | | Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|-------------:|--------------|-------:|-------:|-------:|-------:| |-------------:|--------------|-------:|-------:|-------:|-------:|
|Mistral 7B | perplexity | 5.6931 | 5.8202 | 5.8268 | 6.1645 | |Mistral 7B | perplexity | 5.6931 | 5.8202 | 5.8268 | 6.1645 |
|Mistral 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G | |Mistral 7B | file size | 14.5G | 4.1G | 4.5G | 3.1G |
|Mistral 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|Mistral 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|AWQ-Mistral 7B| perplexity | 5.6934 | 5.8020 | 5.7691 | 6.0426 | |AWQ-Mistral 7B| perplexity | 5.6934 | 5.8020 | 5.7691 | 6.0426 |
|AWQ-Mistral 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G | |AWQ-Mistral 7B| file size | 14.5G | 4.1G | 4.5G | 3.1G |
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
### MPT 7B (Build with OpenBLAS) ### MPT 7B (Build with OpenBLAS)
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K | | Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|-------------:|--------------|-------:|-------:|-------:|-------:| |---------:|--------------|-------:|-------:|-------:|--------:|
|Mistral 7B | perplexity | xxxxxx | xxxxxx | xxxxxx | xxxxxx | |MPT 7B | perplexity | 8.4369 | 8.7956 | 8.6265 | 11.4913 |
|Mistral 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G | |MPT 7B | file size | 13.7G | 3.9G | 4.3G | 2.8G |
|Mistral 7B | ms/tok @ 4th | xxx | xx | xx | xx | |MPT 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|Mistral 7B | ms/tok @ 8th | xxx | xx | xx | xx | |AWQ-MPT 7B| perplexity | 8.4944 | 8.7053 | 8.6750 | 10.2873|
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 | |AWQ-MPT 7B| file size | 13.7G | 3.9G | 4.3G | 2.8G |
|AWQ-Mistral 7B| perplexity | xxxxxx | xxxxxx | xxxxx | xxxxxx | |AWQ-MPT 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|AWQ-Mistral 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |

View file

@ -1,13 +1,17 @@
""" """
Original code from: Implements the AWQ for llama.cpp use cases.
1. https://github.com/casper-hansen/AutoAWQ Original paper: https://arxiv.org/abs/2306.00978
2. https://github.com/mit-han-lab/llm-awq
This code is based on versions of the AWQ implementation found in the following repositories:
* https://github.com/mit-han-lab/llm-awq
* https://github.com/casper-hansen/AutoAWQ
""" """
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
@ -65,7 +69,7 @@ def get_op_by_name(module, op_name):
Args: Args:
module (nn.Module): The layer containing the submodule to find. module (nn.Module): The layer containing the submodule to find.
op_name (str): The name of the submodule to search for, using dot notation for nested modules. op_name (str): The name of the submodule.
Returns: Returns:
nn.Module: The requested submodule found within the given layer. nn.Module: The requested submodule found within the given layer.
@ -87,7 +91,7 @@ def scale_ln_fcs(ln, fcs, scales):
Args: Args:
ln (nn.LayerNorm): The LayerNorm module to be scaled. ln (nn.LayerNorm): The LayerNorm module to be scaled.
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled. fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature. scales (torch.Tensor): A 1D tensor of size (num_features,).
""" """
if not isinstance(fcs, list): if not isinstance(fcs, list):
@ -117,14 +121,14 @@ def scale_fc_fc(fc1, fc2, scales):
Args: Args:
fc1 (nn.Linear): The first fully-connected layer to be scaled. fc1 (nn.Linear): The first fully-connected layer to be scaled.
fc2 (nn.Linear): The second fully-connected layer to be scaled. fc2 (nn.Linear): The second fully-connected layer to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature. scales (torch.Tensor): A 1D tensor of size (num_features,).
""" """
assert isinstance(fc1, nn.Linear) assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear) assert isinstance(fc2, nn.Linear)
scales = scales.to(fc1.weight.device) scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1)) fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None: if fc1.bias is not None:
fc1.bias.div_(scales.view(-1)) fc1.bias.div_(scales.view(-1))
@ -144,7 +148,7 @@ def scale_gelu_fc(gelu, fc, scales):
Args: Args:
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled. gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
fc (nn.Linear): The fully-connected layer to be scaled. fc (nn.Linear): The fully-connected layer to be scaled.
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature. scales (torch.Tensor): A 1D tensor of size (num_features,).
Raises: Raises:
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`. TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
@ -166,13 +170,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
Args: Args:
module (nn.Module): The module containing the layers to be scaled. module (nn.Module): The module containing the layers to be scaled.
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing: scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
* prev_op_name (str): The name of the preceding operation or module, relative to which the layers to be * prev_op_name (str): The name of the preceding operation or module,
scaled are located. relative to which the layers to be scaled are located.
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation. * layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature. * scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
input features (optional). If provided, the input features are also input features (optional).
scaled proportionally after scaling the layer weights.
""" """
for prev_op_name, layer_names, scales in scales_list: for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name) prev_op = get_op_by_name(module, prev_op_name)
@ -234,7 +237,8 @@ def apply_clip(module, clip_list):
def add_scale_weights(model_path, scale_path, tmp_path): def add_scale_weights(model_path, scale_path, tmp_path):
""" """
Adds pre-computed Activation Weight Quantization (AWQ) results to a model, including scaling factors and clipping bounds. Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
including scaling factors and clipping bounds.
Args: Args:
model_path (str): Path to the pre-trained model to be equipped with AWQ. model_path (str): Path to the pre-trained model to be equipped with AWQ.

View file

@ -12,6 +12,8 @@ from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
from awqpy.apply_awq import add_scale_weights
import numpy as np import numpy as np
import torch import torch
@ -442,7 +444,11 @@ class MPTModel(Model):
data = data_torch.squeeze().numpy() data = data_torch.squeeze().numpy()
# map tensor names # map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) if "scales" in name:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
new_name = new_name + ".scales"
else:
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None: if new_name is None:
print(f"Can not map tensor {name!r}") print(f"Can not map tensor {name!r}")
sys.exit() sys.exit()
@ -970,6 +976,9 @@ def parse_args() -> argparse.Namespace:
"--vocab-only", action="store_true", "--vocab-only", action="store_true",
help="extract only the vocab", help="extract only the vocab",
) )
parser.add_argument(
"--awq-path", type=Path, default=None,
help="Path to scale awq cache file")
parser.add_argument( parser.add_argument(
"--outfile", type=Path, "--outfile", type=Path,
help="path to write to; default: based on input", help="path to write to; default: based on input",
@ -989,7 +998,21 @@ def parse_args() -> argparse.Namespace:
args = parse_args() args = parse_args()
dir_model = args.model if args.awq_path:
from awqpy import add_scale_weights
tmp_model_path = args.model / "weighted_model"
if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.")
else:
tmp_model_path.mkdir(parents=True, exist_ok=True)
print("Saving new weighted model ...")
tmp_model_path.mkdirs(exist_ok=True)
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
print(f"Saved weighted model at {tmp_model_path}.")
dir_model = tmp_model_path
else:
dir_model = args.model
if not dir_model.is_dir(): if not dir_model.is_dir():
print(f'Error: {args.model} is not a directory', file=sys.stderr) print(f'Error: {args.model} is not a directory', file=sys.stderr)
sys.exit(1) sys.exit(1)

View file

@ -23,6 +23,8 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
from awqpy.apply_awq import add_scale_weights
import numpy as np import numpy as np
from sentencepiece import SentencePieceProcessor from sentencepiece import SentencePieceProcessor
@ -1139,6 +1141,7 @@ def main(args_in: list[str] | None = None) -> None:
# We currently only support Q8_0 output on little endian systems. # We currently only support Q8_0 output on little endian systems.
output_choices.append("q8_0") output_choices.append("q8_0")
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
parser.add_argument("--awq-path", type=Path, default=None, help="Path to scale awq cache file")
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
@ -1152,6 +1155,19 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
if args.awq_path:
from awqpy import add_scale_weights
tmp_model_path = args.model / "weighted_model"
if tmp_model_path.is_dir():
print(f"{tmp_model_path} exists as a weighted model.")
else:
tmp_model_path.mkdir(parents=True, exist_ok=True)
print("Saving new weighted model ...")
tmp_model_path.mkdirs(exist_ok=True)
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
print(f"Saved weighted model at {tmp_model_path}.")
args.model = tmp_model_path
if args.dump_single: if args.dump_single:
model_plus = lazy_load_file(args.model) model_plus = lazy_load_file(args.model)
do_dump_model(model_plus) do_dump_model(model_plus)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

109
llama.cpp
View file

@ -454,7 +454,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act"}, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act"},
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
@ -3845,7 +3845,6 @@ static struct ggml_tensor * llm_build_ffn(
struct ggml_tensor * gate_b, struct ggml_tensor * gate_b,
struct ggml_tensor * down, struct ggml_tensor * down,
struct ggml_tensor * down_b, struct ggml_tensor * down_b,
struct ggml_tensor *act_scales,
llm_ffn_op_type type_op, llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate, llm_ffn_gate_type type_gate,
const llm_build_cb & cb, const llm_build_cb & cb,
@ -3896,16 +3895,6 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_relu(ctx, cur); cur = ggml_relu(ctx, cur);
cb(cur, "ffn_relu", il); cb(cur, "ffn_relu", il);
} break; } break;
case LLM_FFN_GELU_ACT:
{
cur = ggml_gelu(ctx, cur);
cb(cur, "ffn_relu", il);
struct ggml_tensor *repeat = ggml_repeat(ctx, act_scales, cur);
cb(repeat, "ffn_repeat(scales)", il);
cur = ggml_div(ctx, cur, repeat);
cb(cur, "ffn_div(gelu)", il);
}
break;
case LLM_FFN_RELU_SQR: case LLM_FFN_RELU_SQR:
{ {
cur = ggml_relu(ctx, cur); cur = ggml_relu(ctx, cur);
@ -3933,6 +3922,93 @@ static struct ggml_tensor * llm_build_ffn(
return cur; return cur;
} }
static struct ggml_tensor *llm_build_ffn(
struct ggml_context *ctx,
struct ggml_tensor *cur,
struct ggml_tensor *up,
struct ggml_tensor *up_b,
struct ggml_tensor *gate,
struct ggml_tensor *gate_b,
struct ggml_tensor *down,
struct ggml_tensor *down_b,
struct ggml_tensor *act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
const llm_build_cb &cb,
int il)
{
struct ggml_tensor *tmp = ggml_mul_mat(ctx, up, cur);
cb(tmp, "ffn_up", il);
if (up_b)
{
tmp = ggml_add(ctx, tmp, up_b);
cb(tmp, "ffn_up_b", il);
}
if (gate)
{
switch (type_gate)
{
case LLM_FFN_SEQ:
{
cur = ggml_mul_mat(ctx, gate, tmp);
cb(cur, "ffn_gate", il);
}
break;
case LLM_FFN_PAR:
{
cur = ggml_mul_mat(ctx, gate, cur);
cb(cur, "ffn_gate", il);
}
break;
}
if (gate_b)
{
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
}
else
{
cur = tmp;
}
switch (type_op)
{
case LLM_FFN_GELU_ACT:
{
cur = ggml_gelu(ctx, cur);
cb(cur, "ffn_relu", il);
struct ggml_tensor *repeat = ggml_repeat(ctx, act_scales, cur);
cb(repeat, "ffn_repeat(scales)", il);
cur = ggml_div(ctx, cur, repeat);
cb(cur, "ffn_div(gelu)", il);
}
break;
}
if (type_gate == LLM_FFN_PAR)
{
cur = ggml_mul(ctx, cur, tmp);
cb(cur, "ffn_gate_par", il);
}
cur = ggml_mul_mat(ctx, down, cur);
if (down_b)
{
cb(cur, "ffn_down", il);
}
if (down_b)
{
cur = ggml_add(ctx, cur, down_b);
}
return cur;
}
// if max_alibi_bias > 0 then apply ALiBi // if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx, struct ggml_context * ctx,
@ -4211,7 +4287,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4332,7 +4407,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4451,7 +4525,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4560,7 +4633,6 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4769,7 +4841,6 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il); LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4860,7 +4931,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -4960,7 +5030,6 @@ struct llm_build_context {
model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up, model.layers[il].ffn_up_b,
NULL, NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down, model.layers[il].ffn_down_b,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5168,7 +5237,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }
@ -5285,7 +5353,6 @@ struct llm_build_context {
model.layers[il].ffn_up, NULL, model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL, model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il); LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il); cb(cur, "ffn_out", il);
} }