update: ready for PR
This commit is contained in:
parent
576d28b7f7
commit
4cad8d7d7a
8 changed files with 165 additions and 2355 deletions
|
@ -10,6 +10,13 @@
|
|||
- [ ] Bloom
|
||||
- [ ] Mixtral MoE
|
||||
|
||||
**TODO:**
|
||||
- [ ] Add OPT model
|
||||
- [ ] Add Bloom model
|
||||
- [ ] Add Mixtral MoE
|
||||
- [ ] Update version work with both MPT and MPT-AWQ model
|
||||
- [ ] Support w3, w2
|
||||
|
||||
|
||||
## Contents
|
||||
|
||||
|
@ -33,7 +40,7 @@ git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
|
|||
Example for llama 7b model
|
||||
```bash
|
||||
# For llama7b and llama27b models
|
||||
python examples/awqutils/convert-awq.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --tmp-model-path models/llama-7b-scales --outfile models/llama_7b_fp16.gguf
|
||||
python convert.py models/llama-7b/ --awq-path awq_cache/llama-7b-w4-g128.pt --outfile models/llama_7b_fp16.gguf
|
||||
```
|
||||
|
||||
## Quantize
|
||||
|
@ -57,13 +64,9 @@ We use three types of llamacpp quantization methods to work with our version, in
|
|||
|-----------:|--------------|-------:|-------:|-------:|-------:|
|
||||
|Llama 7B | perplexity | 5.9066 | 6.1214 | 6.0643 | 6.5808 |
|
||||
|Llama 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|Llama 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|
||||
|Llama 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|
||||
|Llama 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|AWQ-LLama 7B| perplexity | 5.9175 | 6.0252 | 5.9987 | 6.3692 |
|
||||
|AWQ-LLama 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|AWQ-LLama 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|
||||
|AWQ-LLama 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|
||||
|AWQ-LLama 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|
||||
|
||||
|
@ -73,13 +76,9 @@ We use three types of llamacpp quantization methods to work with our version, in
|
|||
|------------:|--------------|-------:|-------:|-------:|-------:|
|
||||
|Llama2 7B | perplexity | 5.8664 | 6.0260 | 6.0656 | 6.4496 |
|
||||
|Llama2 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|Llama2 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|
||||
|Llama2 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|
||||
|Llama2 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|AWQ-LLama2 7B| perplexity | 5.8801 | 6.0054 | 5.9849 | 6.3650 |
|
||||
|AWQ-LLama2 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|AWQ-LLama2 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|
||||
|AWQ-LLama2 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|
||||
|AWQ-LLama2 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|
||||
|
||||
|
@ -88,27 +87,19 @@ We use three types of llamacpp quantization methods to work with our version, in
|
|||
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||
|-------------:|--------------|-------:|-------:|-------:|-------:|
|
||||
|Mistral 7B | perplexity | 5.6931 | 5.8202 | 5.8268 | 6.1645 |
|
||||
|Mistral 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|Mistral 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|
||||
|Mistral 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|
||||
|Mistral 7B | file size | 14.5G | 4.1G | 4.5G | 3.1G |
|
||||
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|AWQ-Mistral 7B| perplexity | 5.6934 | 5.8020 | 5.7691 | 6.0426 |
|
||||
|AWQ-Mistral 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|
||||
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|
||||
|AWQ-Mistral 7B| file size | 14.5G | 4.1G | 4.5G | 3.1G |
|
||||
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|
||||
### MPT 7B (Build with OpenBLAS)
|
||||
|
||||
| Model | Measure | F16 | Q4_0 | Q4_1 | Q2_K |
|
||||
|-------------:|--------------|-------:|-------:|-------:|-------:|
|
||||
|Mistral 7B | perplexity | xxxxxx | xxxxxx | xxxxxx | xxxxxx |
|
||||
|Mistral 7B | file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|Mistral 7B | ms/tok @ 4th | xxx | xx | xx | xx |
|
||||
|Mistral 7B | ms/tok @ 8th | xxx | xx | xx | xx |
|
||||
|Mistral 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|AWQ-Mistral 7B| perplexity | xxxxxx | xxxxxx | xxxxx | xxxxxx |
|
||||
|AWQ-Mistral 7B| file size | 12.9G | 3.5G | 3.9G | 2.7G |
|
||||
|AWQ-Mistral 7B| ms/tok @ 4th | xxx| xxx | xxx | xxx |
|
||||
|AWQ-Mistral 7B| ms/tok @ 8th | xxx| xx | xx | xx |
|
||||
|AWQ-Mistral 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|---------:|--------------|-------:|-------:|-------:|--------:|
|
||||
|MPT 7B | perplexity | 8.4369 | 8.7956 | 8.6265 | 11.4913 |
|
||||
|MPT 7B | file size | 13.7G | 3.9G | 4.3G | 2.8G |
|
||||
|MPT 7B | bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
||||
|AWQ-MPT 7B| perplexity | 8.4944 | 8.7053 | 8.6750 | 10.2873|
|
||||
|AWQ-MPT 7B| file size | 13.7G | 3.9G | 4.3G | 2.8G |
|
||||
|AWQ-MPT 7B| bits/weight | 16.0 | 4.5 | 5.0 | 2.6 |
|
|
@ -1,13 +1,17 @@
|
|||
"""
|
||||
Original code from:
|
||||
1. https://github.com/casper-hansen/AutoAWQ
|
||||
2. https://github.com/mit-han-lab/llm-awq
|
||||
Implements the AWQ for llama.cpp use cases.
|
||||
Original paper: https://arxiv.org/abs/2306.00978
|
||||
|
||||
This code is based on versions of the AWQ implementation found in the following repositories:
|
||||
* https://github.com/mit-han-lab/llm-awq
|
||||
* https://github.com/casper-hansen/AutoAWQ
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, AutoConfig
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomGelu
|
||||
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||
from transformers.activations import GELUActivation
|
||||
|
@ -65,7 +69,7 @@ def get_op_by_name(module, op_name):
|
|||
|
||||
Args:
|
||||
module (nn.Module): The layer containing the submodule to find.
|
||||
op_name (str): The name of the submodule to search for, using dot notation for nested modules.
|
||||
op_name (str): The name of the submodule.
|
||||
|
||||
Returns:
|
||||
nn.Module: The requested submodule found within the given layer.
|
||||
|
@ -87,7 +91,7 @@ def scale_ln_fcs(ln, fcs, scales):
|
|||
Args:
|
||||
ln (nn.LayerNorm): The LayerNorm module to be scaled.
|
||||
fcs (List[nn.Linear]): A list of fully-connected layers to be scaled.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||
"""
|
||||
|
||||
if not isinstance(fcs, list):
|
||||
|
@ -117,7 +121,7 @@ def scale_fc_fc(fc1, fc2, scales):
|
|||
Args:
|
||||
fc1 (nn.Linear): The first fully-connected layer to be scaled.
|
||||
fc2 (nn.Linear): The second fully-connected layer to be scaled.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||
"""
|
||||
assert isinstance(fc1, nn.Linear)
|
||||
assert isinstance(fc2, nn.Linear)
|
||||
|
@ -144,7 +148,7 @@ def scale_gelu_fc(gelu, fc, scales):
|
|||
Args:
|
||||
gelu (Union[nn.GELU, BloomGelu, GELUActivation]): The GELU activation module to be scaled.
|
||||
fc (nn.Linear): The fully-connected layer to be scaled.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
||||
scales (torch.Tensor): A 1D tensor of size (num_features,).
|
||||
|
||||
Raises:
|
||||
TypeError: If the `gelu` module is not of type `nn.GELU`, `BloomGelu`, or `GELUActivation`.
|
||||
|
@ -166,13 +170,12 @@ def apply_scale(module, scales_list, input_feat_dict=None):
|
|||
Args:
|
||||
module (nn.Module): The module containing the layers to be scaled.
|
||||
scales_list (List[Tuple[str, List[str], torch.Tensor]]): A list of tuples containing:
|
||||
* prev_op_name (str): The name of the preceding operation or module, relative to which the layers to be
|
||||
scaled are located.
|
||||
* prev_op_name (str): The name of the preceding operation or module,
|
||||
relative to which the layers to be scaled are located.
|
||||
* layer_names (List[str]): A list of names of the layers to be scaled, relative to the preceding operation.
|
||||
* scales (torch.Tensor): A 1D tensor of size (num_features,) containing the scaling factors for each feature.
|
||||
input_feat_dict (Optional[Dict[str, torch.Tensor]]): A dictionary mapping layer names to their corresponding
|
||||
input features (optional). If provided, the input features are also
|
||||
scaled proportionally after scaling the layer weights.
|
||||
input features (optional).
|
||||
"""
|
||||
for prev_op_name, layer_names, scales in scales_list:
|
||||
prev_op = get_op_by_name(module, prev_op_name)
|
||||
|
@ -234,7 +237,8 @@ def apply_clip(module, clip_list):
|
|||
|
||||
def add_scale_weights(model_path, scale_path, tmp_path):
|
||||
"""
|
||||
Adds pre-computed Activation Weight Quantization (AWQ) results to a model, including scaling factors and clipping bounds.
|
||||
Adds pre-computed Activation Weight Quantization (AWQ) results to a model,
|
||||
including scaling factors and clipping bounds.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to the pre-trained model to be equipped with AWQ.
|
|
@ -12,6 +12,8 @@ from enum import IntEnum
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
|
||||
|
||||
from awqpy.apply_awq import add_scale_weights
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
@ -442,6 +444,10 @@ class MPTModel(Model):
|
|||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
if "scales" in name:
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias", ".scales"))
|
||||
new_name = new_name + ".scales"
|
||||
else:
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
|
@ -970,6 +976,9 @@ def parse_args() -> argparse.Namespace:
|
|||
"--vocab-only", action="store_true",
|
||||
help="extract only the vocab",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--awq-path", type=Path, default=None,
|
||||
help="Path to scale awq cache file")
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path,
|
||||
help="path to write to; default: based on input",
|
||||
|
@ -989,7 +998,21 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
args = parse_args()
|
||||
|
||||
if args.awq_path:
|
||||
from awqpy import add_scale_weights
|
||||
tmp_model_path = args.model / "weighted_model"
|
||||
if tmp_model_path.is_dir():
|
||||
print(f"{tmp_model_path} exists as a weighted model.")
|
||||
else:
|
||||
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
||||
print("Saving new weighted model ...")
|
||||
tmp_model_path.mkdirs(exist_ok=True)
|
||||
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
||||
print(f"Saved weighted model at {tmp_model_path}.")
|
||||
dir_model = tmp_model_path
|
||||
else:
|
||||
dir_model = args.model
|
||||
|
||||
if not dir_model.is_dir():
|
||||
print(f'Error: {args.model} is not a directory', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
|
16
convert.py
16
convert.py
|
@ -23,6 +23,8 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import IO, TYPE_CHECKING, Any, Callable, Iterable, Literal, TypeVar
|
||||
|
||||
from awqpy.apply_awq import add_scale_weights
|
||||
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
|
@ -1139,6 +1141,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
# We currently only support Q8_0 output on little endian systems.
|
||||
output_choices.append("q8_0")
|
||||
parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file")
|
||||
parser.add_argument("--awq-path", type=Path, default=None, help="Path to scale awq cache file")
|
||||
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
|
||||
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
|
||||
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
|
||||
|
@ -1152,6 +1155,19 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
||||
|
||||
args = parser.parse_args(args_in)
|
||||
if args.awq_path:
|
||||
from awqpy import add_scale_weights
|
||||
tmp_model_path = args.model / "weighted_model"
|
||||
if tmp_model_path.is_dir():
|
||||
print(f"{tmp_model_path} exists as a weighted model.")
|
||||
else:
|
||||
tmp_model_path.mkdir(parents=True, exist_ok=True)
|
||||
print("Saving new weighted model ...")
|
||||
tmp_model_path.mkdirs(exist_ok=True)
|
||||
add_scale_weights(str(args.model), str(args.awq_path), str(tmp_model_path))
|
||||
print(f"Saved weighted model at {tmp_model_path}.")
|
||||
args.model = tmp_model_path
|
||||
|
||||
if args.dump_single:
|
||||
model_plus = lazy_load_file(args.model)
|
||||
do_dump_model(model_plus)
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
107
llama.cpp
107
llama.cpp
|
@ -3845,7 +3845,6 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
struct ggml_tensor * gate_b,
|
||||
struct ggml_tensor * down,
|
||||
struct ggml_tensor * down_b,
|
||||
struct ggml_tensor *act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb & cb,
|
||||
|
@ -3896,16 +3895,6 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
cur = ggml_relu(ctx, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
} break;
|
||||
case LLM_FFN_GELU_ACT:
|
||||
{
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
struct ggml_tensor *repeat = ggml_repeat(ctx, act_scales, cur);
|
||||
cb(repeat, "ffn_repeat(scales)", il);
|
||||
cur = ggml_div(ctx, cur, repeat);
|
||||
cb(cur, "ffn_div(gelu)", il);
|
||||
}
|
||||
break;
|
||||
case LLM_FFN_RELU_SQR:
|
||||
{
|
||||
cur = ggml_relu(ctx, cur);
|
||||
|
@ -3933,6 +3922,93 @@ static struct ggml_tensor * llm_build_ffn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
static struct ggml_tensor *llm_build_ffn(
|
||||
struct ggml_context *ctx,
|
||||
struct ggml_tensor *cur,
|
||||
struct ggml_tensor *up,
|
||||
struct ggml_tensor *up_b,
|
||||
struct ggml_tensor *gate,
|
||||
struct ggml_tensor *gate_b,
|
||||
struct ggml_tensor *down,
|
||||
struct ggml_tensor *down_b,
|
||||
struct ggml_tensor *act_scales,
|
||||
llm_ffn_op_type type_op,
|
||||
llm_ffn_gate_type type_gate,
|
||||
const llm_build_cb &cb,
|
||||
int il)
|
||||
{
|
||||
struct ggml_tensor *tmp = ggml_mul_mat(ctx, up, cur);
|
||||
cb(tmp, "ffn_up", il);
|
||||
|
||||
if (up_b)
|
||||
{
|
||||
tmp = ggml_add(ctx, tmp, up_b);
|
||||
cb(tmp, "ffn_up_b", il);
|
||||
}
|
||||
|
||||
if (gate)
|
||||
{
|
||||
switch (type_gate)
|
||||
{
|
||||
case LLM_FFN_SEQ:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx, gate, tmp);
|
||||
cb(cur, "ffn_gate", il);
|
||||
}
|
||||
break;
|
||||
case LLM_FFN_PAR:
|
||||
{
|
||||
cur = ggml_mul_mat(ctx, gate, cur);
|
||||
cb(cur, "ffn_gate", il);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (gate_b)
|
||||
{
|
||||
cur = ggml_add(ctx, cur, gate_b);
|
||||
cb(cur, "ffn_gate_b", il);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
cur = tmp;
|
||||
}
|
||||
|
||||
switch (type_op)
|
||||
{
|
||||
case LLM_FFN_GELU_ACT:
|
||||
{
|
||||
cur = ggml_gelu(ctx, cur);
|
||||
cb(cur, "ffn_relu", il);
|
||||
struct ggml_tensor *repeat = ggml_repeat(ctx, act_scales, cur);
|
||||
cb(repeat, "ffn_repeat(scales)", il);
|
||||
cur = ggml_div(ctx, cur, repeat);
|
||||
cb(cur, "ffn_div(gelu)", il);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (type_gate == LLM_FFN_PAR)
|
||||
{
|
||||
cur = ggml_mul(ctx, cur, tmp);
|
||||
cb(cur, "ffn_gate_par", il);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx, down, cur);
|
||||
if (down_b)
|
||||
{
|
||||
cb(cur, "ffn_down", il);
|
||||
}
|
||||
|
||||
if (down_b)
|
||||
{
|
||||
cur = ggml_add(ctx, cur, down_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
// if max_alibi_bias > 0 then apply ALiBi
|
||||
static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_context * ctx,
|
||||
|
@ -4211,7 +4287,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4332,7 +4407,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4451,7 +4525,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4560,7 +4633,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4769,7 +4841,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4860,7 +4931,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -4960,7 +5030,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, model.layers[il].ffn_up_b,
|
||||
NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5168,7 +5237,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
@ -5285,7 +5353,6 @@ struct llm_build_context {
|
|||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue