merge
This commit is contained in:
parent
062f256e6b
commit
fe81134954
23 changed files with 699 additions and 597 deletions
|
@ -3,23 +3,36 @@ ARG UBUNTU_VERSION=22.04
|
|||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1
|
||||
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements requirements
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
ENV LLAMA_CURL=1
|
||||
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
cmake --build build -j $(nproc) && \
|
||||
mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib/ \;
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION as runtime
|
||||
|
||||
RUN make -j$(nproc)
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential python3 python3-pip git libcurl4-openssl-dev libgomp1
|
||||
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
COPY requirements /app/requirements
|
||||
COPY .devops/tools.sh /app/tools.sh
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel && \
|
||||
pip install -r /app/requirements.txt
|
||||
|
||||
COPY --from=build /app/build/bin/ /app/
|
||||
COPY --from=build /app/lib/ /app/
|
||||
COPY --from=build /app/convert_hf_to_gguf.py /app/
|
||||
COPY --from=build /app/gguf-py /app/gguf-py
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
ENTRYPOINT ["/app/.devops/tools.sh"]
|
||||
ENTRYPOINT ["/app/tools.sh"]
|
||||
|
|
|
@ -3,21 +3,27 @@ ARG UBUNTU_VERSION=22.04
|
|||
FROM ubuntu:$UBUNTU_VERSION AS build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git
|
||||
apt-get install -y build-essential git cmake libcurl4-openssl-dev
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make -j$(nproc) llama-cli
|
||||
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
cmake --build build -j $(nproc) && \
|
||||
mkdir -p /app/lib && \
|
||||
find build -name "*.so" -exec cp {} /app/lib/ \;
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS runtime
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y libgomp1
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=build /app/llama-cli /llama-cli
|
||||
RUN apt-get update && \
|
||||
apt-get install -y libcurl4-openssl-dev libgomp1 curl
|
||||
|
||||
COPY --from=build /app/build/bin/llama-cli /app/
|
||||
COPY --from=build /app/lib/ /app/
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
ENTRYPOINT [ "/llama-cli" ]
|
||||
ENTRYPOINT [ "/app/llama-cli" ]
|
||||
|
|
|
@ -9,28 +9,20 @@ WORKDIR /app
|
|||
|
||||
COPY . .
|
||||
|
||||
|
||||
RUN \
|
||||
# Build multiple versions of the CPU backend
|
||||
scripts/build-cpu.sh avx -DGGML_AVX=ON -DGGML_AVX2=OFF && \
|
||||
scripts/build-cpu.sh avx2 -DGGML_AVX=ON -DGGML_AVX2=ON && \
|
||||
scripts/build-cpu.sh avx512 -DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_AVX512=ON && \
|
||||
scripts/build-cpu.sh amx -DGGML_AVX=ON -DGGML_AVX2=ON -DGGML_AVX512=ON -DGGML_AVX_VNNI=ON -DGGML_AVX512_VNNI=ON -DGGML_AMX_TILE=ON -DGGML_AMX_INT8=ON && \
|
||||
# Build llama-server
|
||||
cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
cmake --build build --target llama-server -j $(nproc) && \
|
||||
# Copy the built libraries to /app/lib
|
||||
RUN cmake -S . -B build -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_CURL=ON -DCMAKE_BUILD_TYPE=Release && \
|
||||
cmake --build build -j $(nproc) && \
|
||||
mkdir -p /app/lib && \
|
||||
mv libggml-cpu* /app/lib/ && \
|
||||
find build -name "*.so" -exec cp {} /app/lib/ \;
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION AS runtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y libcurl4-openssl-dev libgomp1 curl
|
||||
|
||||
COPY --from=build /app/build/bin/llama-server /llama-server
|
||||
COPY --from=build /app/lib/ /
|
||||
COPY --from=build /app/build/bin/llama-server /app/
|
||||
COPY --from=build /app/lib/ /app/
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
# Must be set to 0.0.0.0 so it can listen to requests from host machine
|
||||
|
@ -38,4 +30,4 @@ ENV LLAMA_ARG_HOST=0.0.0.0
|
|||
|
||||
HEALTHCHECK CMD [ "curl", "-f", "http://localhost:8080/health" ]
|
||||
|
||||
ENTRYPOINT [ "/llama-server" ]
|
||||
ENTRYPOINT [ "/app/llama-server" ]
|
||||
|
|
|
@ -1831,29 +1831,40 @@ class MiniCPMModel(Model):
|
|||
model_arch = gguf.MODEL_ARCH.MINICPM
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
super().set_gguf_parameters()
|
||||
embedding_scale = float(self.hparams["scale_emb"])
|
||||
self.gguf_writer.add_embedding_scale(embedding_scale)
|
||||
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
|
||||
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
|
||||
self.gguf_writer.add_residual_scale(residual_scale)
|
||||
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
|
||||
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
|
||||
self.gguf_writer.add_logit_scale(logit_scale)
|
||||
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
|
||||
if self.hparams.get("rope_scaling") is not None:
|
||||
if self.hparams["rope_scaling"].get("type") == "longrope":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
|
||||
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is not None:
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
|
||||
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
|
||||
return (
|
||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape)
|
||||
)
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
@ -1863,9 +1874,9 @@ class MiniCPMModel(Model):
|
|||
|
||||
# HF models permute some of the tensors, so we need to undo that
|
||||
if name.endswith(("q_proj.weight")):
|
||||
data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight")):
|
||||
data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
|
||||
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
|
|
@ -921,6 +921,8 @@ struct server_context {
|
|||
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
|
||||
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
||||
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
|
||||
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
||||
|
||||
if (slot.params.sampling.dry_base < 1.0f) {
|
||||
slot.params.sampling.dry_base = defaults.sampling.dry_base;
|
||||
|
@ -2322,10 +2324,29 @@ struct server_context {
|
|||
continue;
|
||||
}
|
||||
|
||||
// determine the max draft that fits the current slot state
|
||||
int n_draft_max = slot.params.speculative.n_max;
|
||||
|
||||
// note: n_past is not yet increased for the `id` token sampled above
|
||||
// also, need to leave space for 1 extra token to allow context shifts
|
||||
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
||||
|
||||
if (slot.n_remaining > 0) {
|
||||
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
||||
|
||||
if (n_draft_max < slot.params.speculative.n_min) {
|
||||
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_token id = slot.sampled;
|
||||
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = slot.params.speculative.n_max;
|
||||
params_spec.n_draft = n_draft_max;
|
||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||
params_spec.p_min = slot.params.speculative.p_min;
|
||||
|
||||
|
@ -2333,6 +2354,8 @@ struct server_context {
|
|||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -2344,6 +2367,8 @@ struct server_context {
|
|||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||
|
||||
llama_decode(ctx, slot.batch_spec);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
|
@ -2372,7 +2397,7 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
|
||||
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
|
|||
last_content = res.body["content"]
|
||||
|
||||
|
||||
def test_slot_ctx_not_exceeded():
|
||||
global server
|
||||
server.n_ctx = 64
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Hello " * 56,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"speculative.p_min": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["content"]) > 0
|
||||
|
||||
|
||||
def test_with_ctx_shift():
|
||||
global server
|
||||
server.n_ctx = 64
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "Hello " * 56,
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"n_predict": 64,
|
||||
"speculative.p_min": 0.0,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert len(res.body["content"]) > 0
|
||||
assert res.body["tokens_predicted"] == 64
|
||||
assert res.body["truncated"] == True
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
|
|
|
@ -94,28 +94,31 @@ endif()
|
|||
|
||||
option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF)
|
||||
option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
|
||||
|
||||
option(GGML_AVX "ggml: enable AVX" ${INS_ENB})
|
||||
option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF)
|
||||
option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB})
|
||||
option(GGML_AVX512 "ggml: enable AVX512" OFF)
|
||||
option(GGML_AVX512 "ggml: enable AVX512F" OFF)
|
||||
option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI" OFF)
|
||||
option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI" OFF)
|
||||
option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16" OFF)
|
||||
if (NOT MSVC)
|
||||
# in MSVC F16C and FMA is implied with AVX2/AVX512
|
||||
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
||||
option(GGML_F16C "ggml: enable F16C" ${INS_ENB})
|
||||
# MSVC does not seem to support AMX
|
||||
option(GGML_AMX_TILE "ggml: enable AMX-TILE" OFF)
|
||||
option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF)
|
||||
option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF)
|
||||
option(GGML_FMA "ggml: enable FMA" ${INS_ENB})
|
||||
if (NOT MSVC)
|
||||
option(GGML_F16C "ggml: enable F16C" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
|
||||
endif()
|
||||
option(GGML_LASX "ggml: enable lasx" ON)
|
||||
option(GGML_LSX "ggml: enable lsx" ON)
|
||||
option(GGML_RVV "ggml: enable rvv" ON)
|
||||
option(GGML_SVE "ggml: enable SVE" OFF)
|
||||
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
|
||||
|
||||
|
||||
if (WIN32)
|
||||
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")
|
||||
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version")
|
||||
endif()
|
||||
|
||||
# ggml core
|
||||
|
@ -180,11 +183,7 @@ option(GGML_BUILD_EXAMPLES "ggml: build examples" ${GGML_STANDALONE})
|
|||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
|
||||
if (GGML_SYCL)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
else()
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
endif()
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
|
|
|
@ -269,7 +269,42 @@ function(ggml_add_backend backend)
|
|||
endif()
|
||||
endfunction()
|
||||
|
||||
function(ggml_add_cpu_backend_variant tag_name)
|
||||
set(GGML_CPU_TAG_NAME ${tag_name})
|
||||
# other: OPENMP LLAMAFILE CPU_HBM
|
||||
foreach (feat NATIVE
|
||||
AVX AVX2 AVX_VNNI FMA F16C
|
||||
AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
|
||||
AMX_TILE AMX_INT8 AMX_BF16)
|
||||
set(GGML_${feat} OFF)
|
||||
endforeach()
|
||||
|
||||
foreach (feat ${ARGN})
|
||||
set(GGML_${feat} ON)
|
||||
endforeach()
|
||||
|
||||
ggml_add_cpu_backend_variant_impl(${tag_name})
|
||||
endfunction()
|
||||
|
||||
ggml_add_backend(CPU)
|
||||
|
||||
if (GGML_CPU_ALL_VARIANTS)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
|
||||
endif()
|
||||
ggml_add_cpu_backend_variant(sandybridge AVX)
|
||||
ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 FMA)
|
||||
ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 FMA AVX512)
|
||||
ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
|
||||
if (NOT MSVC)
|
||||
# MSVC doesn't support AVX-VNNI or AMX
|
||||
ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 FMA AVX_VNNI)
|
||||
ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
|
||||
endif()
|
||||
else ()
|
||||
ggml_add_cpu_backend_variant_impl("")
|
||||
endif()
|
||||
|
||||
ggml_add_backend(BLAS)
|
||||
ggml_add_backend(CANN)
|
||||
ggml_add_backend(CUDA)
|
||||
|
|
|
@ -483,6 +483,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent)
|
|||
best_score = s;
|
||||
best_path = entry.path().string();
|
||||
}
|
||||
} else {
|
||||
if (!silent) {
|
||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -505,15 +509,21 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent)
|
|||
}
|
||||
|
||||
void ggml_backend_load_all() {
|
||||
ggml_backend_load_best("blas", true);
|
||||
ggml_backend_load_best("cann", true);
|
||||
ggml_backend_load_best("cuda", true);
|
||||
ggml_backend_load_best("hip", true);
|
||||
ggml_backend_load_best("kompute", true);
|
||||
ggml_backend_load_best("metal", true);
|
||||
ggml_backend_load_best("rpc", true);
|
||||
ggml_backend_load_best("sycl", true);
|
||||
ggml_backend_load_best("vulkan", true);
|
||||
ggml_backend_load_best("musa", true);
|
||||
ggml_backend_load_best("cpu", true);
|
||||
#ifdef NDEBUG
|
||||
bool silent = true;
|
||||
#else
|
||||
bool silent = false;
|
||||
#endif
|
||||
|
||||
ggml_backend_load_best("blas", silent);
|
||||
ggml_backend_load_best("cann", silent);
|
||||
ggml_backend_load_best("cuda", silent);
|
||||
ggml_backend_load_best("hip", silent);
|
||||
ggml_backend_load_best("kompute", silent);
|
||||
ggml_backend_load_best("metal", silent);
|
||||
ggml_backend_load_best("rpc", silent);
|
||||
ggml_backend_load_best("sycl", silent);
|
||||
ggml_backend_load_best("vulkan", silent);
|
||||
ggml_backend_load_best("musa", silent);
|
||||
ggml_backend_load_best("cpu", silent);
|
||||
}
|
||||
|
|
|
@ -1,32 +1,39 @@
|
|||
ggml_add_backend_library(ggml-cpu)
|
||||
function(ggml_add_cpu_backend_variant_impl tag_name)
|
||||
if (tag_name)
|
||||
set(GGML_CPU_NAME ggml-cpu-${tag_name})
|
||||
else()
|
||||
set(GGML_CPU_NAME ggml-cpu)
|
||||
endif()
|
||||
|
||||
ggml_add_backend_library(${GGML_CPU_NAME})
|
||||
|
||||
list (APPEND GGML_CPU_SOURCES
|
||||
ggml-cpu.c
|
||||
ggml-cpu.cpp
|
||||
ggml-cpu-aarch64.c
|
||||
ggml-cpu-aarch64.h
|
||||
ggml-cpu-quants.c
|
||||
ggml-cpu-quants.h
|
||||
amx/amx.cpp
|
||||
amx/amx.h
|
||||
amx/mmq.cpp
|
||||
amx/mmq.h
|
||||
ggml-cpu-impl.h
|
||||
ggml-cpu/ggml-cpu.c
|
||||
ggml-cpu/ggml-cpu.cpp
|
||||
ggml-cpu/ggml-cpu-aarch64.c
|
||||
ggml-cpu/ggml-cpu-aarch64.h
|
||||
ggml-cpu/ggml-cpu-quants.c
|
||||
ggml-cpu/ggml-cpu-quants.h
|
||||
ggml-cpu/amx/amx.cpp
|
||||
ggml-cpu/amx/amx.h
|
||||
ggml-cpu/amx/mmq.cpp
|
||||
ggml-cpu/amx/mmq.h
|
||||
ggml-cpu/ggml-cpu-impl.h
|
||||
)
|
||||
|
||||
target_compile_features(ggml-cpu PRIVATE c_std_11 cxx_std_17)
|
||||
target_include_directories(ggml-cpu PRIVATE .)
|
||||
target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17)
|
||||
target_include_directories(${GGML_CPU_NAME} PRIVATE . ggml-cpu)
|
||||
|
||||
if (APPLE AND GGML_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
message(STATUS "Accelerate framework found")
|
||||
|
||||
target_compile_definitions(ggml-cpu PRIVATE GGML_USE_ACCELERATE)
|
||||
target_compile_definitions(ggml-cpu PRIVATE ACCELERATE_NEW_LAPACK)
|
||||
target_compile_definitions(ggml-cpu PRIVATE ACCELERATE_LAPACK_ILP64)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_ACCELERATE)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_NEW_LAPACK)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ACCELERATE_LAPACK_ILP64)
|
||||
|
||||
target_link_libraries(ggml-cpu PRIVATE ${ACCELERATE_FRAMEWORK})
|
||||
target_link_libraries(${GGML_CPU_NAME} PRIVATE ${ACCELERATE_FRAMEWORK})
|
||||
else()
|
||||
message(WARNING "Accelerate framework not found")
|
||||
endif()
|
||||
|
@ -35,24 +42,20 @@ endif()
|
|||
if (GGML_OPENMP)
|
||||
find_package(OpenMP)
|
||||
if (OpenMP_FOUND)
|
||||
message(STATUS "OpenMP found")
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP)
|
||||
|
||||
target_compile_definitions(ggml-cpu PRIVATE GGML_USE_OPENMP)
|
||||
|
||||
target_link_libraries(ggml-cpu PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
else()
|
||||
message(WARNING "OpenMP not found")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_LLAMAFILE)
|
||||
message(STATUS "Using llamafile")
|
||||
|
||||
target_compile_definitions(ggml-cpu PRIVATE GGML_USE_LLAMAFILE)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_LLAMAFILE)
|
||||
|
||||
list(APPEND GGML_CPU_SOURCES
|
||||
llamafile/sgemm.cpp
|
||||
llamafile/sgemm.h)
|
||||
ggml-cpu/llamafile/sgemm.cpp
|
||||
ggml-cpu/llamafile/sgemm.h)
|
||||
endif()
|
||||
|
||||
if (GGML_CPU_HBM)
|
||||
|
@ -60,9 +63,9 @@ if (GGML_CPU_HBM)
|
|||
|
||||
message(STATUS "Using memkind for CPU HBM")
|
||||
|
||||
target_compile_definitions(ggml-cpu PRIVATE GGML_USE_CPU_HBM)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_HBM)
|
||||
|
||||
target_link_libraries(ggml-cpu PUBLIC memkind)
|
||||
target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind)
|
||||
endif()
|
||||
|
||||
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
||||
|
@ -173,18 +176,19 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
|||
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
||||
message(STATUS "x86 detected")
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
if (GGML_NATIVE)
|
||||
include(cmake/FindSIMD.cmake)
|
||||
include(ggml-cpu/cmake/FindSIMD.cmake)
|
||||
endif ()
|
||||
if (GGML_AVX512)
|
||||
list(APPEND ARCH_FLAGS /arch:AVX512)
|
||||
# /arch:AVX512 includes: __AVX512F__, __AVX512CD__, __AVX512BW__, __AVX512DQ__, and __AVX512VL__
|
||||
# MSVC has no compile-time flags enabling specific
|
||||
# AVX512 extensions, neither it defines the
|
||||
# macros corresponding to the extensions.
|
||||
# Do it manually.
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX512)
|
||||
if (GGML_AVX512_VBMI)
|
||||
list(APPEND ARCH_DEFINITIONS __AVX512VBMI__)
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
|
@ -192,78 +196,98 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||
endif()
|
||||
endif()
|
||||
if (GGML_AVX512_VNNI)
|
||||
list(APPEND ARCH_DEFINITIONS __AVX512VNNI__)
|
||||
list(APPEND ARCH_DEFINITIONS __AVX512VNNI__ GGML_AVX512_VNNI)
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
list(APPEND ARCH_FLAGS -mavx512vnni)
|
||||
endif()
|
||||
endif()
|
||||
if (GGML_AVX512_BF16)
|
||||
list(APPEND ARCH_DEFINITIONS __AVX512BF16__)
|
||||
list(APPEND ARCH_DEFINITIONS __AVX512BF16__ GGML_AVX512_BF16)
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
list(APPEND ARCH_FLAGS -mavx512bf16)
|
||||
endif()
|
||||
endif()
|
||||
if (GGML_AMX_TILE)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_TILE__)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_TILE__ GGML_AMX_TILE)
|
||||
endif()
|
||||
if (GGML_AMX_INT8)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_INT8__)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_INT8__ GGML_AMX_INT8)
|
||||
endif()
|
||||
if (GGML_AMX_BF16)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_BF16__)
|
||||
list(APPEND ARCH_DEFINITIONS __AMX_BF16__ GGML_AMX_BF16)
|
||||
endif()
|
||||
elseif (GGML_AVX2)
|
||||
list(APPEND ARCH_FLAGS /arch:AVX2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX2 GGML_FMA GGML_F16C)
|
||||
elseif (GGML_AVX)
|
||||
list(APPEND ARCH_FLAGS /arch:AVX)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX)
|
||||
else ()
|
||||
list(APPEND ARCH_FLAGS /arch:SSE4.2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
||||
endif()
|
||||
if (GGML_AVX_VNNI)
|
||||
list(APPEND ARCH_DEFINITIONS __AVXVNNI__)
|
||||
if (CMAKE_C_COMPILER_ID STREQUAL "Clang")
|
||||
list(APPEND ARCH_FLAGS -mavxvnni)
|
||||
endif()
|
||||
# MSVC generates AVX512 with AVX-VNNI intrinsics even with /arch:AVX2
|
||||
#list(APPEND ARCH_DEFINITIONS __AVXVNNI__ GGML_AVX_VNNI)
|
||||
endif()
|
||||
else ()
|
||||
if (GGML_NATIVE)
|
||||
list(APPEND ARCH_FLAGS -march=native)
|
||||
endif()
|
||||
else ()
|
||||
list(APPEND ARCH_FLAGS -msse4.2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_SSE42)
|
||||
if (GGML_F16C)
|
||||
list(APPEND ARCH_FLAGS -mf16c)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_F16C)
|
||||
endif()
|
||||
if (GGML_FMA)
|
||||
list(APPEND ARCH_FLAGS -mfma)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_FMA)
|
||||
endif()
|
||||
if (GGML_AVX)
|
||||
list(APPEND ARCH_FLAGS -mavx)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX)
|
||||
endif()
|
||||
if (GGML_AVX2)
|
||||
list(APPEND ARCH_FLAGS -mavx2)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX2)
|
||||
endif()
|
||||
if (GGML_AVX_VNNI)
|
||||
list(APPEND ARCH_FLAGS -mavxvnni)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX_VNNI)
|
||||
endif()
|
||||
if (GGML_AVX512)
|
||||
list(APPEND ARCH_FLAGS -mavx512f)
|
||||
list(APPEND ARCH_FLAGS -mavx512cd)
|
||||
list(APPEND ARCH_FLAGS -mavx512vl)
|
||||
list(APPEND ARCH_FLAGS -mavx512dq)
|
||||
list(APPEND ARCH_FLAGS -mavx512bw)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX512)
|
||||
endif()
|
||||
if (GGML_AVX512_VBMI)
|
||||
list(APPEND ARCH_FLAGS -mavx512vbmi)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX512_VBMI)
|
||||
endif()
|
||||
if (GGML_AVX512_VNNI)
|
||||
list(APPEND ARCH_FLAGS -mavx512vnni)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX512_VNNI)
|
||||
endif()
|
||||
if (GGML_AVX512_BF16)
|
||||
list(APPEND ARCH_FLAGS -mavx512bf16)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AVX512_BF16)
|
||||
endif()
|
||||
if (GGML_AMX_TILE)
|
||||
list(APPEND ARCH_FLAGS -mamx-tile)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AMX_TILE)
|
||||
endif()
|
||||
if (GGML_AMX_INT8)
|
||||
list(APPEND ARCH_FLAGS -mamx-int8)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AMX_INT8)
|
||||
endif()
|
||||
if (GGML_AMX_BF16)
|
||||
list(APPEND ARCH_FLAGS -mamx-bf16)
|
||||
list(APPEND ARCH_DEFINITIONS GGML_AMX_BF16)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||
|
@ -302,18 +326,29 @@ else()
|
|||
endif()
|
||||
|
||||
if (GGML_CPU_AARCH64)
|
||||
message(STATUS "Using runtime weight conversion of Q4_0 to Q4_0_x_x to enable optimized GEMM/GEMV kernels")
|
||||
target_compile_definitions(ggml-cpu PRIVATE GGML_USE_CPU_AARCH64)
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
|
||||
endif()
|
||||
|
||||
target_sources(ggml-cpu PRIVATE ${GGML_CPU_SOURCES})
|
||||
set_source_files_properties(${GGML_CPU_SOURCES} PROPERTIES COMPILE_OPTIONS "${ARCH_FLAGS}")
|
||||
set_source_files_properties(${GGML_CPU_SOURCES} PROPERTIES COMPILE_DEFINITIONS "${ARCH_DEFINITIONS}")
|
||||
message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
|
||||
target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
|
||||
target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
|
||||
target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
||||
|
||||
# the feature detection code must be compiled without any architecture flags
|
||||
target_sources(ggml-cpu PRIVATE cpu-feats-x86.cpp)
|
||||
# target_sources(ggml-cpu PRIVATE cpu-feats-arm.cpp) # TODO: ARM feature detection
|
||||
if (GGML_BACKEND_DL)
|
||||
# The feature detection code is compiled as a separate target so that
|
||||
# it can be built without the architecture flags
|
||||
# Since multiple variants of the CPU backend may be included in the same
|
||||
# build, using set_source_files_properties() to set the arch flags is not possible
|
||||
set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats)
|
||||
add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp)
|
||||
target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include)
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS})
|
||||
target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED)
|
||||
set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME})
|
||||
endif()
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
set_target_properties(ggml-cpu PROPERTIES COMPILE_FLAGS "-msimd128")
|
||||
set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128")
|
||||
endif()
|
||||
endfunction()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend-impl.h"
|
||||
|
||||
#if defined(__x86_64__) || (defined(_MSC_VER) && defined(_M_AMD64))
|
||||
|
@ -13,6 +12,7 @@
|
|||
#include <array>
|
||||
#include <string>
|
||||
|
||||
// ref: https://cdrdv2-public.intel.com/782156/325383-sdm-vol-2abcd.pdf
|
||||
struct cpuid_x86 {
|
||||
bool SSE3(void) { return f_1_ecx[0]; }
|
||||
bool PCLMULQDQ(void) { return f_1_ecx[1]; }
|
||||
|
@ -50,11 +50,15 @@ struct cpuid_x86 {
|
|||
bool INVPCID(void) { return f_7_ebx[10]; }
|
||||
bool RTM(void) { return is_intel && f_7_ebx[11]; }
|
||||
bool AVX512F(void) { return f_7_ebx[16]; }
|
||||
bool AVX512DQ(void) { return f_7_ebx[17]; }
|
||||
bool RDSEED(void) { return f_7_ebx[18]; }
|
||||
bool ADX(void) { return f_7_ebx[19]; }
|
||||
bool AVX512PF(void) { return f_7_ebx[26]; }
|
||||
bool AVX512ER(void) { return f_7_ebx[27]; }
|
||||
bool AVX512CD(void) { return f_7_ebx[28]; }
|
||||
bool AVX512BW(void) { return f_7_ebx[30]; }
|
||||
bool AVX512VL(void) { return f_7_ebx[31]; }
|
||||
|
||||
bool SHA(void) { return f_7_ebx[29]; }
|
||||
|
||||
bool PREFETCHWT1(void) { return f_7_ecx[0]; }
|
||||
|
@ -259,36 +263,57 @@ void test_x86_is() {
|
|||
static int ggml_backend_cpu_x86_score() {
|
||||
// FIXME: this does not check for OS support
|
||||
|
||||
cpuid_x86 is;
|
||||
// if the CPU backend was built with any features not supported by the current CPU, it cannot be used
|
||||
if (ggml_cpu_has_fma() && !is.FMA()) { return 0; }
|
||||
if (ggml_cpu_has_f16c() && !is.F16C()) { return 0; }
|
||||
if (ggml_cpu_has_ssse3() && !is.SSSE3()) { return 0; }
|
||||
if (ggml_cpu_has_sse3() && !is.SSE3()) { return 0; }
|
||||
if (ggml_cpu_has_avx() && !is.AVX()) { return 0; }
|
||||
if (ggml_cpu_has_avx_vnni() && !is.AVX_VNNI()) { return 0; }
|
||||
if (ggml_cpu_has_avx2() && !is.AVX2()) { return 0; }
|
||||
if (ggml_cpu_has_avx512() && !is.AVX512F()) { return 0; }
|
||||
if (ggml_cpu_has_avx512_vbmi() && !is.AVX512_VBMI()) { return 0; }
|
||||
if (ggml_cpu_has_avx512_bf16() && !is.AVX512_BF16()) { return 0; }
|
||||
if (ggml_cpu_has_avx512_vnni() && !is.AVX512_VNNI()) { return 0; }
|
||||
if (ggml_cpu_has_amx_int8() && !is.AMX_INT8()) { return 0; }
|
||||
|
||||
// calculate a backend score based on the supported features
|
||||
// more important features have a higher weight
|
||||
int score = 0;
|
||||
score += ggml_cpu_has_fma () * 1;
|
||||
score += ggml_cpu_has_f16c () * 1<<1;
|
||||
score += ggml_cpu_has_ssse3 () * 1<<2;
|
||||
score += ggml_cpu_has_sse3 () * 1<<3;
|
||||
score += ggml_cpu_has_avx_vnni () * 1<<4;
|
||||
score += ggml_cpu_has_avx () * 1<<5;
|
||||
score += ggml_cpu_has_avx2 () * 1<<6;
|
||||
score += ggml_cpu_has_avx512 () * 1<<7;
|
||||
// score += ggml_cpu_has_avx512_vbmi() * 1<<8; // not used
|
||||
score += ggml_cpu_has_avx512_bf16() * 1<<9;
|
||||
score += ggml_cpu_has_avx512_vnni() * 1<<10;
|
||||
score += ggml_cpu_has_amx_int8 () * 1<<11;
|
||||
cpuid_x86 is;
|
||||
|
||||
#ifdef GGML_FMA
|
||||
if (!is.FMA()) { return 0; }
|
||||
score += 1;
|
||||
#endif
|
||||
#ifdef GGML_F16C
|
||||
if (!is.F16C()) { return 0; }
|
||||
score += 1<<1;
|
||||
#endif
|
||||
#ifdef GGML_SSE42
|
||||
if (!is.SSE42()) { return 0; }
|
||||
score += 1<<2;
|
||||
#endif
|
||||
#ifdef GGML_AVX
|
||||
if (!is.AVX()) { return 0; }
|
||||
score += 1<<4;
|
||||
#endif
|
||||
#ifdef GGML_AVX2
|
||||
if (!is.AVX2()) { return 0; }
|
||||
score += 1<<5;
|
||||
#endif
|
||||
#ifdef GGML_AVX_VNNI
|
||||
if (!is.AVX_VNNI()) { return 0; }
|
||||
score += 1<<6;
|
||||
#endif
|
||||
#ifdef GGML_AVX512
|
||||
if (!is.AVX512F()) { return 0; }
|
||||
if (!is.AVX512CD()) { return 0; }
|
||||
if (!is.AVX512VL()) { return 0; }
|
||||
if (!is.AVX512DQ()) { return 0; }
|
||||
if (!is.AVX512BW()) { return 0; }
|
||||
score += 1<<7;
|
||||
#endif
|
||||
#ifdef GGML_AVX512_VBMI
|
||||
if (!is.AVX512_VBMI()) { return 0; }
|
||||
score += 1<<8;
|
||||
#endif
|
||||
#ifdef GGML_AVX512_BF16
|
||||
if (!is.AVX512_BF16()) { return 0; }
|
||||
score += 1<<9;
|
||||
#endif
|
||||
#ifdef GGML_AVX512_VNNI
|
||||
if (!is.AVX512_VNNI()) { return 0; }
|
||||
score += 1<<10;
|
||||
#endif
|
||||
#ifdef GGML_AMX_INT8
|
||||
if (!is.AMX_INT8()) { return 0; }
|
||||
score += 1<<11;
|
||||
#endif
|
||||
|
||||
return score;
|
||||
}
|
||||
|
|
|
@ -756,7 +756,7 @@ do { \
|
|||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
||||
#else
|
||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||
static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) {
|
||||
float tmp[8];
|
||||
|
||||
for (int i = 0; i < 8; i++) {
|
||||
|
@ -2425,7 +2425,7 @@ bool ggml_is_numa(void) {
|
|||
#endif
|
||||
|
||||
#if !defined(HWCAP2_I8MM)
|
||||
#define HWCAP2_I8MM 0
|
||||
#define HWCAP2_I8MM (1 << 13)
|
||||
#endif
|
||||
|
||||
static void ggml_init_arm_arch_features(void) {
|
||||
|
|
|
@ -641,7 +641,15 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
|
|||
if (ggml_cpu_has_llamafile()) {
|
||||
features.push_back({ "LLAMAFILE", "1" });
|
||||
}
|
||||
// TODO: rename this
|
||||
#ifdef GGML_USE_ACCELERATE
|
||||
features.push_back({ "ACCELERATE", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
features.push_back({ "CPU_HBM", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_OPENMP
|
||||
features.push_back({ "OPENMP", "1" });
|
||||
#endif
|
||||
#ifdef GGML_USE_CPU_AARCH64
|
||||
features.push_back({ "AARCH64_REPACK", "1" });
|
||||
#endif
|
||||
|
|
|
@ -353,7 +353,45 @@ struct vk_op_unary_push_constants {
|
|||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||
uint32_t d_offset;
|
||||
float param1; float param2;
|
||||
uint32_t ne0_012mp; uint32_t ne0_012L;
|
||||
uint32_t ne0_01mp; uint32_t ne0_01L;
|
||||
uint32_t ne0_0mp; uint32_t ne0_0L;
|
||||
uint32_t ne1_012mp; uint32_t ne1_012L;
|
||||
uint32_t ne1_01mp; uint32_t ne1_01L;
|
||||
uint32_t ne1_0mp; uint32_t ne1_0L;
|
||||
};
|
||||
static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
|
||||
|
||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||
// Precompute mp (m' in the paper) and L such that division
|
||||
// can be computed using a multiply (high 32b of 64b result)
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
|
||||
{
|
||||
// compute L = ceil(log2(d));
|
||||
L = 0;
|
||||
while (L < 32 && (uint32_t{1} << L) < d) {
|
||||
L++;
|
||||
}
|
||||
|
||||
mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
|
||||
}
|
||||
|
||||
template <typename T> void init_pushconst_fastdiv(T &p) {
|
||||
static_assert(!std::is_const<T>::value, "unexpected type");
|
||||
}
|
||||
|
||||
template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
|
||||
// Compute magic values to divide by these six numbers.
|
||||
init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
|
||||
init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
|
||||
init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
|
||||
init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
|
||||
init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
|
||||
init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
|
||||
}
|
||||
|
||||
struct vk_op_binary_push_constants {
|
||||
uint32_t ne;
|
||||
|
@ -2914,13 +2952,14 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|||
elements = { ne, 1, 1 };
|
||||
}
|
||||
|
||||
const vk_op_unary_push_constants pc = {
|
||||
vk_op_unary_push_constants pc = {
|
||||
(uint32_t)ne,
|
||||
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
|
||||
(uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
|
||||
0,
|
||||
0.0f, 0.0f,
|
||||
};
|
||||
init_pushconst_fastdiv(pc);
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
||||
}
|
||||
|
@ -4125,7 +4164,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|||
}
|
||||
|
||||
template<typename PC>
|
||||
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
|
||||
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
|
||||
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||
if (src1 != nullptr) {
|
||||
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
||||
|
@ -4165,6 +4204,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|||
const uint64_t ned3 = dst->ne[3];
|
||||
const uint64_t ned = ned0 * ned1;
|
||||
|
||||
init_pushconst_fastdiv(pc);
|
||||
|
||||
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
|
||||
|
||||
if (pipeline == nullptr) {
|
||||
|
|
|
@ -31,7 +31,7 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f);
|
||||
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -46,7 +46,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|||
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||
const float m = float(data_a_packed16[a_offset + ib].m);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
|
||||
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) * d + m;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -63,7 +63,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|||
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f);
|
||||
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -83,7 +83,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
|||
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
|
||||
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -95,16 +95,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
|
||||
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
|
||||
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF));
|
||||
return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
float iq_helper(uint i) {
|
||||
const float x = float(i);
|
||||
return round(((0.080958*x-1.875836)*x+25.907107)*x-127.663571);
|
||||
}
|
||||
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
const float d = float(data_a[a_offset + ib].d);
|
||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||
|
@ -112,6 +107,6 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
|||
}
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||
return vec4(iq_helper(vui & 0xF), iq_helper((vui >> 4) & 0xF), iq_helper((vui >> 8) & 0xF), iq_helper((vui >> 12) & 0xF));
|
||||
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -8,6 +8,13 @@ layout (push_constant) uniform parameter
|
|||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint d_offset;
|
||||
float param1; float param2;
|
||||
|
||||
uint ne0_012mp; uint ne0_012L;
|
||||
uint ne0_01mp; uint ne0_01L;
|
||||
uint ne0_0mp; uint ne0_0L;
|
||||
uint ne1_012mp; uint ne1_012L;
|
||||
uint ne1_01mp; uint ne1_01L;
|
||||
uint ne1_0mp; uint ne1_0L;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
|
@ -17,22 +24,30 @@ uint get_idx() {
|
|||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
}
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
uint src0_idx(uint idx) {
|
||||
const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
|
||||
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
|
||||
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||
const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
|
||||
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
|
||||
const uint i02_offset = i02*p.ne01*p.ne00;
|
||||
const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
|
||||
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
|
||||
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
||||
return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
|
||||
}
|
||||
|
||||
uint dst_idx(uint idx) {
|
||||
const uint i13 = idx / (p.ne12*p.ne11*p.ne10);
|
||||
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
|
||||
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
|
||||
const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10);
|
||||
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
|
||||
const uint i12_offset = i12*p.ne11*p.ne10;
|
||||
const uint i11 = (idx - i13_offset - i12_offset) / p.ne10;
|
||||
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
|
||||
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
|
||||
return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
|
||||
}
|
||||
|
|
|
@ -59,7 +59,6 @@ void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_
|
|||
ibi += p.ncols;
|
||||
|
||||
#if K_PER_ITER == 8
|
||||
// TODO: can we dequant as f16 instead of as vec?
|
||||
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
||||
FLOAT_TYPE rowtmp = 0;
|
||||
|
|
|
@ -896,6 +896,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
|
@ -1391,6 +1393,7 @@ class RopeScalingType(Enum):
|
|||
NONE = 'none'
|
||||
LINEAR = 'linear'
|
||||
YARN = 'yarn'
|
||||
LONGROPE = 'longrope'
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
|
|
|
@ -185,7 +185,8 @@ extern "C" {
|
|||
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
|
||||
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
|
||||
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
|
||||
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
|
||||
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
|
||||
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE,
|
||||
};
|
||||
|
||||
enum llama_pooling_type {
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
name="$1"
|
||||
args="${@:2}"
|
||||
|
||||
echo "Building $name with args: $args"
|
||||
|
||||
rm -fr build-cpu-$1
|
||||
cmake -S . -B build-cpu-$1 -DGGML_BACKEND_DL=ON -DGGML_NATIVE=OFF $args
|
||||
cmake --build build-cpu-$1 --config Release -t ggml-cpu -j $(nproc)
|
||||
cp build-cpu-$1/bin/libggml-cpu.so ./libggml-cpu-$1.so
|
||||
rm -fr build-cpu-$1
|
167
src/llama.cpp
167
src/llama.cpp
|
@ -1036,6 +1036,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
|
@ -1686,6 +1688,7 @@ static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_
|
|||
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
|
||||
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
|
||||
};
|
||||
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
|
@ -5580,8 +5583,12 @@ static void llm_load_hparams(
|
|||
case LLM_ARCH_MINICPM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
|
||||
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
|
||||
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 52: model.type = e_model::MODEL_1B; break;
|
||||
case 40: model.type = e_model::MODEL_2B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
|
@ -7065,7 +7072,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
|
||||
if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
|
||||
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
|
||||
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
|
||||
|
@ -7690,7 +7697,13 @@ static bool llm_load_tensors(
|
|||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
else {
|
||||
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
|
||||
if (n_expert == 0) {
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
|
@ -13497,153 +13510,6 @@ struct llm_build_context {
|
|||
return gf;
|
||||
}
|
||||
|
||||
// ref: https://arxiv.org/abs/2203.03466
|
||||
// https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
|
||||
// based on the original build_llama() function
|
||||
struct ggml_cgraph * build_minicpm() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
//TODO: if the model varies, these parameters need to be read from the model
|
||||
const int64_t n_embd_base = 256;
|
||||
const float scale_embd = 12.0f;
|
||||
const float scale_depth = 1.4f;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
// scale the input embeddings
|
||||
inpL = ggml_scale(ctx0, inpL, scale_embd);
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// scale_res - scale the hidden states for residual connection
|
||||
const float scale_res = scale_depth/sqrtf(float(n_layer));
|
||||
cur = ggml_scale(ctx0, cur, scale_res);
|
||||
cb(cur, "hidden_scaled", -1);
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// scale the hidden states for residual connection
|
||||
cur = ggml_scale(ctx0, cur, scale_res);
|
||||
cb(cur, "hidden_scaled_ffn", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head scaling
|
||||
const float scale_lmhead = float(n_embd_base)/float(n_embd);
|
||||
cur = ggml_scale(ctx0, cur, scale_lmhead);
|
||||
cb(cur, "lmhead_scaling", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_minicpm3() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
|
@ -16742,6 +16608,7 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
|
||||
switch (model.arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
{
|
||||
|
@ -16825,10 +16692,6 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_internlm2();
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM:
|
||||
{
|
||||
result = llm.build_minicpm();
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
result = llm.build_minicpm3();
|
||||
|
|
|
@ -3862,6 +3862,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
|||
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
|
||||
|
|
|
@ -10,11 +10,16 @@ declare -a params=(
|
|||
|
||||
MODELS_REPO=lora-tests
|
||||
MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
|
||||
COMMIT=c26d5fb85b4070a9e9c4e65d132c783b98086890
|
||||
|
||||
# Clone the Hugging Face repository if the directory does not exist
|
||||
if [ ! -d "$MODELS_REPO" ]; then
|
||||
echo "Cloning the Hugging Face repository..."
|
||||
git clone $MODELS_REPO_URL --depth 1
|
||||
cd $MODELS_REPO
|
||||
git fetch --depth=1 origin $COMMIT
|
||||
git reset --hard $COMMIT
|
||||
cd -
|
||||
else
|
||||
echo "Repository already exists. Skipping clone."
|
||||
fi
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue