Merge branch 'master' into metal-memory-reduction
This commit is contained in:
commit
c846c451d1
11 changed files with 644 additions and 27 deletions
32
.devops/server-cuda.Dockerfile
Normal file
32
.devops/server-cuda.Dockerfile
Normal file
|
@ -0,0 +1,32 @@
|
|||
ARG UBUNTU_VERSION=22.04
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG CUDA_VERSION=11.7.1
|
||||
# Target the CUDA build image
|
||||
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||
# Target the CUDA runtime image
|
||||
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
|
||||
|
||||
FROM ${BASE_CUDA_DEV_CONTAINER} as build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
ARG CUDA_DOCKER_ARCH=all
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# Set nvcc architecture
|
||||
ENV CUDA_DOCKER_ARCH=${CUDA_DOCKER_ARCH}
|
||||
# Enable cuBLAS
|
||||
ENV LLAMA_CUBLAS=1
|
||||
|
||||
RUN make
|
||||
|
||||
FROM ${BASE_CUDA_RUN_CONTAINER} as runtime
|
||||
|
||||
COPY --from=build /app/server /server
|
||||
|
||||
ENTRYPOINT [ "/server" ]
|
25
.devops/server-intel.Dockerfile
Normal file
25
.devops/server-intel.Dockerfile
Normal file
|
@ -0,0 +1,25 @@
|
|||
ARG ONEAPI_VERSION=2024.0.1-devel-ubuntu22.04
|
||||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
FROM intel/hpckit:$ONEAPI_VERSION as build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# for some reasons, "-DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_64lp -DLLAMA_NATIVE=ON" give worse performance
|
||||
RUN mkdir build && \
|
||||
cd build && \
|
||||
cmake .. -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx && \
|
||||
cmake --build . --config Release --target main server
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION as runtime
|
||||
|
||||
COPY --from=build /app/build/bin/server /server
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
ENTRYPOINT [ "/server" ]
|
45
.devops/server-rocm.Dockerfile
Normal file
45
.devops/server-rocm.Dockerfile
Normal file
|
@ -0,0 +1,45 @@
|
|||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
# This needs to generally match the container host's environment.
|
||||
ARG ROCM_VERSION=5.6
|
||||
|
||||
# Target the CUDA build image
|
||||
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
|
||||
|
||||
FROM ${BASE_ROCM_DEV_CONTAINER} as build
|
||||
|
||||
# Unless otherwise specified, we make a fat build.
|
||||
# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878
|
||||
# This is mostly tied to rocBLAS supported archs.
|
||||
ARG ROCM_DOCKER_ARCH=\
|
||||
gfx803 \
|
||||
gfx900 \
|
||||
gfx906 \
|
||||
gfx908 \
|
||||
gfx90a \
|
||||
gfx1010 \
|
||||
gfx1030 \
|
||||
gfx1100 \
|
||||
gfx1101 \
|
||||
gfx1102
|
||||
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY requirements requirements
|
||||
|
||||
RUN pip install --upgrade pip setuptools wheel \
|
||||
&& pip install -r requirements.txt
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
# Set nvcc architecture
|
||||
ENV GPU_TARGETS=${ROCM_DOCKER_ARCH}
|
||||
# Enable ROCm
|
||||
ENV LLAMA_HIPBLAS=1
|
||||
ENV CC=/opt/rocm/llvm/bin/clang
|
||||
ENV CXX=/opt/rocm/llvm/bin/clang++
|
||||
|
||||
RUN make
|
||||
|
||||
ENTRYPOINT [ "/app/server" ]
|
20
.devops/server.Dockerfile
Normal file
20
.devops/server.Dockerfile
Normal file
|
@ -0,0 +1,20 @@
|
|||
ARG UBUNTU_VERSION=22.04
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION as build
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y build-essential git
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN make
|
||||
|
||||
FROM ubuntu:$UBUNTU_VERSION as runtime
|
||||
|
||||
COPY --from=build /app/server /server
|
||||
|
||||
ENV LC_ALL=C.utf8
|
||||
|
||||
ENTRYPOINT [ "/server" ]
|
4
.github/workflows/docker.yml
vendored
4
.github/workflows/docker.yml
vendored
|
@ -28,14 +28,18 @@ jobs:
|
|||
config:
|
||||
- { tag: "light", dockerfile: ".devops/main.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "full", dockerfile: ".devops/full.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "server", dockerfile: ".devops/server.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
# NOTE(canardletter): The CUDA builds on arm64 are very slow, so I
|
||||
# have disabled them for now until the reason why
|
||||
# is understood.
|
||||
- { tag: "light-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "full-cuda", dockerfile: ".devops/full-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "server-cuda", dockerfile: ".devops/server-cuda.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "light-rocm", dockerfile: ".devops/main-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "full-rocm", dockerfile: ".devops/full-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "server-rocm", dockerfile: ".devops/server-rocm.Dockerfile", platforms: "linux/amd64,linux/arm64" }
|
||||
- { tag: "light-intel", dockerfile: ".devops/main-intel.Dockerfile", platforms: "linux/amd64" }
|
||||
- { tag: "server-intel", dockerfile: ".devops/server-intel.Dockerfile", platforms: "linux/amd64" }
|
||||
steps:
|
||||
- name: Check out the repo
|
||||
uses: actions/checkout@v3
|
||||
|
|
17
README.md
17
README.md
|
@ -122,7 +122,8 @@ as the main playground for developing new features for the [ggml](https://github
|
|||
- Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp)
|
||||
- JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp)
|
||||
- Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb)
|
||||
- Rust: [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
||||
- Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp)
|
||||
- Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs)
|
||||
- C#/.NET: [SciSharp/LLamaSharp](https://github.com/SciSharp/LLamaSharp)
|
||||
- Scala 3: [donderom/llm4s](https://github.com/donderom/llm4s)
|
||||
- Clojure: [phronmophobic/llama.clj](https://github.com/phronmophobic/llama.clj)
|
||||
|
@ -931,17 +932,20 @@ Place your desired model into the `~/llama.cpp/models/` directory and execute th
|
|||
* Create a folder to store big models & intermediate files (ex. /llama/models)
|
||||
|
||||
#### Images
|
||||
We have two Docker images available for this project:
|
||||
We have three Docker images available for this project:
|
||||
|
||||
1. `ghcr.io/ggerganov/llama.cpp:full`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
2. `ghcr.io/ggerganov/llama.cpp:light`: This image only includes the main executable file. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
3. `ghcr.io/ggerganov/llama.cpp:server`: This image only includes the server executabhle file. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
|
||||
Additionally, there the following images, similar to the above:
|
||||
|
||||
- `ghcr.io/ggerganov/llama.cpp:full-cuda`: Same as `full` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggerganov/llama.cpp:light-cuda`: Same as `light` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggerganov/llama.cpp:server-cuda`: Same as `server` but compiled with CUDA support. (platforms: `linux/amd64`)
|
||||
- `ghcr.io/ggerganov/llama.cpp:full-rocm`: Same as `full` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggerganov/llama.cpp:light-rocm`: Same as `light` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
- `ghcr.io/ggerganov/llama.cpp:server-rocm`: Same as `server` but compiled with ROCm support. (platforms: `linux/amd64`, `linux/arm64`)
|
||||
|
||||
The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](.github/workflows/docker.yml). If you need different settings (for example, a different CUDA or ROCm library, you'll need to build the images locally for now).
|
||||
|
||||
|
@ -967,6 +971,12 @@ or with a light image:
|
|||
docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512
|
||||
```
|
||||
|
||||
or with a server image:
|
||||
|
||||
```bash
|
||||
docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggerganov/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512
|
||||
```
|
||||
|
||||
### Docker With CUDA
|
||||
|
||||
Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia-container-toolkit) properly installed on Linux, or is using a GPU enabled cloud, `cuBLAS` should be accessible inside the container.
|
||||
|
@ -976,6 +986,7 @@ Assuming one has the [nvidia-container-toolkit](https://github.com/NVIDIA/nvidia
|
|||
```bash
|
||||
docker build -t local/llama.cpp:full-cuda -f .devops/full-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:light-cuda -f .devops/main-cuda.Dockerfile .
|
||||
docker build -t local/llama.cpp:server-cuda -f .devops/server-cuda.Dockerfile .
|
||||
```
|
||||
|
||||
You may want to pass in some different `ARGS`, depending on the CUDA environment supported by your container host, as well as the GPU architecture.
|
||||
|
@ -989,6 +1000,7 @@ The resulting images, are essentially the same as the non-CUDA images:
|
|||
|
||||
1. `local/llama.cpp:full-cuda`: This image includes both the main executable file and the tools to convert LLaMA models into ggml and convert into 4-bit quantization.
|
||||
2. `local/llama.cpp:light-cuda`: This image only includes the main executable file.
|
||||
3. `local/llama.cpp:server-cuda`: This image only includes the server executable file.
|
||||
|
||||
#### Usage
|
||||
|
||||
|
@ -997,6 +1009,7 @@ After building locally, Usage is similar to the non-CUDA examples, but you'll ne
|
|||
```bash
|
||||
docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1
|
||||
docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1
|
||||
docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1
|
||||
```
|
||||
|
||||
### Contributing
|
||||
|
|
|
@ -201,6 +201,8 @@ class Model:
|
|||
return PlamoModel
|
||||
if model_architecture == "CodeShellForCausalLM":
|
||||
return CodeShellModel
|
||||
if model_architecture == "OrionForCausalLM":
|
||||
return OrionModel
|
||||
return Model
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
|
@ -250,6 +252,8 @@ class Model:
|
|||
return gguf.MODEL_ARCH.PLAMO
|
||||
if arch == "CodeShellForCausalLM":
|
||||
return gguf.MODEL_ARCH.CODESHELL
|
||||
if arch == "OrionForCausalLM":
|
||||
return gguf.MODEL_ARCH.ORION
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
|
@ -572,6 +576,83 @@ class MPTModel(Model):
|
|||
self.gguf_writer.add_tensor("output.weight", data)
|
||||
|
||||
|
||||
class OrionModel(Model):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
head_count = self.hparams["num_attention_heads"]
|
||||
head_count_kv = self.hparams.get("num_key_value_heads", head_count)
|
||||
hf_repo = self.hparams.get("_name_or_path", "")
|
||||
|
||||
ctx_length = 0
|
||||
if "max_sequence_length" in self.hparams:
|
||||
ctx_length = self.hparams["max_sequence_length"]
|
||||
elif "max_position_embeddings" in self.hparams:
|
||||
ctx_length = self.hparams["max_position_embeddings"]
|
||||
elif "model_max_length" in self.hparams:
|
||||
ctx_length = self.hparams["model_max_length"]
|
||||
else:
|
||||
print("gguf: can not find ctx length parameter.")
|
||||
sys.exit()
|
||||
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_name(self.dir_model.name)
|
||||
self.gguf_writer.add_source_hf_repo(hf_repo)
|
||||
self.gguf_writer.add_tensor_data_layout("Meta AI original pth")
|
||||
self.gguf_writer.add_context_length(ctx_length)
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(head_count)
|
||||
self.gguf_writer.add_head_count_kv(head_count_kv)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
|
||||
|
||||
def write_tensors(self):
|
||||
# Collect tensors from generator object
|
||||
model_kv = dict(self.get_tensors())
|
||||
block_count = self.hparams["num_hidden_layers"]
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
|
||||
for name, data_torch in model_kv.items():
|
||||
# we don't need these
|
||||
if name.endswith(".rotary_emb.inv_freq"):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if self.ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{name} -> {new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
class BaichuanModel(Model):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
|
|
@ -66,6 +66,14 @@ server.exe -m models\7B\ggml-model.gguf -c 2048
|
|||
The above command will start a server that by default listens on `127.0.0.1:8080`.
|
||||
You can consume the endpoints with Postman or NodeJS with axios library. You can visit the web front end at the same url.
|
||||
|
||||
### Docker:
|
||||
```bash
|
||||
docker run -p 8080:8080 -v /path/to/models:/models ggerganov/llama.cpp:server -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080
|
||||
|
||||
# or, with CUDA:
|
||||
docker run -p 8080:8080 -v /path/to/models:/models --gpus all ggerganov/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99
|
||||
```
|
||||
|
||||
## Testing with CURL
|
||||
|
||||
Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the base OS.
|
||||
|
|
|
@ -101,6 +101,7 @@ class MODEL_ARCH(IntEnum):
|
|||
PHI2 = auto()
|
||||
PLAMO = auto()
|
||||
CODESHELL = auto()
|
||||
ORION = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -151,6 +152,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
MODEL_ARCH.CODESHELL: "codeshell",
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -427,7 +429,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
]
|
||||
],
|
||||
MODEL_ARCH.ORION: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
@ -452,6 +470,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.ORION: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
|
246
llama.cpp
246
llama.cpp
|
@ -52,6 +52,7 @@
|
|||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cfloat>
|
||||
#include <cinttypes>
|
||||
#include <climits>
|
||||
#include <cmath>
|
||||
|
@ -196,6 +197,7 @@ enum llm_arch {
|
|||
LLM_ARCH_PHI2,
|
||||
LLM_ARCH_PLAMO,
|
||||
LLM_ARCH_CODESHELL,
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
@ -217,6 +219,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_PHI2, "phi2" },
|
||||
{ LLM_ARCH_PLAMO, "plamo" },
|
||||
{ LLM_ARCH_CODESHELL, "codeshell" },
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
};
|
||||
|
||||
enum llm_kv {
|
||||
|
@ -641,6 +644,25 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_ORION,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
|
@ -1332,6 +1354,7 @@ enum e_model {
|
|||
MODEL_7B,
|
||||
MODEL_8B,
|
||||
MODEL_13B,
|
||||
MODEL_14B,
|
||||
MODEL_15B,
|
||||
MODEL_30B,
|
||||
MODEL_34B,
|
||||
|
@ -2683,6 +2706,7 @@ static const char * llama_model_type_name(e_model type) {
|
|||
case MODEL_7B: return "7B";
|
||||
case MODEL_8B: return "8B";
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_14B: return "14B";
|
||||
case MODEL_15B: return "15B";
|
||||
case MODEL_30B: return "30B";
|
||||
case MODEL_34B: return "34B";
|
||||
|
@ -2950,7 +2974,15 @@ static void llm_load_hparams(
|
|||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 40: model.type = e_model::MODEL_14B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
|
@ -3933,6 +3965,38 @@ static bool llm_load_tensors(
|
|||
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output_norm_b = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
}
|
||||
} break;
|
||||
|
||||
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
@ -4563,6 +4627,126 @@ struct llm_build_context {
|
|||
ctx0 = nullptr;
|
||||
}
|
||||
}
|
||||
struct ggml_cgraph * build_orion() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
|
||||
cb(inp_pos, "inp_pos", -1);
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
||||
cb(KQ_mask, "KQ_mask", -1);
|
||||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
// if (model.layers[il].bq) {
|
||||
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
// cb(Qcur, "Qcur", il);
|
||||
// }
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
// if (model.layers[il].bk) {
|
||||
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
// cb(Kcur, "Kcur", il);
|
||||
// }
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
// if (model.layers[il].bv) {
|
||||
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
// cb(Vcur, "Vcur", il);
|
||||
// }
|
||||
|
||||
Qcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
|
||||
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_rope_custom(
|
||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
|
||||
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
||||
LLM_NORM, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, model.output_norm_b,
|
||||
LLM_NORM, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
|
||||
|
||||
struct ggml_cgraph * build_llama() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
@ -6520,6 +6704,10 @@ static struct ggml_cgraph * llama_build_graph(
|
|||
{
|
||||
result = llm.build_codeshell();
|
||||
} break;
|
||||
case LLM_ARCH_ORION:
|
||||
{
|
||||
result = llm.build_orion();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
@ -7946,6 +8134,11 @@ void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * c
|
|||
}
|
||||
|
||||
void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
|
||||
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
|
||||
// if (k >= (int32_t)candidates->size) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
|
@ -8054,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can
|
|||
return;
|
||||
}
|
||||
|
||||
llama_sample_softmax(ctx, candidates);
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
float scale = candidates->data[0].p; // scale by max prob
|
||||
size_t i = 1; // first token always matches
|
||||
bool min_p_applied = false;
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].p < p * scale && i >= min_keep) {
|
||||
break; // prob too small
|
||||
// if the candidates aren't sorted, try the unsorted implementation first
|
||||
if (!candidates->sorted) {
|
||||
std::vector<llama_token_data> filtered_tokens;
|
||||
|
||||
float max_logit = -FLT_MAX;
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
max_logit = std::max(max_logit, candidates->data[i].logit);
|
||||
}
|
||||
const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit >= min_logit) {
|
||||
filtered_tokens.push_back(candidates->data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// if we have enough values the operation was a success
|
||||
if (filtered_tokens.size() >= min_keep) {
|
||||
memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
|
||||
candidates->size = filtered_tokens.size();
|
||||
min_p_applied = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
// if the candidates are sorted or the unsorted implementation failed, use this implementation
|
||||
if (!min_p_applied) {
|
||||
// Sort the logits in descending order
|
||||
if (!candidates->sorted) {
|
||||
std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
|
||||
return a.logit > b.logit;
|
||||
});
|
||||
candidates->sorted = true;
|
||||
}
|
||||
|
||||
const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
|
||||
size_t i = 1; // first token always matches
|
||||
|
||||
for (; i < candidates->size; ++i) {
|
||||
if (candidates->data[i].logit < min_logit && i >= min_keep) {
|
||||
break; // prob too small
|
||||
}
|
||||
}
|
||||
|
||||
// Resize the output vector to keep only the matching tokens
|
||||
candidates->size = i;
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
|
|
|
@ -5,11 +5,10 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static void dump(const llama_token_data_array * candidates) {
|
||||
for (size_t i = 0; i < candidates->size; i++) {
|
||||
|
@ -20,11 +19,11 @@ static void dump(const llama_token_data_array * candidates) {
|
|||
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
||||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||
size_t n_vocab = probs.size();
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
|
@ -41,11 +40,11 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
size_t n_vocab = probs.size();
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
|
@ -62,11 +61,11 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||
}
|
||||
|
||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||
size_t n_vocab = probs.size();
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
|
@ -81,12 +80,33 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||
}
|
||||
}
|
||||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
size_t n_vocab = probs.size();
|
||||
static void test_min_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_min_p(nullptr, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
|
@ -107,11 +127,11 @@ static void test_repetition_penalties(
|
|||
) {
|
||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
const size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
const float logit = logf(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
|
@ -128,6 +148,88 @@ static void test_repetition_penalties(
|
|||
}
|
||||
}
|
||||
|
||||
static void test_sampler_queue(
|
||||
const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
const float logit = logf(token_id);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
llama_token min_token_id = 0;
|
||||
const llama_token max_token_id = n_vocab-1;
|
||||
|
||||
for (auto s : samplers_sequence) {
|
||||
switch (s){
|
||||
case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break;
|
||||
case 'f': GGML_ASSERT(false && "tail_free test not implemented"); break;
|
||||
case 'y': GGML_ASSERT(false && "typical test not implemented"); break;
|
||||
case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break;
|
||||
case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break;
|
||||
case 't': GGML_ASSERT(false && "temperature test not implemented"); break;
|
||||
default : GGML_ASSERT(false && "Unknown sampler"); break;
|
||||
}
|
||||
|
||||
llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests
|
||||
|
||||
const int size = candidates_p.size;
|
||||
|
||||
if (s == 'k') {
|
||||
const int expected_size = std::min(size, top_k);
|
||||
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k));
|
||||
|
||||
GGML_ASSERT(size == expected_size);
|
||||
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
|
||||
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
|
||||
} else if (s == 'p') {
|
||||
const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2;
|
||||
const int softmax_numerator_target = ceilf(top_p * softmax_divisor);
|
||||
|
||||
min_token_id = n_vocab;
|
||||
int expected_size = 0;
|
||||
int cumsum = 0;
|
||||
do { // do-while because always at least one token is sampled
|
||||
min_token_id--;
|
||||
expected_size++;
|
||||
|
||||
cumsum += min_token_id;
|
||||
} while (cumsum < softmax_numerator_target);
|
||||
|
||||
// token 0 has p == 0, need special consideration for cumsum because top_p immediately returns
|
||||
if (min_token_id == 1) {
|
||||
min_token_id--;
|
||||
expected_size += 1;
|
||||
}
|
||||
|
||||
GGML_ASSERT(size == expected_size);
|
||||
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
|
||||
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
|
||||
} else if (s == 'm') {
|
||||
int expected_size = ceilf((1.0f-min_p) * n_vocab);
|
||||
expected_size = std::max(expected_size, 1);
|
||||
expected_size = std::min(expected_size, size);
|
||||
|
||||
min_token_id = floorf(min_p * n_vocab);
|
||||
min_token_id = std::max(min_token_id, 1);
|
||||
min_token_id = std::max(min_token_id, (llama_token)(n_vocab - size));
|
||||
min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1));
|
||||
|
||||
GGML_ASSERT(size == expected_size);
|
||||
GGML_ASSERT(candidates_p.data[0].id == max_token_id);
|
||||
GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id);
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
printf("Sampler queue %3s OK with n_vocab=%05ld top_k=%05d top_p=%f min_p=%f\n",
|
||||
samplers_sequence.c_str(), n_vocab, top_k, top_p, min_p);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
ggml_time_init();
|
||||
|
||||
|
@ -139,6 +241,15 @@ int main(void) {
|
|||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
|
||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
|
||||
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.00f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/1.0f, 0.3f/1.0f, 0.2f/1.0f, 0.1f/1.0f}, 0.24f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.26f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.9f, 0.3f/0.9f, 0.2f/0.9f}, 0.49f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.51f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.7f, 0.3f/0.7f}, 0.74f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 0.76f);
|
||||
test_min_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f/0.4f}, 1.00f);
|
||||
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
|
||||
test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
|
||||
|
@ -154,6 +265,34 @@ int main(void) {
|
|||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||
|
||||
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 0.0f, 1.0f);
|
||||
test_sampler_queue(10000, "m", 10000, 1.0f, 1.0f);
|
||||
test_sampler_queue(10000, "m", 10000, 1.0f, 1e-12);
|
||||
|
||||
test_sampler_queue(10000, "k", 100, 1.0000f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 0.0002f, 1.0f);
|
||||
test_sampler_queue(10000, "p", 10000, 0.8000f, 1.0f);
|
||||
test_sampler_queue(10000, "m", 10000, 1.0000f, 9997.9f/9999.0f);
|
||||
test_sampler_queue(10000, "m", 10000, 1.0000f, 0.1f);
|
||||
|
||||
test_sampler_queue(10000, "kp", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "km", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "pk", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "pm", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "mk", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "mp", 100, 0.8f, 9997.9f/9999.0f);
|
||||
test_sampler_queue(10000, "mp", 100, 0.8f, 0.1f);
|
||||
|
||||
test_sampler_queue(10000, "kpm", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "kmp", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "pkm", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "pmk", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
|
||||
|
||||
printf("OK\n");
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue