Merge branch 'ggerganov:master' into master
This commit is contained in:
commit
504ee68716
11 changed files with 249 additions and 83 deletions
41
.github/workflows/build.yml
vendored
41
.github/workflows/build.yml
vendored
|
@ -184,6 +184,47 @@ jobs:
|
|||
cmake -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
ubuntu-22-cmake-sycl-fp16:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
continue-on-error: true
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: add oneAPI to apt
|
||||
shell: bash
|
||||
run: |
|
||||
cd /tmp
|
||||
wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB
|
||||
sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB
|
||||
rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB
|
||||
sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main"
|
||||
|
||||
- name: install oneAPI dpcpp compiler
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt update
|
||||
sudo apt install intel-oneapi-compiler-dpcpp-cpp
|
||||
|
||||
- name: install oneAPI MKL library
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt install intel-oneapi-mkl-devel
|
||||
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
source /opt/intel/oneapi/setvars.sh
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DLLAMA_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_SYCL_F16=ON ..
|
||||
cmake --build . --config Release -j $(nproc)
|
||||
|
||||
# TODO: build with LLAMA_NO_METAL because test-backend-ops fail on "Apple Paravirtual device" and I don't know
|
||||
# how to debug it.
|
||||
# ref: https://github.com/ggerganov/llama.cpp/actions/runs/7131777249/job/19420981052#step:5:1124
|
||||
|
|
|
@ -850,7 +850,9 @@ endif()
|
|||
|
||||
set(ARCH_FLAGS "")
|
||||
|
||||
if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") OR ("${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "arm64"))
|
||||
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
||||
message(STATUS "ARM detected")
|
||||
if (MSVC)
|
||||
add_compile_definitions(__ARM_NEON)
|
||||
|
@ -876,7 +878,9 @@ if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATC
|
|||
list(APPEND ARCH_FLAGS -mno-unaligned-access)
|
||||
endif()
|
||||
endif()
|
||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$" OR "${CMAKE_GENERATOR_PLATFORM_LWR}" MATCHES "^(x86_64|i686|amd64|x64)$" )
|
||||
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
||||
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
||||
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
||||
message(STATUS "x86 detected")
|
||||
if (MSVC)
|
||||
# instruction set detection for MSVC only
|
||||
|
|
|
@ -150,6 +150,7 @@ Unless otherwise noted these projects are open-source with permissive licensing:
|
|||
- [ollama/ollama](https://github.com/ollama/ollama)
|
||||
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (AGPL)
|
||||
- [psugihara/FreeChat](https://github.com/psugihara/FreeChat)
|
||||
- [cztomsik/ava](https://github.com/cztomsik/ava) (MIT)
|
||||
- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal)
|
||||
- [pythops/tenere](https://github.com/pythops/tenere) (AGPL)
|
||||
- [semperai/amica](https://github.com/semperai/amica)
|
||||
|
@ -679,7 +680,7 @@ python3 -m pip install -r requirements.txt
|
|||
python3 convert.py models/mymodel/
|
||||
|
||||
# [Optional] for models using BPE tokenizers
|
||||
python convert.py models/mymodel/ --vocabtype bpe
|
||||
python convert.py models/mymodel/ --vocab-type bpe
|
||||
|
||||
# quantize the model to 4-bits (using Q4_K_M method)
|
||||
./quantize ./models/mymodel/ggml-model-f16.gguf ./models/mymodel/ggml-model-Q4_K_M.gguf Q4_K_M
|
||||
|
|
|
@ -132,7 +132,7 @@ static void sampler_queue(
|
|||
const float temp = params.temp;
|
||||
const float dynatemp_range = params.dynatemp_range;
|
||||
const float dynatemp_exponent = params.dynatemp_exponent;
|
||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||
const int32_t top_k = params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float min_p = params.min_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
|
|
|
@ -1078,17 +1078,76 @@ class MiniCPMModel(Model):
|
|||
self.gguf_writer.add_name("MiniCPM")
|
||||
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_hf()
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
|
||||
return (
|
||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape)
|
||||
)
|
||||
|
||||
def write_tensors(self):
|
||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# HF models permute some of the tensors, so we need to undo that
|
||||
if name.endswith(("q_proj.weight")):
|
||||
data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
|
||||
if name.endswith(("k_proj.weight")):
|
||||
data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if self.ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
class QwenModel(Model):
|
||||
@staticmethod
|
||||
|
|
|
@ -14,14 +14,14 @@ Build with cmake or run `make llava-cli` to build it.
|
|||
After building, run: `./llava-cli` to see the usage. For example:
|
||||
|
||||
```sh
|
||||
./llava-cli -m llava-v1.5-7b/ggml-model-q5_k.gguf --mmproj llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
|
||||
./llava-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
|
||||
```
|
||||
|
||||
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
|
||||
|
||||
## Model conversion
|
||||
|
||||
- Clone `llava-v15-7b`` and `clip-vit-large-patch14-336`` locally:
|
||||
- Clone `llava-v15-7b` and `clip-vit-large-patch14-336` locally:
|
||||
|
||||
```sh
|
||||
git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
|
||||
|
@ -38,7 +38,7 @@ python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b
|
|||
3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF:
|
||||
|
||||
```sh
|
||||
python ./examples/llava/convert-image-encoder-to-gguf -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
|
||||
python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
|
||||
```
|
||||
|
||||
4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF:
|
||||
|
|
181
ggml-cuda.cu
181
ggml-cuda.cu
|
@ -5310,22 +5310,26 @@ template <bool need_check> static __global__ void
|
|||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||
}
|
||||
|
||||
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
#define MMVQ_NWARPS_NVIDIA 4
|
||||
#define MMVQ_NWARPS_AMD_RDNA2 1
|
||||
#define MMVQ_NWARPS_AMD_OLD 4
|
||||
|
||||
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
|
||||
|
||||
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
|
||||
if (row >= nrows_x) {
|
||||
return;
|
||||
}
|
||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||
const int row = blockIdx.x;
|
||||
|
||||
const int blocks_per_row_x = ncols_x / qk;
|
||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
||||
|
||||
// partial sum for each thread
|
||||
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
|
||||
|
@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q(
|
|||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
|
||||
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
|
||||
const int ibx = row*blocks_per_row_x + i; // x block index
|
||||
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
||||
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
|
@ -5346,13 +5350,29 @@ static __global__ void mul_mat_vec_q(
|
|||
}
|
||||
}
|
||||
|
||||
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
|
||||
if (threadIdx.y > 0) {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.y > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols_y; ++j) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < nwarps-1; ++i) {
|
||||
tmp[j] += tmp_shared[i][j][threadIdx.x];
|
||||
}
|
||||
tmp[j] = warp_reduce_sum(tmp[j]);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
dst[j*nrows_x + row] = tmp[j];
|
||||
dst[j*nrows_dst + row] = tmp[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6828,51 +6848,70 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
|||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
GGML_ASSERT(ncols_x % qk == 0);
|
||||
GGML_ASSERT(ncols_y <= 4);
|
||||
|
||||
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
break;
|
||||
// case 5:
|
||||
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
// break;
|
||||
// case 6:
|
||||
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
// break;
|
||||
// case 7:
|
||||
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
// break;
|
||||
// case 8:
|
||||
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
// break;
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
int nwarps;
|
||||
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
|
||||
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
|
||||
} else {
|
||||
nwarps = MMVQ_NWARPS_NVIDIA;
|
||||
}
|
||||
|
||||
const dim3 block_nums(nrows_x, 1, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
switch (nwarps) {
|
||||
case 1: switch(ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
case 4: switch(ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
|
||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
} break;
|
||||
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
|
||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -8391,7 +8430,7 @@ static void ggml_cuda_op_mul_mat_q(
|
|||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
|
||||
|
||||
switch (src0->type) {
|
||||
|
@ -8525,58 +8564,70 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
// the main device has a larger memory buffer to hold the results from all GPUs
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_0:
|
||||
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_1:
|
||||
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q8_0:
|
||||
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_XXS:
|
||||
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
|
@ -9909,7 +9960,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||
}
|
||||
} else {
|
||||
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
|
||||
if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
||||
} else if (use_mul_mat_q) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||
|
|
|
@ -12148,7 +12148,8 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
|
|||
const int64_t src1_ncols, const int64_t src1_padded_row_size,
|
||||
const dpct::queue_ptr &stream) {
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t row_diff = row_high - row_low;
|
||||
|
||||
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||
|
@ -12167,8 +12168,9 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec(
|
|||
} else {
|
||||
src1_dfloat = src1_dfloat_a.alloc(ne00);
|
||||
ggml_cpy_f32_f16_sycl((const char *)src1_ddf_i, (char *)src1_dfloat,
|
||||
ne00, ne00, 1, sizeof(float), 0, 0, ne00, 1,
|
||||
sizeof(sycl::half), 0, 0, stream);
|
||||
ne00, ne00, ne01, ne02, nb00, nb01, nb02,
|
||||
nb03, ne10, ne11, ne12, nb10, nb11, nb12,
|
||||
nb13, stream);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
16
llama.cpp
16
llama.cpp
|
@ -2947,6 +2947,8 @@ static void llm_load_hparams(
|
|||
} break;
|
||||
case LLM_ARCH_MINICPM:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 40: model.type = e_model::MODEL_2B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
|
@ -4207,8 +4209,7 @@ static bool llm_load_tensors(
|
|||
ctx_bufs.emplace_back(ctx, buf);
|
||||
}
|
||||
|
||||
// print memory requirements
|
||||
{
|
||||
if (llama_supports_gpu_offload()) {
|
||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||
|
||||
LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
|
||||
|
@ -4220,10 +4221,11 @@ static bool llm_load_tensors(
|
|||
const int max_offloadable_layers = hparams.n_layer + 1;
|
||||
|
||||
LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
|
||||
}
|
||||
|
||||
for (ggml_backend_buffer_t buf : model.bufs) {
|
||||
LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
|
||||
}
|
||||
// print memory requirements
|
||||
for (ggml_backend_buffer_t buf : model.bufs) {
|
||||
LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// populate tensors_by_name
|
||||
|
@ -8586,6 +8588,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
|
|||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (k <= 0) {
|
||||
k = candidates->size;
|
||||
}
|
||||
|
||||
k = std::max(k, (int) min_keep);
|
||||
k = std::min(k, (int) candidates->size);
|
||||
|
||||
|
|
2
tests/.gitignore
vendored
2
tests/.gitignore
vendored
|
@ -1,3 +1,3 @@
|
|||
*
|
||||
!*.*
|
||||
test-c.o
|
||||
*.o
|
||||
|
|
|
@ -235,6 +235,8 @@ int main(void) {
|
|||
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
|
||||
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
|
||||
|
||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
|
||||
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue