Merge branch 'ggerganov:master' into vulkan
This commit is contained in:
commit
e85e232439
32 changed files with 1981 additions and 324 deletions
44
.devops/llama-cli-cann.Dockerfile
Normal file
44
.devops/llama-cli-cann.Dockerfile
Normal file
|
@ -0,0 +1,44 @@
|
|||
ARG ASCEND_VERSION=8.0.rc2.alpha003-910b-openeuler22.03-py3.8
|
||||
|
||||
FROM cosdt/cann:$ASCEND_VERSION AS build
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN yum install -y gcc g++ cmake make
|
||||
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
|
||||
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
|
||||
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
|
||||
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
|
||||
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
|
||||
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
|
||||
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
|
||||
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}
|
||||
|
||||
# find libascend_hal.so, because the drive hasn`t been mounted.
|
||||
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH
|
||||
|
||||
RUN echo "Building with static libs" && \
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \
|
||||
cmake -B build -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \
|
||||
cmake --build build --config Release --target llama-cli
|
||||
|
||||
# TODO: use image with NNRT
|
||||
FROM cosdt/cann:$ASCEND_VERSION AS runtime
|
||||
COPY --from=build /app/build/bin/llama-cli /llama-cli
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest
|
||||
ENV LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:$LIBRARY_PATH
|
||||
ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH}
|
||||
ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH}
|
||||
ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH}
|
||||
ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME}
|
||||
ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp
|
||||
ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit
|
||||
ENV ASCEND_HOME_PATH=${ASCEND_TOOLKIT_HOME}
|
||||
|
||||
ENTRYPOINT ["/llama-cli" ]
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -129,3 +129,6 @@ poetry.toml
|
|||
|
||||
# Scripts
|
||||
!/scripts/install-oneapi.bat
|
||||
|
||||
# Test models for lora adapters
|
||||
/lora-tests
|
||||
|
|
|
@ -105,6 +105,7 @@ Typically finetunes of the base models below are supported as well.
|
|||
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
|
||||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
|
||||
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
||||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||
|
||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||
|
||||
|
@ -424,6 +425,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
|
|||
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
|
||||
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
|
||||
| [Vulkan](./docs/build.md#vulkan) | GPU |
|
||||
| [CANN](./docs/build.md#cann) | Ascend NPU |
|
||||
|
||||
## Tools
|
||||
|
||||
|
|
|
@ -110,8 +110,34 @@ int32_t cpu_get_num_physical_cores() {
|
|||
if (result == 0) {
|
||||
return num_physical_cores;
|
||||
}
|
||||
#elif defined(_WIN32)
|
||||
//TODO: Implement
|
||||
#elif defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
|
||||
// TODO: windows + arm64 + mingw64
|
||||
unsigned int n_threads_win = std::thread::hardware_concurrency();
|
||||
unsigned int default_threads = n_threads_win > 0 ? (n_threads_win <= 4 ? n_threads_win : n_threads_win / 2) : 4;
|
||||
|
||||
DWORD buffer_size = 0;
|
||||
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &buffer_size)) {
|
||||
if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) {
|
||||
return default_threads;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> buffer(buffer_size);
|
||||
if (!GetLogicalProcessorInformationEx(RelationProcessorCore, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data()), &buffer_size)) {
|
||||
return default_threads;
|
||||
}
|
||||
|
||||
int32_t num_physical_cores = 0;
|
||||
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(buffer.data());
|
||||
while (buffer_size > 0) {
|
||||
if (info->Relationship == RelationProcessorCore) {
|
||||
num_physical_cores += info->Processor.GroupCount;
|
||||
}
|
||||
buffer_size -= info->Size;
|
||||
info = reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(reinterpret_cast<char*>(info) + info->Size);
|
||||
}
|
||||
|
||||
return num_physical_cores > 0 ? num_physical_cores : default_threads;
|
||||
#endif
|
||||
unsigned int n_threads = std::thread::hardware_concurrency();
|
||||
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
|
||||
|
@ -1727,7 +1753,13 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
|
|||
if (params.n_threads_batch != -1) {
|
||||
os << " (n_threads_batch = " << params.n_threads_batch << ")";
|
||||
}
|
||||
#if defined(_WIN32) && (_WIN32_WINNT >= 0x0601) && !defined(__MINGW64__) // windows 7 and later
|
||||
// TODO: windows + arm64 + mingw64
|
||||
DWORD logicalProcessorCount = GetActiveProcessorCount(ALL_PROCESSOR_GROUPS);
|
||||
os << " / " << logicalProcessorCount << " | " << llama_print_system_info();
|
||||
#else
|
||||
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
|
||||
#endif
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
|
|
@ -596,6 +596,9 @@ class Model:
|
|||
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
|
||||
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
|
||||
res = "gpt3-finnish"
|
||||
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
|
||||
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
|
||||
res = "exaone"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
@ -3740,6 +3743,118 @@ class ChatGLMModel(Model):
|
|||
name = name.removeprefix("transformer.")
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("NemotronForCausalLM")
|
||||
class NemotronModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.NEMOTRON
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
self.gguf_writer.add_pad_token_id(0)
|
||||
self.gguf_writer.add_unk_token_id(1)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
|
||||
f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"])
|
||||
self.gguf_writer.add_layer_norm_eps(f_norm_eps)
|
||||
|
||||
# * Partial RoPE
|
||||
rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"])
|
||||
n_embd = self.find_hparam(["hidden_size", "n_embd"])
|
||||
n_head = self.find_hparam(["num_attention_heads", "n_head"])
|
||||
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
|
||||
|
||||
# * RopeScaling for Nemotron
|
||||
if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
|
||||
else:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side
|
||||
# model.layers.{l}.input_layernorm.weight
|
||||
# model.layers.{l}.post_attention_layernorm.weight
|
||||
# model.norm.weight
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("ExaoneForCausalLM")
|
||||
class ExaoneModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.EXAONE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
||||
assert(hparams["activation_function"] == "silu")
|
||||
|
||||
max_position_embeddings = hparams["max_position_embeddings"]
|
||||
embed_dim = hparams["hidden_size"]
|
||||
num_heads = hparams["num_attention_heads"]
|
||||
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
|
||||
layer_norm_eps = hparams["layer_norm_epsilon"]
|
||||
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
|
||||
num_layers = hparams["num_layers"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
|
||||
# attention_dropout_rate = hparams["attention_dropout"]
|
||||
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
|
||||
# embed_dropout_rate = hparams["embed_dropout"]
|
||||
self.gguf_writer.add_embedding_length(embed_dim)
|
||||
self.gguf_writer.add_head_count(num_heads)
|
||||
self.gguf_writer.add_head_count_kv(num_kv_heads)
|
||||
self.gguf_writer.add_context_length(max_position_embeddings)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
|
||||
self.gguf_writer.add_feed_forward_length(intermediate_size)
|
||||
self.gguf_writer.add_block_count(num_layers)
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
|
||||
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
|
||||
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
|
||||
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
|
||||
if hparams["rope_scaling"].get("type") == "linear":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
|
||||
|
||||
def prepare_tensors(self):
|
||||
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
|
||||
if rope_scaling.get("rope_type", '').lower() == "llama3":
|
||||
base = self.hparams.get("rope_theta", 10000.0)
|
||||
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
|
||||
factor = rope_scaling.get("factor", 8.0)
|
||||
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
|
||||
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
|
||||
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
high_freq_wavelen = old_context_len / high_freq_factor
|
||||
assert low_freq_wavelen != high_freq_wavelen
|
||||
|
||||
rope_factors = []
|
||||
for freq in freqs:
|
||||
wavelen = 2 * math.pi / freq
|
||||
if wavelen < high_freq_wavelen:
|
||||
rope_factors.append(1)
|
||||
elif wavelen > low_freq_wavelen:
|
||||
rope_factors.append(factor)
|
||||
else:
|
||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||
|
||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||
|
||||
super().prepare_tensors()
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
@ -96,6 +96,7 @@ models = [
|
|||
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
|
||||
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
|
||||
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
|
||||
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class Tensor:
|
|||
assert quant is not None, 'Unknown tensor type'
|
||||
(blksize, tysize) = quant
|
||||
offset += 12
|
||||
self.dtype= dtype
|
||||
self.dtype= gguf.GGMLQuantizationType(dtype)
|
||||
self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)])
|
||||
offset += 4 * n_dims
|
||||
self.name = bytes(data[offset:offset + name_len])
|
||||
|
|
259
docs/backend/CANN.md
Normal file
259
docs/backend/CANN.md
Normal file
|
@ -0,0 +1,259 @@
|
|||
# llama.cpp for CANN
|
||||
|
||||
- [Background](#background)
|
||||
- [News](#news)
|
||||
- [OS](#os)
|
||||
- [Hardware](#hardware)
|
||||
- [Model Supports](#model-supports)
|
||||
- [DataType Supports](#datatype-supports)
|
||||
- [Docker](#docker)
|
||||
- [Linux](#linux)
|
||||
- [TODO](#todo)
|
||||
|
||||
|
||||
## Background
|
||||
|
||||
**Ascend NPU** is a range of AI processors using Neural Processing Unit. It will efficiently handle matrix-matrix multiplication, dot-product and scalars.
|
||||
|
||||
**CANN** (Compute Architecture for Neural Networks) is a heterogeneous computing architecture for AI scenarios, providing support for multiple AI frameworks on the top and serving AI processors and programming at the bottom. It plays a crucial role in bridging the gap between upper and lower layers, and is a key platform for improving the computing efficiency of Ascend AI processors. Meanwhile, it offers a highly efficient and easy-to-use programming interface for diverse application scenarios, allowing users to rapidly build AI applications and services based on the Ascend platform.
|
||||
|
||||
**Llama.cpp + CANN**
|
||||
|
||||
The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly.
|
||||
|
||||
## News
|
||||
|
||||
- 2024.8
|
||||
- Support `Q4_0` and `Q8_0` data type for Ascend NPU.
|
||||
- 2024.7
|
||||
- Create CANN backend for Ascend NPU.
|
||||
|
||||
## OS
|
||||
|
||||
| OS | Status | Verified |
|
||||
|:-------:|:-------:|:----------------------------------------------:|
|
||||
| Linux | Support | Ubuntu 22.04, OpenEuler22.03 |
|
||||
|
||||
|
||||
## Hardware
|
||||
|
||||
### Ascend NPU
|
||||
|
||||
**Verified devices**
|
||||
| Ascend NPU | Status |
|
||||
|:-----------------------------:|:-------:|
|
||||
| Atlas 300T A2 | Support |
|
||||
|
||||
*Notes:*
|
||||
|
||||
- If you have trouble with Ascend NPU device, please create a issue with **[CANN]** prefix/tag.
|
||||
- If you run successfully with your Ascend NPU device, please help update the upper table.
|
||||
|
||||
|
||||
## Model Supports
|
||||
|
||||
| Model Name | FP16 | Q8_0 | Q4_0 |
|
||||
|:----------------------------|:-----:|:----:|:----:|
|
||||
| AquilaChat2-7B | √ | √ | √ |
|
||||
| Baichuan-7b | √ | √ | √ |
|
||||
| Baichuan2-7B-Chat | √ | √ | √ |
|
||||
| bitnet_b1_58-large | √ | √ | √ |
|
||||
| bloom-560m | √ | x | √ |
|
||||
| bloomz-alpaca-560m | √ | x | √ |
|
||||
| c4ai-command-r-35B-v01 | x | x | x |
|
||||
| chatglm3-6B | x | x | x |
|
||||
| chinese-alpaca-2-1.3b | √ | √ | √ |
|
||||
| CodeShell-7B | √ | √ | √ |
|
||||
| deepseek-ai_deepseek-coder-1.3B-base | x | x | x |
|
||||
| deepseek-ai_DeepSeek-V2-Lite | x | x | x |
|
||||
| deepseek-coder-6.7B-instruct | x | x | x |
|
||||
| DeepSeek-V2-Lite-64x1.5B | x | x | x |
|
||||
| falcon-7b-instruct | √ | √ | √ |
|
||||
| flan-t5-large | √ | √ | √ |
|
||||
| gemma-2-9b-it | √ | √ | √ |
|
||||
| glm-4-9B | x | x | x |
|
||||
| gpt2 | √ | √ | √ |
|
||||
| Gpt2-163M | √ | √ | √ |
|
||||
| granite-3B-code-instruct | √ | √ | √ |
|
||||
| GritLM-7B | √ | √ | √ |
|
||||
| internlm2_5-7b-chat | √ | √ | √ |
|
||||
| koala-7B-HF | √ | √ | √ |
|
||||
| Llama-2-7b-chat-hf | √ | √ | √ |
|
||||
| Llama-3-Smaug-8B | √ | √ | √ |
|
||||
| Llama2-Chinese-7b-Chat | √ | √ | √ |
|
||||
| Llama3-8B | √ | √ | √ |
|
||||
| Llama3-8b-chinese | √ | √ | √ |
|
||||
| mamba-130m-hf | √ | √ | √ |
|
||||
| Mistral-7B-Instruct-v0.2 | √ | √ | √ |
|
||||
| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ |
|
||||
| mpt-7B | √ | √ | √ |
|
||||
| OLMo-1B-hf | √ | √ | √ |
|
||||
| OpenELM-3B-Instruct | √ | √ | √ |
|
||||
| Orion-14b-base | √ | √ | √ |
|
||||
| phi1 | x | x | x |
|
||||
| phi2 | x | x | x |
|
||||
| Phi-3-mini-4k-instruct | √ | √ | √ |
|
||||
| plamo-13b | √ | √ | √ |
|
||||
| pythia-70M | x | x | x |
|
||||
| Qwen-7B | √ | √ | √ |
|
||||
| Qwen2-1.5B-Instruct | √ | x | √ |
|
||||
| Refact-1_6B-fim | √ | √ | √ |
|
||||
| SmolLM-135M | √ | √ | √ |
|
||||
| stablelm-zephyr | x | x | x |
|
||||
| stablelm-2-zephyr-1_6b | x | x | x |
|
||||
| starcoderbase-1b | √ | √ | √ |
|
||||
| starcoder2-3b | √ | √ | √ |
|
||||
| vigogne-7b-chat | √ | √ | √ |
|
||||
| xverse-7b-chat | √ | √ | √ |
|
||||
| Yi-6b-Chat | √ | √ | √ |
|
||||
|
||||
|
||||
|
||||
## DataType Supports
|
||||
|
||||
| DataType | Status |
|
||||
|:----------------------:|:-------:|
|
||||
| FP16 | Support |
|
||||
| Q8_0 | Support |
|
||||
| Q4_0 | Support |
|
||||
|
||||
## Docker
|
||||
|
||||
### Build Images
|
||||
You can get a image with llama.cpp in one command.
|
||||
```sh
|
||||
docker build -t llama-cpp-cann -f .devops/llama-cli-cann.Dockerfile .
|
||||
```
|
||||
|
||||
### Run container
|
||||
|
||||
```sh
|
||||
# Find all cards.
|
||||
npu-smi info
|
||||
|
||||
# Select the cards that you want to use, make sure these cards are not used by someone.
|
||||
# Following using cards of device0.
|
||||
docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager --device /dev/devmm_svm --device /dev/hisi_hdc -v /usr/local/dcmi:/usr/local/dcmi -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info -v /PATH_TO_YOUR_MODELS/:/app/models -it llama-cpp-cann -m /app/models/MODEL_PATH -ngl 32 -p "Building a website can be done in 10 simple steps:"
|
||||
```
|
||||
|
||||
*Notes:*
|
||||
|
||||
- You may need to install Ascend Driver and firmware on the **host** machine *(Please refer to the [Linux configuration](#linux) for details)*.
|
||||
|
||||
## Linux
|
||||
|
||||
### I. Setup Environment
|
||||
|
||||
1. **Install Ascend Driver and firmware**
|
||||
|
||||
```sh
|
||||
# create driver running user.
|
||||
sudo groupadd -g HwHiAiUser
|
||||
sudo useradd -g HwHiAiUser -d /home/HwHiAiUser -m HwHiAiUser -s /bin/bash
|
||||
sudo usermod -aG HwHiAiUser $USER
|
||||
|
||||
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
|
||||
# and install driver.
|
||||
sudo sh Ascend-hdk-910b-npu-driver_x.x.x_linux-{arch}.run --full --install-for-all
|
||||
```
|
||||
|
||||
Once installed, run `npu-smi info` to check whether driver is installed successfully.
|
||||
```sh
|
||||
+-------------------------------------------------------------------------------------------+
|
||||
| npu-smi 24.1.rc2 Version: 24.1.rc2 |
|
||||
+----------------------+---------------+----------------------------------------------------+
|
||||
| NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)|
|
||||
| Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) |
|
||||
+======================+===============+====================================================+
|
||||
| 2 xxx | OK | 64.4 51 15 / 15 |
|
||||
| 0 | 0000:01:00.0 | 0 1873 / 15077 0 / 32768 |
|
||||
+======================+===============+====================================================+
|
||||
| 5 xxx | OK | 64.0 52 15 / 15 |
|
||||
| 0 | 0000:81:00.0 | 0 1874 / 15077 0 / 32768 |
|
||||
+======================+===============+====================================================+
|
||||
| No running processes found in NPU 2 |
|
||||
+======================+===============+====================================================+
|
||||
| No running processes found in NPU 5 |
|
||||
+======================+===============+====================================================+
|
||||
```
|
||||
|
||||
2. **Install Ascend Firmware**
|
||||
```sh
|
||||
# download driver from https://www.hiascend.com/hardware/firmware-drivers/community according to your system
|
||||
# and install driver.
|
||||
sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full
|
||||
```
|
||||
If the following messaage appers, firmware is installed successfully.
|
||||
```sh
|
||||
Firmware package installed successfully!
|
||||
```
|
||||
|
||||
|
||||
3. **Install CANN toolkit and kernels**
|
||||
|
||||
CANN toolkit and kernels can be obtained from the official [CANN Toolkit](https://www.hiascend.com/zh/developer/download/community/result?module=cann) page.
|
||||
|
||||
Please download the corresponding version that satified your system. The minimum version required is 8.0.RC2.alpha002 and here is the install command.
|
||||
```sh
|
||||
pip3 install attrs numpy decorator sympy cffi pyyaml pathlib2 psutil protobuf scipy requests absl-py wheel typing_extensions
|
||||
sh Ascend-cann-toolkit_8.0.RC2.alpha002_linux-aarch64.run --install
|
||||
sh Ascend-cann-kernels-910b_8.0.RC2.alpha002_linux.run --install
|
||||
```
|
||||
|
||||
Set Ascend Variables:
|
||||
```sh
|
||||
echo "source ~/Ascend/ascend-toolkit/set_env.sh" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
Upon a successful installation, CANN is enabled for the available ascend devices.
|
||||
|
||||
### II. Build llama.cpp
|
||||
|
||||
```sh
|
||||
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release
|
||||
cmake --build build --config release
|
||||
```
|
||||
|
||||
### III. Run the inference
|
||||
|
||||
1. **Retrieve and prepare model**
|
||||
|
||||
You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration.
|
||||
|
||||
**Notes**:
|
||||
|
||||
- CANN backend only supports FP16/Q4_0/Q8_0 models currently.
|
||||
|
||||
2. **Launch inference**
|
||||
|
||||
There are two device selection modes:
|
||||
|
||||
- Single device: Use one device target specified by the user.
|
||||
- Multiple devices: Automatically choose the devices with the same backend.
|
||||
|
||||
| Device selection | Parameter |
|
||||
|:----------------:|:--------------------------------------:|
|
||||
| Single device | --split-mode none --main-gpu DEVICE_ID |
|
||||
| Multiple devices | --split-mode layer (default) |
|
||||
|
||||
Examples:
|
||||
|
||||
- Use device 0:
|
||||
|
||||
```sh
|
||||
./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0
|
||||
```
|
||||
|
||||
- Use multiple devices:
|
||||
|
||||
```sh
|
||||
./build/bin/llama-cli -m path_to_model -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer
|
||||
```
|
||||
|
||||
### **GitHub contribution**:
|
||||
Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay.
|
||||
|
||||
|
||||
## TODO
|
||||
- Support more models and data types.
|
|
@ -352,6 +352,31 @@ cmake --build build --config Release
|
|||
# ggml_vulkan: Using Intel(R) Graphics (ADL GT2) | uma: 1 | fp16: 1 | warp size: 32
|
||||
```
|
||||
|
||||
### CANN
|
||||
This provides NPU acceleration using the AI cores of your Ascend NPU. And [CANN](https://www.hiascend.com/en/software/cann) is a hierarchical APIs to help you to quickly build AI applications and service based on Ascend NPU.
|
||||
|
||||
For more information about Ascend NPU in [Ascend Community](https://www.hiascend.com/en/).
|
||||
|
||||
Make sure to have the CANN toolkit installed. You can download it from here: [CANN Toolkit](https://www.hiascend.com/developer/download/community/result?module=cann)
|
||||
|
||||
Go to `llama.cpp` directory and build using CMake.
|
||||
```bash
|
||||
cmake -B build -DGGML_CANN=on -DCMAKE_BUILD_TYPE=release
|
||||
cmake --build build --config release
|
||||
```
|
||||
|
||||
You can test with:
|
||||
|
||||
`./build/llama-cli -m PATH_TO_MODEL -p "Building a website can be done in 10 steps:" -ngl 32`
|
||||
|
||||
If the fllowing info is output on screen, you are using `llama.cpp by CANN backend`:
|
||||
```bash
|
||||
llm_load_tensors: CANN buffer size = 13313.00 MiB
|
||||
llama_new_context_with_model: CANN compute buffer size = 1260.81 MiB
|
||||
```
|
||||
|
||||
For detailed info, such as model/device supports, CANN install, please refer to [llama.cpp for CANN](./backend/CANN.md).
|
||||
|
||||
### Android
|
||||
|
||||
To read documentation for how to build on Android, [click here](./android.md)
|
||||
|
|
|
@ -16,8 +16,8 @@ Convert PyTorch model to gguf files (You can also download the converted [gguf](
|
|||
|
||||
```bash
|
||||
python ./examples/minicpmv/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5
|
||||
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5
|
||||
python ./convert-hf-to-gguf.py ../MiniCPM-Llama3-V-2_5/model
|
||||
python ./examples/minicpmv/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
|
|
107
examples/llava/README-minicpmv2.6.md
Normal file
107
examples/llava/README-minicpmv2.6.md
Normal file
|
@ -0,0 +1,107 @@
|
|||
## MiniCPM-V 2.6
|
||||
|
||||
### Prepare models and code
|
||||
|
||||
Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch model from huggingface to "MiniCPM-V-2_6" folder.
|
||||
|
||||
Clone llama.cpp:
|
||||
```bash
|
||||
git clone git@github.com:OpenBMB/llama.cpp.git
|
||||
cd llama.cpp
|
||||
git checkout minicpmv-main
|
||||
```
|
||||
|
||||
### Usage of MiniCPM-V 2.6
|
||||
|
||||
Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us)
|
||||
|
||||
```bash
|
||||
python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6
|
||||
python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3
|
||||
python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
|
||||
|
||||
# quantize int4 version
|
||||
./llama-quantize ../MiniCPM-V-2_6/model/ggml-model-f16.gguf ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
```
|
||||
|
||||
Build for Linux or Mac
|
||||
|
||||
```bash
|
||||
make
|
||||
make llama-minicpmv-cli
|
||||
```
|
||||
|
||||
Inference on Linux or Mac
|
||||
```
|
||||
# run f16 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# run quantized int4 version
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
|
||||
# or run in interactive mode
|
||||
./llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -i
|
||||
```
|
||||
|
||||
### Video
|
||||
Install FFmpeg
|
||||
```
|
||||
brew install ffmpeg
|
||||
brew install pkg-config
|
||||
```
|
||||
|
||||
### Android
|
||||
|
||||
#### Build on Android device using Termux
|
||||
We found that build on Android device would bring better runtime performance, so we recommend to build on device.
|
||||
|
||||
[Termux](https://github.com/termux/termux-app#installation) is a terminal app on Android device (no root required).
|
||||
|
||||
Install tools in Termux:
|
||||
```
|
||||
apt update && apt upgrade -y
|
||||
apt install git make cmake
|
||||
```
|
||||
|
||||
It's recommended to move your model inside the `~/` directory for best performance:
|
||||
```
|
||||
cd storage/downloads
|
||||
mv model.gguf ~/
|
||||
```
|
||||
|
||||
#### Building the Project using Android NDK
|
||||
Obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake.
|
||||
|
||||
Execute the following commands on your computer to avoid downloading the NDK to your mobile. Alternatively, you can also do this in Termux:
|
||||
|
||||
```bash
|
||||
mkdir build-android
|
||||
cd build-android
|
||||
export NDK=/your_ndk_path
|
||||
cmake -DCMAKE_TOOLCHAIN_FILE=$NDK/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_C_FLAGS=-march=armv8.4a+dotprod ..
|
||||
make
|
||||
```
|
||||
|
||||
Install [termux](https://github.com/termux/termux-app#installation) on your device and run `termux-setup-storage` to get access to your SD card (if Android 11+ then run the command twice).
|
||||
|
||||
Finally, copy these built `llama` binaries and the model file to your device storage. Because the file permissions in the Android sdcard cannot be changed, you can copy the executable files to the `/data/data/com.termux/files/home/bin` path, and then execute the following commands in Termux to add executable permission:
|
||||
|
||||
(Assumed that you have pushed the built executable files to the /sdcard/llama.cpp/bin path using `adb push`)
|
||||
```
|
||||
$cp -r /sdcard/llama.cpp/bin /data/data/com.termux/files/home/
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$chmod +x ./*
|
||||
```
|
||||
|
||||
Download models and push them to `/sdcard/llama.cpp/`, then move it to `/data/data/com.termux/files/home/model/`
|
||||
|
||||
```
|
||||
$mv /sdcard/llama.cpp/ggml-model-Q4_K_M.gguf /data/data/com.termux/files/home/model/
|
||||
$mv /sdcard/llama.cpp/mmproj-model-f16.gguf /data/data/com.termux/files/home/model/
|
||||
```
|
||||
|
||||
Now, you can start chatting:
|
||||
```
|
||||
$cd /data/data/com.termux/files/home/bin
|
||||
$./llama-minicpmv-cli -m ../model/ggml-model-Q4_K_M.gguf --mmproj ../model/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
|
||||
```
|
|
@ -85,6 +85,7 @@ static std::string format(const char * fmt, ...) {
|
|||
#define KEY_HAS_VIS_ENC "clip.has_vision_encoder"
|
||||
#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector"
|
||||
#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector"
|
||||
#define KEY_MINICPMV_VERSION "clip.minicpmv_version"
|
||||
#define KEY_USE_GELU "clip.use_gelu"
|
||||
#define KEY_N_EMBD "clip.%s.embedding_length"
|
||||
#define KEY_N_FF "clip.%s.feed_forward_length"
|
||||
|
@ -530,6 +531,7 @@ struct clip_ctx {
|
|||
bool has_vision_encoder = false;
|
||||
bool has_llava_projector = false;
|
||||
bool has_minicpmv_projector = false;
|
||||
int minicpmv_version = 2;
|
||||
|
||||
struct clip_vision_model vision_model;
|
||||
projector_type proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
@ -645,7 +647,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
if (ctx->has_minicpmv_projector) {
|
||||
int pos_w = image_size_width/patch_size;
|
||||
int pos_h = image_size_height/patch_size;
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
|
||||
}
|
||||
ggml_set_name(pos_embed, "pos_embed");
|
||||
ggml_set_input(pos_embed);
|
||||
}
|
||||
|
@ -772,8 +779,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
|
||||
// ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
|
||||
|
@ -953,10 +960,20 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
|||
}
|
||||
|
||||
{ // attention
|
||||
const int hidden_size = 4096;
|
||||
int hidden_size = 4096;
|
||||
const int d_head = 128;
|
||||
const int n_head = hidden_size/d_head;
|
||||
const int num_query = 96;
|
||||
int n_head = hidden_size/d_head;
|
||||
int num_query = 96;
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
hidden_size = 4096;
|
||||
n_head = hidden_size/d_head;
|
||||
num_query = 96;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
hidden_size = 3584;
|
||||
n_head = hidden_size/d_head;
|
||||
num_query = 64;
|
||||
}
|
||||
|
||||
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
|
||||
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
|
||||
|
@ -1157,6 +1174,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
new_clip->has_minicpmv_projector = gguf_get_val_bool(ctx, idx);
|
||||
}
|
||||
|
||||
idx = gguf_find_key(ctx, KEY_MINICPMV_VERSION);
|
||||
if (idx != -1) {
|
||||
new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx);
|
||||
}
|
||||
|
||||
// GGML_ASSERT(new_clip->has_llava_projector); // see monatis/clip.cpp for image and/or text encoding for semantic search
|
||||
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
|
@ -1918,10 +1940,12 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
|
|||
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
|
||||
// res_imgs memory is being allocated here, previous allocations will be freed if found
|
||||
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32_batch * res_imgs) {
|
||||
if (clip_is_minicpmv(ctx)) {
|
||||
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img);
|
||||
|
||||
if(clip_is_minicpmv(ctx)){
|
||||
int max_slice_nums = 9;
|
||||
std::vector<std::vector<clip_image_u8 *>> imgs = uhd_slice_image(img, max_slice_nums);
|
||||
res_imgs->size = 0;
|
||||
for (size_t i = 0; i < imgs.size(); ++i) {
|
||||
for (size_t i = 0; i < imgs.size(); ++i){
|
||||
res_imgs->size += imgs[i].size();
|
||||
}
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
|
@ -2154,7 +2178,12 @@ int clip_n_patches(const struct clip_ctx * ctx) {
|
|||
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
|
||||
n_patches /= 4;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
n_patches = 96;
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
n_patches = 96;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
n_patches = 64;
|
||||
}
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
|
@ -2290,6 +2319,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int num_positions = num_patches + (ctx->has_class_embedding ? 1 : 0);
|
||||
if(ctx->load_image_size==nullptr){
|
||||
ctx->load_image_size= clip_image_size_init();
|
||||
}
|
||||
const int pos_w = ctx->load_image_size->width/patch_size;
|
||||
const int pos_h = ctx->load_image_size->height/patch_size;
|
||||
|
||||
{
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
|
@ -2324,8 +2358,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
int* positions_data = (int*)malloc(ggml_nbytes(positions));
|
||||
for (int i = 0; i < num_positions; i++) {
|
||||
positions_data[i] = std::floor(70.0*i/num_positions);
|
||||
int bucket_coords_h[70];
|
||||
int bucket_coords_w[70];
|
||||
for (int i = 0; i < pos_h; i++){
|
||||
bucket_coords_h[i] = std::floor(70.0*i/pos_h);
|
||||
}
|
||||
for (int i = 0; i < pos_w; i++){
|
||||
bucket_coords_w[i] = std::floor(70.0*i/pos_w);
|
||||
}
|
||||
for (int i = 0, id = 0; i < pos_h; i++){
|
||||
for (int j = 0; j < pos_w; j++){
|
||||
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
|
||||
}
|
||||
}
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
|
@ -2336,12 +2380,13 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
// -> https://huggingface.co/Qwen/Qwen-VL/tree/main
|
||||
// -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
|
||||
struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
|
||||
if(ctx->load_image_size==nullptr){
|
||||
ctx->load_image_size= clip_image_size_init();
|
||||
}
|
||||
int pos_w = ctx->load_image_size->width/patch_size;
|
||||
int pos_h = ctx->load_image_size->height/patch_size;
|
||||
int embed_dim = 4096;
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
embed_dim = 4096;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
embed_dim = 3584;
|
||||
}
|
||||
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
|
||||
|
||||
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
|
||||
|
@ -2354,7 +2399,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
|
||||
free(pos_embed_data);
|
||||
}
|
||||
} else {
|
||||
}
|
||||
else{
|
||||
{
|
||||
if (ctx->has_class_embedding) {
|
||||
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
|
||||
|
@ -2556,13 +2602,21 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||
return ctx->vision_model.mm_3_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
|
||||
return 4096;
|
||||
if (ctx->minicpmv_version == 2) {
|
||||
return 4096;
|
||||
}
|
||||
else if (ctx->minicpmv_version == 3) {
|
||||
return 3584;
|
||||
}
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
}
|
||||
|
||||
bool clip_is_minicpmv(const struct clip_ctx * ctx) {
|
||||
return ctx->has_minicpmv_projector;
|
||||
int clip_is_minicpmv(const struct clip_ctx * ctx) {
|
||||
if (ctx->has_minicpmv_projector) {
|
||||
return ctx->minicpmv_version;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons
|
|||
|
||||
CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype);
|
||||
|
||||
CLIP_API bool clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -256,7 +256,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
load_image_size->width = img_res_v.data[i].nx;
|
||||
load_image_size->height = img_res_v.data[i].ny;
|
||||
clip_add_load_image_size(ctx_clip, load_image_size);
|
||||
const bool encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
bool encoded = false;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]);
|
||||
}
|
||||
if (!encoded) {
|
||||
LOG_TEE("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) img_res_v.size);
|
||||
return false;
|
||||
|
|
|
@ -134,7 +134,13 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
|||
std::string system_prompt;
|
||||
int idx = 0;
|
||||
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
|
||||
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
system_prompt = "<|im_start|>user\n";
|
||||
}
|
||||
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
|
||||
eval_string(ctx_llava->ctx_llama, (system_prompt+"<image>").c_str(), params->n_batch, &n_past, false);
|
||||
process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
|
||||
|
@ -210,10 +216,24 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
|
|||
|
||||
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||
std::string user_prompt = prompt;
|
||||
if (!is_first) user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (!is_first) {
|
||||
if (has_minicpmv_projector == 2) {
|
||||
user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
user_prompt = "<|im_start|>user\n" + prompt;
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
|
||||
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
|
||||
if (has_minicpmv_projector == 2) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
else if (has_minicpmv_projector == 3) {
|
||||
eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
|
||||
}
|
||||
|
||||
// generate the response
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
|
|
@ -1,9 +1,416 @@
|
|||
import argparse
|
||||
# coding=utf-8
|
||||
# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Siglip model. """
|
||||
# Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import (
|
||||
logging,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
class SiglipVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
|
||||
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
|
||||
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of channels in the input images.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
patch_size (`int`, *optional*, defaults to 16):
|
||||
The size (resolution) of each patch.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the layer normalization layers.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
|
||||
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
|
||||
>>> configuration = SiglipVisionConfig()
|
||||
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
|
||||
>>> model = SiglipVisionModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "siglip_vision_model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
num_channels=3,
|
||||
image_size=224,
|
||||
patch_size=16,
|
||||
hidden_act="gelu_pytorch_tanh",
|
||||
layer_norm_eps=1e-6,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
|
||||
|
||||
SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/siglip-base-patch16-224",
|
||||
# See all SigLIP models at https://huggingface.co/models?filter=siglip
|
||||
]
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
||||
def _get_unpad_data(attention_mask):
|
||||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
||||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
||||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
||||
return (
|
||||
indices,
|
||||
cu_seqlens,
|
||||
max_seqlen_in_batch,
|
||||
)
|
||||
|
||||
|
||||
def _trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
if tensor.dtype in [torch.float16, torch.bfloat16]:
|
||||
# The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu
|
||||
og_dtype = tensor.dtype
|
||||
tensor = tensor.to(torch.float32)
|
||||
tensor.erfinv_()
|
||||
tensor = tensor.to(og_dtype)
|
||||
else:
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
if tensor.dtype == torch.float16:
|
||||
# The `clamp_` op is not (yet?) defined in float16+cpu
|
||||
tensor = tensor.to(torch.float32)
|
||||
tensor.clamp_(min=a, max=b)
|
||||
tensor = tensor.to(torch.float16)
|
||||
else:
|
||||
tensor.clamp_(min=a, max=b)
|
||||
|
||||
|
||||
def trunc_normal_tf_(
|
||||
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
|
||||
):
|
||||
"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \\leq \text{mean} \\leq b`.
|
||||
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
||||
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
||||
and the result is subsquently scaled and shifted by the mean and std args.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
"""
|
||||
with torch.no_grad():
|
||||
_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||
tensor.mul_(std).add_(mean)
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
denom = fan_in
|
||||
if mode == "fan_in":
|
||||
denom = fan_in
|
||||
elif mode == "fan_out":
|
||||
denom = fan_out
|
||||
elif mode == "fan_avg":
|
||||
denom = (fan_in + fan_out) / 2
|
||||
|
||||
variance = scale / denom
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
|
||||
elif distribution == "normal":
|
||||
with torch.no_grad():
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
bound = math.sqrt(3 * variance)
|
||||
with torch.no_grad():
|
||||
tensor.uniform_(-bound, bound)
|
||||
else:
|
||||
raise ValueError(f"invalid distribution {distribution}")
|
||||
|
||||
|
||||
def lecun_normal_(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
|
||||
|
||||
|
||||
def default_flax_embed_init(tensor):
|
||||
variance_scaling_(tensor, mode="fan_in", distribution="normal")
|
||||
|
||||
class SiglipVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
|
||||
self.num_patches_per_side = self.image_size // self.patch_size
|
||||
self.num_patches = self.num_patches_per_side**2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
|
||||
class SiglipAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
|
||||
class SiglipMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
|
||||
class SiglipEncoderLayer(nn.Module):
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
self.self_attn = (
|
||||
SiglipAttention(config)
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = SiglipMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
class SiglipPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
models.
|
||||
"""
|
||||
|
||||
config_class = SiglipVisionConfig
|
||||
base_model_prefix = "siglip"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
|
||||
if isinstance(module, SiglipVisionEmbeddings):
|
||||
width = self.config.hidden_size
|
||||
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
|
||||
elif isinstance(module, nn.Embedding):
|
||||
default_flax_embed_init(module.weight)
|
||||
elif isinstance(module, SiglipAttention):
|
||||
nn.init.normal_(module.q_proj.weight)
|
||||
nn.init.normal_(module.k_proj.weight)
|
||||
nn.init.normal_(module.v_proj.weight)
|
||||
nn.init.normal_(module.out_proj.weight)
|
||||
nn.init.zeros_(module.q_proj.bias)
|
||||
nn.init.zeros_(module.k_proj.bias)
|
||||
nn.init.zeros_(module.v_proj.bias)
|
||||
nn.init.zeros_(module.out_proj.bias)
|
||||
elif isinstance(module, SiglipMLP):
|
||||
nn.init.normal_(module.fc1.weight)
|
||||
nn.init.normal_(module.fc2.weight)
|
||||
nn.init.normal_(module.fc1.bias, std=1e-6)
|
||||
nn.init.normal_(module.fc2.bias, std=1e-6)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
lecun_normal_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
|
||||
SIGLIP_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
Parameters:
|
||||
config ([`SiglipVisionConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
SIGLIP_VISION_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
||||
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
|
||||
class SiglipEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`SiglipEncoderLayer`].
|
||||
Args:
|
||||
config: SiglipConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
class SiglipVisionTransformer(SiglipPreTrainedModel):
|
||||
config_class = SiglipVisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
def __init__(self, config: SiglipVisionConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(config)
|
||||
self.encoder = SiglipEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embeddings.patch_embedding
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from gguf import *
|
||||
from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig
|
||||
|
@ -94,6 +501,7 @@ default_image_mean = [0.48145466, 0.4578275, 0.40821073]
|
|||
default_image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||
ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None)
|
||||
ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None)
|
||||
ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3', default=2)
|
||||
|
||||
# with proper
|
||||
args = ap.parse_args()
|
||||
|
@ -135,6 +543,15 @@ if args.use_f32:
|
|||
# model = CLIPModel.from_pretrained(dir_model)
|
||||
# processor = CLIPProcessor.from_pretrained(dir_model)
|
||||
|
||||
minicpmv_version = args.minicpmv_version
|
||||
emb_dim = 4096
|
||||
if minicpmv_version == 1:
|
||||
emb_dim = 2304
|
||||
elif minicpmv_version == 2:
|
||||
emb_dim = 4096
|
||||
elif minicpmv_version == 3:
|
||||
emb_dim = 3584
|
||||
|
||||
default_vision_config = {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 980,
|
||||
|
@ -144,8 +561,12 @@ default_vision_config = {
|
|||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
}
|
||||
|
||||
vision_config = Idefics2VisionConfig(**default_vision_config)
|
||||
model = Idefics2VisionTransformer(vision_config)
|
||||
if minicpmv_version == 3:
|
||||
vision_config = SiglipVisionConfig(**default_vision_config)
|
||||
model = SiglipVisionTransformer(vision_config)
|
||||
|
||||
processor = None
|
||||
# if model.attn_pool is not None:
|
||||
|
@ -158,6 +579,7 @@ fname_middle = None
|
|||
has_text_encoder = True
|
||||
has_vision_encoder = True
|
||||
has_minicpmv_projector = False
|
||||
|
||||
if args.text_only:
|
||||
fname_middle = "text-"
|
||||
has_vision_encoder = False
|
||||
|
@ -165,6 +587,7 @@ elif args.minicpmv_projector is not None:
|
|||
fname_middle = "mmproj-"
|
||||
has_text_encoder = False
|
||||
has_minicpmv_projector = True
|
||||
minicpmv_version = 3
|
||||
elif args.vision_only:
|
||||
fname_middle = "vision-"
|
||||
has_text_encoder = False
|
||||
|
@ -189,6 +612,7 @@ elif has_minicpmv_projector:
|
|||
fout.add_description("image encoder for MiniCPM-V")
|
||||
# add projector type
|
||||
fout.add_string("clip.projector_type", "resampler")
|
||||
fout.add_int32("clip.minicpmv_version", minicpmv_version)
|
||||
else:
|
||||
fout.add_description("two-tower CLIP model")
|
||||
|
||||
|
@ -274,11 +698,11 @@ def _replace_name_resampler(s, v):
|
|||
if re.match("resampler.pos_embed", s):
|
||||
return {
|
||||
s: v,
|
||||
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
|
||||
re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
|
||||
}
|
||||
if re.match("resampler.proj", s):
|
||||
return {
|
||||
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(4096, (70, 70))),
|
||||
re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))),
|
||||
re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(),
|
||||
}
|
||||
if re.match("resampler.attn.in_proj_.*", s):
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("-m", "--model", help="Path to MiniCPM-V-2.5 model")
|
||||
ap.add_argument("-m", "--model", help="Path to MiniCPM-V model")
|
||||
args = ap.parse_args()
|
||||
|
||||
# find the model part that includes the the multimodal projector weights
|
||||
|
@ -29,7 +29,6 @@ if len(clip_tensors) > 0:
|
|||
f.write("{}\n")
|
||||
|
||||
config = model.llm.config
|
||||
config._name_or_path = "openbmb/MiniCPM-Llama3-V-2.5"
|
||||
config.auto_map = {
|
||||
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
|
||||
"AutoModel": "modeling_minicpm.MiniCPMModel",
|
||||
|
@ -40,7 +39,6 @@ config.auto_map = {
|
|||
model.llm.save_pretrained(f"{args.model}/model")
|
||||
tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
||||
tok.save_pretrained(f"{args.model}/model")
|
||||
# os.system(f"cp {args.model}/modeling_minicpm.py {args.model}/MiniCPM_l3/modeling_minicpm.py")
|
||||
|
||||
print("Done!")
|
||||
print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.")
|
||||
|
|
|
@ -34,7 +34,7 @@ Run the quantized model:
|
|||
|
||||
```bash
|
||||
# start inference on a gguf model
|
||||
./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -n 128
|
||||
./llama-cli -m ./models/mymodel/ggml-model-Q4_K_M.gguf -cnv -p "You are a helpful assistant"
|
||||
```
|
||||
|
||||
When running the larger models, make sure you have enough disk space to store all the intermediate files.
|
||||
|
|
|
@ -368,15 +368,16 @@ node index.js
|
|||
|
||||
## API Endpoints
|
||||
|
||||
### GET `/health`: Returns the current state of the server
|
||||
### GET `/health`: Returns heath check result
|
||||
|
||||
- 503 -> `{"status": "loading model"}` if the model is still being loaded.
|
||||
- 500 -> `{"status": "error"}` if the model failed to load.
|
||||
- 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
|
||||
- 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available.
|
||||
- 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available.
|
||||
**Response format**
|
||||
|
||||
If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set.
|
||||
- HTTP status code 503
|
||||
- Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}`
|
||||
- Explanation: the model is still being loaded.
|
||||
- HTTP status code 200
|
||||
- Body: `{"status": "ok" }`
|
||||
- Explanation: the model is successfully loaded and the server is ready.
|
||||
|
||||
### POST `/completion`: Given a `prompt`, it returns the predicted completion.
|
||||
|
||||
|
@ -639,10 +640,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
|
|||
}'
|
||||
```
|
||||
|
||||
### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
|
||||
### GET `/slots`: Returns the current slots processing state
|
||||
|
||||
This endpoint can be disabled with `--no-slots`
|
||||
|
||||
If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots.
|
||||
|
||||
**Response format**
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
|
@ -702,7 +709,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
|
|||
]
|
||||
```
|
||||
|
||||
### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled:
|
||||
Possible values for `slot[i].state` are:
|
||||
- `0`: SLOT_STATE_IDLE
|
||||
- `1`: SLOT_STATE_PROCESSING
|
||||
|
||||
### GET `/metrics`: Prometheus compatible metrics exporter
|
||||
|
||||
This endpoint is only accessible if `--metrics` is set.
|
||||
|
||||
Available metrics:
|
||||
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
|
||||
|
@ -767,6 +780,10 @@ Available metrics:
|
|||
|
||||
### GET `/lora-adapters`: Get list of all LoRA adapters
|
||||
|
||||
This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...`
|
||||
|
||||
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
|
||||
|
||||
If an adapter is disabled, the scale will be set to 0.
|
||||
|
||||
**Response format**
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
#include "json.hpp"
|
||||
// mime type for sending response
|
||||
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
||||
|
||||
// auto generated files (update with ./deps.sh)
|
||||
#include "colorthemes.css.hpp"
|
||||
|
@ -67,7 +69,6 @@ enum slot_command {
|
|||
enum server_state {
|
||||
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
||||
SERVER_STATE_READY, // Server is ready and model is loaded
|
||||
SERVER_STATE_ERROR // An error occurred, load_model failed
|
||||
};
|
||||
|
||||
enum server_task_type {
|
||||
|
@ -695,6 +696,7 @@ struct server_context {
|
|||
|
||||
add_bos_token = llama_add_bos_token(model);
|
||||
has_eos_token = !llama_add_eos_token(model);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -2555,19 +2557,19 @@ int main(int argc, char ** argv) {
|
|||
svr->set_default_headers({{"Server", "llama.cpp"}});
|
||||
|
||||
// CORS preflight
|
||||
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
||||
// Access-Control-Allow-Origin is already set by middleware
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
return res.set_content("", "application/json; charset=utf-8");
|
||||
return res.set_content("", "text/html"); // blank response, no data
|
||||
});
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(final_response.dump(), MIMETYPE_JSON);
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
|
@ -2597,11 +2599,6 @@ int main(int argc, char ** argv) {
|
|||
svr->set_read_timeout (params.timeout_read);
|
||||
svr->set_write_timeout(params.timeout_write);
|
||||
|
||||
if (!svr->bind_to_port(params.hostname, params.port)) {
|
||||
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::string> log_data;
|
||||
|
||||
log_data["hostname"] = params.hostname;
|
||||
|
@ -2617,35 +2614,6 @@ int main(int argc, char ** argv) {
|
|||
// Necessary similarity of prompt for slot selection
|
||||
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
||||
|
||||
// load the model
|
||||
if (!ctx_server.load_model(params)) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.init();
|
||||
state.store(SERVER_STATE_READY);
|
||||
}
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
const auto model_meta = ctx_server.model_meta();
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
// print sample chat example to make it clear which template is used
|
||||
{
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||
{"built_in", params.chat_template.empty()},
|
||||
});
|
||||
}
|
||||
|
||||
//
|
||||
// Middlewares
|
||||
//
|
||||
|
@ -2689,8 +2657,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// API key is invalid or not provided
|
||||
// TODO: make another middleware for CORS related logic
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
@ -2698,8 +2664,21 @@ int main(int argc, char ** argv) {
|
|||
return false;
|
||||
};
|
||||
|
||||
auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
// register server middlewares
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
||||
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!middleware_server_state(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
if (!middleware_validate_api_key(req, res)) {
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
|
@ -2710,62 +2689,15 @@ int main(int argc, char ** argv) {
|
|||
// Route handlers (or controllers)
|
||||
//
|
||||
|
||||
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
server_state current_state = state.load();
|
||||
switch (current_state) {
|
||||
case SERVER_STATE_READY:
|
||||
{
|
||||
// request slots data using task queue
|
||||
server_task task;
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.type = SERVER_TASK_TYPE_METRICS;
|
||||
task.id_target = -1;
|
||||
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
const int n_idle_slots = result.data.at("idle");
|
||||
const int n_processing_slots = result.data.at("processing");
|
||||
|
||||
json health = {
|
||||
{"status", "ok"},
|
||||
{"slots_idle", n_idle_slots},
|
||||
{"slots_processing", n_processing_slots}
|
||||
};
|
||||
|
||||
res.status = 200; // HTTP OK
|
||||
if (params.endpoint_slots && req.has_param("include_slots")) {
|
||||
health["slots"] = result.data.at("slots");
|
||||
}
|
||||
|
||||
if (n_idle_slots == 0) {
|
||||
health["status"] = "no slot available";
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
res.status = 503; // HTTP Service Unavailable
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(health.dump(), "application/json");
|
||||
break;
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
{
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
} break;
|
||||
case SERVER_STATE_ERROR:
|
||||
{
|
||||
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
} break;
|
||||
}
|
||||
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
|
||||
// error and loading states are handled by middleware
|
||||
json health = {{"status", "ok"}};
|
||||
res.set_content(health.dump(), "application/json");
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
if (!params.endpoint_slots) {
|
||||
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2783,13 +2715,22 @@ int main(int argc, char ** argv) {
|
|||
server_task_result result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
res.set_content(result.data.at("slots").dump(), "application/json");
|
||||
// optionally return "fail_on_no_slot" error
|
||||
const int n_idle_slots = result.data.at("idle");
|
||||
if (req.has_param("fail_on_no_slot")) {
|
||||
if (n_idle_slots == 0) {
|
||||
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!params.endpoint_metrics) {
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -2914,7 +2855,7 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2944,7 +2885,7 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2964,13 +2905,11 @@ int main(int argc, char ** argv) {
|
|||
if (result.error) {
|
||||
res_error(res, result.data);
|
||||
} else {
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
std::string id_slot_str = req.path_params.at("id_slot");
|
||||
int id_slot;
|
||||
|
||||
|
@ -2994,7 +2933,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
||||
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
||||
if (tlen > 0) {
|
||||
|
@ -3003,7 +2942,6 @@ int main(int argc, char ** argv) {
|
|||
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
||||
}
|
||||
}
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = {
|
||||
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||
|
@ -3011,7 +2949,7 @@ int main(int argc, char ** argv) {
|
|||
{ "chat_template", curr_tmpl.c_str() }
|
||||
};
|
||||
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -3020,8 +2958,6 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3032,7 +2968,7 @@ int main(int argc, char ** argv) {
|
|||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3095,9 +3031,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
};
|
||||
|
||||
const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
||||
json models = {
|
||||
{"object", "list"},
|
||||
{"data", {
|
||||
|
@ -3106,12 +3040,12 @@ int main(int argc, char ** argv) {
|
|||
{"object", "model"},
|
||||
{"created", std::time(0)},
|
||||
{"owned_by", "llamacpp"},
|
||||
{"meta", model_meta}
|
||||
{"meta", ctx_server.model_meta()}
|
||||
},
|
||||
}}
|
||||
};
|
||||
|
||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
res.set_content(models.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
|
@ -3119,8 +3053,6 @@ int main(int argc, char ** argv) {
|
|||
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3135,7 +3067,7 @@ int main(int argc, char ** argv) {
|
|||
if (!result.error && result.stop) {
|
||||
json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
|
||||
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3197,8 +3129,6 @@ int main(int argc, char ** argv) {
|
|||
return;
|
||||
}
|
||||
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json data = json::parse(req.body);
|
||||
|
||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||
|
@ -3209,7 +3139,7 @@ int main(int argc, char ** argv) {
|
|||
if (!json_value(data, "stream", false)) {
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
||||
} else {
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
@ -3257,7 +3187,6 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
@ -3266,11 +3195,10 @@ int main(int argc, char ** argv) {
|
|||
tokens = ctx_server.tokenize(body.at("content"), add_special);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
|
||||
std::string content;
|
||||
|
@ -3280,12 +3208,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
const json data = format_detokenized_response(content);
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(data.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
const json body = json::parse(req.body);
|
||||
bool is_openai = false;
|
||||
|
||||
|
@ -3331,11 +3257,10 @@ int main(int argc, char ** argv) {
|
|||
json root = is_openai
|
||||
? format_embeddings_response_oaicompat(body, responses)
|
||||
: responses[0];
|
||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||
return res.set_content(root.dump(), MIMETYPE_JSON);
|
||||
};
|
||||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
|
||||
auto & la = ctx_server.lora_adapters[i];
|
||||
|
@ -3345,13 +3270,11 @@ int main(int argc, char ** argv) {
|
|||
{"scale", la.scale},
|
||||
});
|
||||
}
|
||||
res.set_content(result.dump(), "application/json");
|
||||
res.set_content(result.dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.lora_adapters.size();
|
||||
|
||||
|
@ -3379,7 +3302,7 @@ int main(int argc, char ** argv) {
|
|||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
res.set_content(result.data.dump(), "application/json");
|
||||
res.set_content(result.data.dump(), MIMETYPE_JSON);
|
||||
res.status = 200; // HTTP OK
|
||||
};
|
||||
|
||||
|
@ -3455,35 +3378,75 @@ int main(int argc, char ** argv) {
|
|||
log_data["n_threads_http"] = std::to_string(params.n_threads_http);
|
||||
svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
|
||||
|
||||
LOG_INFO("HTTP server listening", log_data);
|
||||
// clean up function, to be called before exit
|
||||
auto clean_up = [&svr]() {
|
||||
svr->stop();
|
||||
llama_backend_free();
|
||||
};
|
||||
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]() {
|
||||
if (!svr->listen_after_bind()) {
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
// bind HTTP listen port, run the HTTP server in a thread
|
||||
if (!svr->bind_to_port(params.hostname, params.port)) {
|
||||
LOG_ERROR("couldn't bind HTTP server socket", {
|
||||
{"hostname", params.hostname},
|
||||
{"port", params.port},
|
||||
});
|
||||
clean_up();
|
||||
LOG_ERROR("exiting due to HTTP server error", {});
|
||||
return 1;
|
||||
}
|
||||
std::thread t([&]() { svr->listen_after_bind(); });
|
||||
svr->wait_until_ready();
|
||||
|
||||
LOG_INFO("HTTP server is listening", log_data);
|
||||
|
||||
// load the model
|
||||
LOG_INFO("loading model", log_data);
|
||||
if (!ctx_server.load_model(params)) {
|
||||
clean_up();
|
||||
t.join();
|
||||
LOG_ERROR("exiting due to model loading error", {});
|
||||
return 1;
|
||||
} else {
|
||||
ctx_server.init();
|
||||
state.store(SERVER_STATE_READY);
|
||||
|
||||
LOG_INFO("model loaded", {});
|
||||
|
||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||
if (params.chat_template.empty()) {
|
||||
if (!ctx_server.validate_model_chat_template()) {
|
||||
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
||||
params.chat_template = "chatml";
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
});
|
||||
// print sample chat example to make it clear which template is used
|
||||
{
|
||||
LOG_INFO("chat template", {
|
||||
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
||||
{"built_in", params.chat_template.empty()},
|
||||
});
|
||||
}
|
||||
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
&ctx_server.queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
&ctx_server.queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
shutdown_handler = [&](int) {
|
||||
ctx_server.queue_tasks.terminate();
|
||||
};
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
|
@ -3499,12 +3462,8 @@ int main(int argc, char ** argv) {
|
|||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
|
||||
svr->stop();
|
||||
clean_up();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -205,27 +205,20 @@ def step_start_server(context):
|
|||
async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
|
||||
match expecting_status:
|
||||
case 'healthy':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||
timeout=30)
|
||||
await wait_for_slots_status(context, context.base_url, 200,
|
||||
timeout=30)
|
||||
|
||||
case 'ready' | 'idle':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||
timeout=30,
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0,
|
||||
expected_slots=[{'id': slot_id, 'state': 0}
|
||||
for slot_id in
|
||||
range(context.n_slots if context.n_slots else 1)])
|
||||
await wait_for_slots_status(context, context.base_url, 200,
|
||||
timeout=30,
|
||||
params={'fail_on_no_slot': 1},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0)
|
||||
case 'busy':
|
||||
await wait_for_health_status(context, context.base_url, 503,
|
||||
'no slot available',
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=0,
|
||||
slots_processing=context.n_slots,
|
||||
expected_slots=[{'id': slot_id, 'state': 1}
|
||||
for slot_id in
|
||||
range(context.n_slots if context.n_slots else 1)])
|
||||
await wait_for_slots_status(context, context.base_url, 503,
|
||||
params={'fail_on_no_slot': 1},
|
||||
slots_idle=0,
|
||||
slots_processing=context.n_slots)
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
|
@ -1187,17 +1180,15 @@ async def gather_tasks_results(context):
|
|||
return n_completions
|
||||
|
||||
|
||||
async def wait_for_health_status(context,
|
||||
base_url,
|
||||
expected_http_status_code,
|
||||
expected_health_status,
|
||||
timeout=3,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None,
|
||||
expected_slots=None):
|
||||
async def wait_for_slots_status(context,
|
||||
base_url,
|
||||
expected_http_status_code,
|
||||
timeout=3,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None):
|
||||
if context.debug:
|
||||
print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
||||
print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}")
|
||||
interval = 0.5
|
||||
counter = 0
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
|
@ -1205,26 +1196,19 @@ async def wait_for_health_status(context,
|
|||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
async with await session.get(f'{base_url}/health', params=params) as health_response:
|
||||
status_code = health_response.status
|
||||
health = await health_response.json()
|
||||
async with await session.get(f'{base_url}/slots', params=params) as slots_response:
|
||||
status_code = slots_response.status
|
||||
slots = await slots_response.json()
|
||||
if context.debug:
|
||||
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
|
||||
f"'{base_url}/health'?{params} is {health}\n")
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
return
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
print(f"slots responses {slots}\n")
|
||||
if status_code == 503 and status_code == expected_http_status_code:
|
||||
return
|
||||
if status_code == 200 and status_code == expected_http_status_code:
|
||||
n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots)
|
||||
n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots)
|
||||
if ((slots_idle is None or slots_idle == n_slots_idle)
|
||||
and (slots_processing is None or slots_processing == n_slots_processing)):
|
||||
return
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
counter += interval
|
||||
|
@ -1238,7 +1222,7 @@ async def wait_for_health_status(context,
|
|||
if n_completions > 0:
|
||||
return
|
||||
|
||||
assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}'
|
||||
assert False, f'slots check timeout exceeded {counter}s>={timeout}'
|
||||
|
||||
|
||||
def assert_embeddings(embeddings):
|
||||
|
|
6
flake.lock
generated
6
flake.lock
generated
|
@ -20,11 +20,11 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1723175592,
|
||||
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=",
|
||||
"lastModified": 1723637854,
|
||||
"narHash": "sha256-med8+5DSWa2UnOqtdICndjDAEjxr5D7zaIiK4pn0Q7c=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b",
|
||||
"rev": "c3aa7b8938b17aebd2deecf7be0636000d62a2b9",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
|
|
@ -1018,10 +1018,6 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
|||
#define GGML_SCHED_MAX_BACKENDS 16
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLITS
|
||||
#define GGML_SCHED_MAX_SPLITS 2048
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
|
||||
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
|
||||
#endif
|
||||
|
@ -1125,7 +1121,8 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
|
|||
}
|
||||
|
||||
#if 0
|
||||
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
|
||||
#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
|
||||
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
|
||||
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
|
||||
#define GET_CAUSE(node) causes[hash_id(node)]
|
||||
#else
|
||||
|
@ -1549,7 +1546,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
|
|||
sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
|
||||
GGML_ASSERT(sched->splits != NULL);
|
||||
}
|
||||
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
|
||||
split = &sched->splits[i_split];
|
||||
split->backend_id = node_backend_id;
|
||||
split->i_start = i;
|
||||
|
@ -1865,13 +1861,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
|
|||
sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
|
||||
sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
|
||||
|
||||
const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
|
||||
const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
|
||||
const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
|
||||
sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
|
||||
sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
|
||||
sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
|
||||
sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
|
||||
|
||||
sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
|
||||
sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
|
||||
sched->context_buffer = malloc(sched->context_buffer_size);
|
||||
|
||||
const int initial_splits_capacity = 16;
|
||||
|
|
|
@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
|
|||
|
||||
// RPC commands
|
||||
enum rpc_cmd {
|
||||
ALLOC_BUFFER = 0,
|
||||
GET_ALIGNMENT,
|
||||
GET_MAX_SIZE,
|
||||
BUFFER_GET_BASE,
|
||||
FREE_BUFFER,
|
||||
BUFFER_CLEAR,
|
||||
SET_TENSOR,
|
||||
GET_TENSOR,
|
||||
COPY_TENSOR,
|
||||
GRAPH_COMPUTE,
|
||||
GET_DEVICE_MEMORY,
|
||||
RPC_CMD_ALLOC_BUFFER = 0,
|
||||
RPC_CMD_GET_ALIGNMENT,
|
||||
RPC_CMD_GET_MAX_SIZE,
|
||||
RPC_CMD_BUFFER_GET_BASE,
|
||||
RPC_CMD_FREE_BUFFER,
|
||||
RPC_CMD_BUFFER_CLEAR,
|
||||
RPC_CMD_SET_TENSOR,
|
||||
RPC_CMD_GET_TENSOR,
|
||||
RPC_CMD_COPY_TENSOR,
|
||||
RPC_CMD_GRAPH_COMPUTE,
|
||||
RPC_CMD_GET_DEVICE_MEMORY,
|
||||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
// RPC data structures
|
||||
|
@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
|
|||
uint64_t remote_ptr = ctx->remote_ptr;
|
||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.empty());
|
||||
delete ctx;
|
||||
|
@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
|
|||
uint64_t remote_ptr = ctx->remote_ptr;
|
||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | base_ptr (8 bytes) |
|
||||
|
@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
|
|||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
||||
GGML_ASSERT(status);
|
||||
}
|
||||
|
||||
|
@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
|
|||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == size);
|
||||
// output serialization format: | data (size bytes) |
|
||||
|
@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
|
|||
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
||||
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
||||
GGML_ASSERT(status);
|
||||
// output serialization format: | result (1 byte) |
|
||||
GGML_ASSERT(output.size() == 1);
|
||||
|
@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
|
|||
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
||||
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
|
||||
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
||||
GGML_ASSERT(status);
|
||||
}
|
||||
|
||||
|
@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||
memcpy(input.data(), &size, sizeof(size));
|
||||
std::vector<uint8_t> output;
|
||||
auto sock = get_socket(buft_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
||||
|
@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | alignment (8 bytes) |
|
||||
|
@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||
// output serialization format: | max_size (8 bytes) |
|
||||
|
@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|||
serialize_graph(cgraph, input);
|
||||
std::vector<uint8_t> output;
|
||||
auto sock = get_socket(rpc_ctx->endpoint);
|
||||
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 1);
|
||||
return (enum ggml_status)output[0];
|
||||
|
@ -636,7 +637,7 @@ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const
|
|||
}
|
||||
|
||||
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
||||
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
||||
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
||||
return false;
|
||||
}
|
||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||
|
@ -678,6 +679,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
|
|||
}
|
||||
auto sock = get_socket(endpoint);
|
||||
if (sock == nullptr) {
|
||||
fprintf(stderr, "Failed to connect to %s\n", endpoint);
|
||||
return nullptr;
|
||||
}
|
||||
size_t alignment = get_alignment(sock);
|
||||
|
@ -719,7 +721,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|||
// input serialization format: | 0 bytes |
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
||||
GGML_ASSERT(status);
|
||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
|
@ -1098,59 +1100,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
break;
|
||||
}
|
||||
if (cmd >= RPC_CMD_COUNT) {
|
||||
// fail fast if the command is invalid
|
||||
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||
break;
|
||||
}
|
||||
std::vector<uint8_t> input;
|
||||
std::vector<uint8_t> output;
|
||||
uint64_t input_size;
|
||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
||||
break;
|
||||
}
|
||||
input.resize(input_size);
|
||||
try {
|
||||
input.resize(input_size);
|
||||
} catch (const std::bad_alloc & e) {
|
||||
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
||||
break;
|
||||
}
|
||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
||||
break;
|
||||
}
|
||||
bool ok = true;
|
||||
switch (cmd) {
|
||||
case ALLOC_BUFFER: {
|
||||
case RPC_CMD_ALLOC_BUFFER: {
|
||||
ok = server.alloc_buffer(input, output);
|
||||
break;
|
||||
}
|
||||
case GET_ALIGNMENT: {
|
||||
case RPC_CMD_GET_ALIGNMENT: {
|
||||
server.get_alignment(output);
|
||||
break;
|
||||
}
|
||||
case GET_MAX_SIZE: {
|
||||
case RPC_CMD_GET_MAX_SIZE: {
|
||||
server.get_max_size(output);
|
||||
break;
|
||||
}
|
||||
case BUFFER_GET_BASE: {
|
||||
case RPC_CMD_BUFFER_GET_BASE: {
|
||||
ok = server.buffer_get_base(input, output);
|
||||
break;
|
||||
}
|
||||
case FREE_BUFFER: {
|
||||
case RPC_CMD_FREE_BUFFER: {
|
||||
ok = server.free_buffer(input);
|
||||
break;
|
||||
}
|
||||
case BUFFER_CLEAR: {
|
||||
case RPC_CMD_BUFFER_CLEAR: {
|
||||
ok = server.buffer_clear(input);
|
||||
break;
|
||||
}
|
||||
case SET_TENSOR: {
|
||||
case RPC_CMD_SET_TENSOR: {
|
||||
ok = server.set_tensor(input);
|
||||
break;
|
||||
}
|
||||
case GET_TENSOR: {
|
||||
case RPC_CMD_GET_TENSOR: {
|
||||
ok = server.get_tensor(input, output);
|
||||
break;
|
||||
}
|
||||
case COPY_TENSOR: {
|
||||
case RPC_CMD_COPY_TENSOR: {
|
||||
ok = server.copy_tensor(input, output);
|
||||
break;
|
||||
}
|
||||
case GRAPH_COMPUTE: {
|
||||
case RPC_CMD_GRAPH_COMPUTE: {
|
||||
ok = server.graph_compute(input, output);
|
||||
break;
|
||||
}
|
||||
case GET_DEVICE_MEMORY: {
|
||||
case RPC_CMD_GET_DEVICE_MEMORY: {
|
||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||
output.resize(2*sizeof(uint64_t), 0);
|
||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
||||
|
@ -1203,8 +1215,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|||
return;
|
||||
}
|
||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||
fflush(stdout);
|
||||
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
||||
printf("Client connection closed\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
#ifdef _WIN32
|
||||
WSACleanup();
|
||||
|
|
|
@ -219,6 +219,8 @@ class MODEL_ARCH(IntEnum):
|
|||
T5 = auto()
|
||||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -347,6 +349,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.T5: "t5",
|
||||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -1065,6 +1069,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.NEMOTRON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.EXAONE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -1105,6 +1140,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_ARCH.CHATGLM: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
],
|
||||
MODEL_ARCH.NEMOTRON: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
|
|
@ -10,10 +10,10 @@ class TensorNameMap:
|
|||
# Token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD: (
|
||||
"gpt_neox.embed_in", # gptneox
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf
|
||||
"model.embed_tokens", # llama-hf nemotron
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
|
@ -52,7 +52,7 @@ class TensorNameMap:
|
|||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
|
@ -62,7 +62,7 @@ class TensorNameMap:
|
|||
# Output norm
|
||||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
|
@ -75,6 +75,7 @@ class TensorNameMap:
|
|||
"transformer.rms_norm", # Grok
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
),
|
||||
|
||||
# Rope frequencies
|
||||
|
@ -88,12 +89,12 @@ class TensorNameMap:
|
|||
# Attention norm
|
||||
MODEL_TENSOR.ATTN_NORM: (
|
||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
|
||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
|
||||
"transformer.blocks.{bid}.norm_1", # mpt
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
|
@ -135,18 +136,19 @@ class TensorNameMap:
|
|||
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
),
|
||||
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
|
@ -154,18 +156,20 @@ class TensorNameMap:
|
|||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
),
|
||||
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.v", # refact
|
||||
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
),
|
||||
|
||||
# Attention output
|
||||
|
@ -175,7 +179,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
|
@ -190,6 +194,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
|
@ -215,10 +220,10 @@ class TensorNameMap:
|
|||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
|
||||
"h.{bid}.post_attention_layernorm", # bloom
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
|
@ -258,7 +263,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
|
@ -277,6 +282,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
@ -308,6 +314,7 @@ class TensorNameMap:
|
|||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
|
@ -329,7 +336,7 @@ class TensorNameMap:
|
|||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
|
@ -347,6 +354,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.9.1"
|
||||
version = "0.10.0"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
|
|
@ -95,6 +95,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
||||
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
|
|
@ -388,6 +388,7 @@ struct llm_tokenizer_bpe {
|
|||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
||||
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
|
||||
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
|
||||
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
|
||||
regex_exprs = {
|
||||
"\\p{N}",
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
|
|
417
src/llama.cpp
417
src/llama.cpp
|
@ -210,6 +210,8 @@ enum llm_arch {
|
|||
LLM_ARCH_T5,
|
||||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -255,6 +257,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_T5, "t5" },
|
||||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
@ -1296,6 +1300,43 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
|||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NEMOTRON,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_EXAONE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
@ -5235,6 +5276,23 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_4B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 32: model.type = e_model::MODEL_8B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
|
@ -5473,6 +5531,9 @@ static void llm_load_vocab(
|
|||
} else if (
|
||||
tokenizer_pre == "gpt3-finnish") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
|
||||
} else if (
|
||||
tokenizer_pre == "exaone") {
|
||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
@ -6111,9 +6172,9 @@ static bool llm_load_tensors(
|
|||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_gate_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_gate_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
} else {
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
|
||||
|
@ -6437,7 +6498,7 @@ static bool llm_load_tensors(
|
|||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
|
||||
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); //output_dens
|
||||
layer.bo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens
|
||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); //output_dens
|
||||
|
||||
layer.attn_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}); //output_norm
|
||||
layer.attn_out_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd});
|
||||
|
@ -7568,6 +7629,78 @@ static bool llm_load_tensors(
|
|||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
// optional bias tensors
|
||||
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
|
||||
// optional MLP bias
|
||||
layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.rope_freqs = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -8254,7 +8387,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON) {
|
||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
|
@ -13755,6 +13888,254 @@ struct llm_build_context {
|
|||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_nemotron() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
//GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm,
|
||||
model.layers[il].attn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm,
|
||||
model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_RELU_SQR, LLM_FFN_SEQ, cb, il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, model.output_norm_b,
|
||||
LLM_NORM, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_exaone() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
|
@ -14010,6 +14391,14 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_jais();
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
{
|
||||
result = llm.build_nemotron();
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
result = llm.build_exaone();
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
@ -17080,6 +17469,8 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
|||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
case LLM_ARCH_CODESHELL:
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
case LLM_ARCH_EXAONE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
|
@ -19010,6 +19401,22 @@ static int32_t llama_chat_apply_template_internal(
|
|||
if (add_ass) {
|
||||
ss << "Assistant:";
|
||||
}
|
||||
} else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) {
|
||||
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
|
||||
// EXAONE-3.0-7.8B-Instruct
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
if (role == "system") {
|
||||
ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
|
||||
} else if (role == "user") {
|
||||
ss << "[|user|]" << trim(message->content) << "\n";
|
||||
} else if (role == "assistant") {
|
||||
ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
|
||||
}
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "[|assistant|]";
|
||||
}
|
||||
} else {
|
||||
// template not supported
|
||||
return -1;
|
||||
|
|
|
@ -503,7 +503,7 @@ static void test_special_chars() {
|
|||
"aaaaabcccc",
|
||||
"aaaabccc",
|
||||
"aaaabccccc",
|
||||
"🔵🟠✅❌abc❌✅🟠🔵"
|
||||
"🔵🟠✅❌abc❌✅🟠🔵",
|
||||
"🔵🟠abc🟠🔵"
|
||||
}
|
||||
);
|
||||
|
|
139
tests/test-lora-conversion-inference.sh
Executable file
139
tests/test-lora-conversion-inference.sh
Executable file
|
@ -0,0 +1,139 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Array of models to iterate over
|
||||
declare -a params=(
|
||||
"Gemma2ForCausalLM 64"
|
||||
"LlamaForCausalLM 64"
|
||||
"Phi3ForCausalLM 64"
|
||||
)
|
||||
|
||||
MODELS_REPO=lora-tests
|
||||
MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
|
||||
|
||||
# Clone the Hugging Face repository if the directory does not exist
|
||||
if [ ! -d "$MODELS_REPO" ]; then
|
||||
echo "Cloning the Hugging Face repository..."
|
||||
git clone $MODELS_REPO_URL
|
||||
else
|
||||
echo "Repository already exists. Skipping clone."
|
||||
fi
|
||||
|
||||
# Array to store results to print
|
||||
results=()
|
||||
|
||||
trim_leading_whitespace() {
|
||||
local input_string="$1"
|
||||
echo "${input_string#"${input_string%%[![:space:]]*}"}"
|
||||
}
|
||||
|
||||
extract_starting_substring() {
|
||||
local reference_string="$1"
|
||||
local target_string="$2"
|
||||
|
||||
local target_length=${#target_string}
|
||||
echo "${reference_string:0:$target_length}"
|
||||
}
|
||||
|
||||
get_first_word() {
|
||||
local input_string="$1"
|
||||
read -r first_word _ <<< "$input_string"
|
||||
echo "$first_word"
|
||||
}
|
||||
|
||||
# Load the expected strings
|
||||
EXPECTED_BASE_FULL=$(cat $MODELS_REPO/data/pale_blue_dot.txt)
|
||||
EXPECTED_LORA_FULL=$(cat $MODELS_REPO/data/bohemian_rhapsody.txt)
|
||||
EXPECTED_BASE_FIRST_WORD=$(get_first_word "$EXPECTED_BASE_FULL")
|
||||
EXPECTED_LORA_FIRST_WORD=$(get_first_word "$EXPECTED_LORA_FULL")
|
||||
|
||||
run_conversion_and_inference_lora() {
|
||||
local model_name=$1
|
||||
local hidden_size=$2
|
||||
|
||||
echo -e "\n\n-------- RUNNING TEST FOR MODEL $model_name --------\n\n"
|
||||
|
||||
# Convert safetensors to gguf
|
||||
echo "Running convert_hf_to_gguf.py for $model_name with hidden_size $hidden_size..."
|
||||
python convert_hf_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
|
||||
--outfile $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
|
||||
--outtype f32
|
||||
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo "Running convert_lora_to_gguf.py for $model_name with hidden_size $hidden_size..."
|
||||
python3 convert_lora_to_gguf.py $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora \
|
||||
--base $MODELS_REPO/$model_name/hidden_size=$hidden_size/base \
|
||||
--outtype f32
|
||||
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo "Running llama-export-lora with lora for $model_name with hidden_size $hidden_size..."
|
||||
./llama-export-lora \
|
||||
-m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
|
||||
-o $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
|
||||
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf
|
||||
|
||||
# Run inference
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo "Running llama-cli without lora for $model_name with hidden_size $hidden_size..."
|
||||
OUTPUT_BASE=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
|
||||
-p "$EXPECTED_BASE_FIRST_WORD" -n 50 --seed 42 --temp 0)
|
||||
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo "Running llama-cli with hot lora for $model_name with hidden_size $hidden_size..."
|
||||
OUTPUT_LORA_HOT=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32.gguf \
|
||||
--lora $MODELS_REPO/$model_name/hidden_size=$hidden_size/lora/Lora-F32-LoRA.gguf \
|
||||
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
|
||||
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo "Running llama-cli with merged lora for $model_name with hidden_size $hidden_size..."
|
||||
OUTPUT_LORA_MERGED=$(./llama-cli -m $MODELS_REPO/$model_name/hidden_size=$hidden_size/base/Base-F32-lora-merged.gguf \
|
||||
-p "$EXPECTED_LORA_FIRST_WORD" -n 50 --seed 42 --temp 0)
|
||||
|
||||
# Remove any initial white space
|
||||
OUTPUT_BASE=$(trim_leading_whitespace "$OUTPUT_BASE")
|
||||
OUTPUT_LORA_HOT=$(trim_leading_whitespace "$OUTPUT_LORA_HOT")
|
||||
OUTPUT_LORA_MERGED=$(trim_leading_whitespace "$OUTPUT_LORA_MERGED")
|
||||
# Extract the corresponding substring from full string
|
||||
EXPECTED_BASE=$(extract_starting_substring "$EXPECTED_BASE_FULL" "$OUTPUT_BASE")
|
||||
EXPECTED_LORA=$(extract_starting_substring "$EXPECTED_LORA_FULL" "$OUTPUT_LORA_HOT")
|
||||
|
||||
# Assert output equals the expected output
|
||||
if [[ "$OUTPUT_BASE" != "$EXPECTED_BASE" ]]; then
|
||||
echo "Error: $model_name OUTPUT_BASE does not start with the expected string."
|
||||
echo -e "Out=$OUTPUT_BASE\n\nExp=$EXPECTED_BASE"
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$OUTPUT_LORA_HOT" != "$EXPECTED_LORA" ]]; then
|
||||
echo "Error: $model_name OUTPUT_LORA_HOT does not start with the expected string."
|
||||
echo -e "Out=$OUTPUT_LORA_HOT\n\nExp=$EXPECTED_LORA"
|
||||
exit 1
|
||||
fi
|
||||
if [[ "$OUTPUT_LORA_MERGED" != "$EXPECTED_LORA" ]]; then
|
||||
echo "Error: $model_name OUTPUT_LORA_MERGED does not start with the expected string."
|
||||
echo -e "Out=$OUTPUT_LORA_MERGED\n\nExp=$EXPECTED_LORA"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Store the results
|
||||
results+=("
|
||||
\n\033[1mResults for $model_name with hidden_size $hidden_size:\033[0m
|
||||
\n\033[32m • Base:\n$OUTPUT_BASE
|
||||
\n\033[34m • Lora hot:\n$OUTPUT_LORA_HOT
|
||||
\n\033[36m • Lora merged:\n$OUTPUT_LORA_MERGED
|
||||
\n \033[0m
|
||||
")
|
||||
|
||||
echo "All tests passed for $model_name with hidden_size $hidden_size!"
|
||||
}
|
||||
|
||||
# Run test for each model
|
||||
for param in "${params[@]}"; do
|
||||
run_conversion_and_inference_lora $param
|
||||
done
|
||||
|
||||
# Print results
|
||||
echo -e "\n\n---------------------------\n\n"
|
||||
echo -e "\n\033[1mSummary of All Results:\033[0m"
|
||||
for result in "${results[@]}"; do
|
||||
echo -e "$result"
|
||||
done
|
Loading…
Add table
Add a link
Reference in a new issue