Merge branch 'ggerganov:master' into master
This commit is contained in:
commit
fee7936705
13 changed files with 911 additions and 739 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
@ -61,6 +61,11 @@ llama-batched-swift
|
||||||
out/
|
out/
|
||||||
tmp/
|
tmp/
|
||||||
|
|
||||||
|
# Deprecated
|
||||||
|
|
||||||
|
/main
|
||||||
|
/server
|
||||||
|
|
||||||
# CI
|
# CI
|
||||||
|
|
||||||
!.github/workflows/*.yml
|
!.github/workflows/*.yml
|
||||||
|
|
8
Makefile
8
Makefile
|
@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64
|
||||||
endif # GGML_OPENBLAS64
|
endif # GGML_OPENBLAS64
|
||||||
|
|
||||||
ifdef GGML_BLIS
|
ifdef GGML_BLIS
|
||||||
MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis
|
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis
|
||||||
MK_LDFLAGS += -lblis -L/usr/local/lib
|
MK_LDFLAGS += -lblis -L/usr/local/lib
|
||||||
OBJ_GGML += ggml/src/ggml-blas.o
|
OBJ_GGML += ggml/src/ggml-blas.o
|
||||||
endif # GGML_BLIS
|
endif # GGML_BLIS
|
||||||
|
|
||||||
|
ifdef GGML_NVPL
|
||||||
|
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas
|
||||||
|
MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp
|
||||||
|
OBJ_GGML += ggml/src/ggml-blas.o
|
||||||
|
endif # GGML_NVPL
|
||||||
|
|
||||||
ifndef GGML_NO_LLAMAFILE
|
ifndef GGML_NO_LLAMAFILE
|
||||||
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
|
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
|
||||||
OBJ_GGML += ggml/src/llamafile/sgemm.o
|
OBJ_GGML += ggml/src/llamafile/sgemm.o
|
||||||
|
|
|
@ -29,6 +29,7 @@ static void print_usage_information(const char * argv0, FILE * stream) {
|
||||||
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
|
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
|
||||||
fprintf(stream, " --stdin read prompt from standard input.\n");
|
fprintf(stream, " --stdin read prompt from standard input.\n");
|
||||||
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
|
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
|
||||||
|
fprintf(stream, " --no-parse-special do not parse control tokens.\n");
|
||||||
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
|
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
|
||||||
fprintf(stream, " --show-count print the total number of tokens.\n");
|
fprintf(stream, " --show-count print the total number of tokens.\n");
|
||||||
}
|
}
|
||||||
|
@ -195,6 +196,7 @@ int main(int raw_argc, char ** raw_argv) {
|
||||||
// variables where to put any arguments we see.
|
// variables where to put any arguments we see.
|
||||||
bool printing_ids = false;
|
bool printing_ids = false;
|
||||||
bool no_bos = false;
|
bool no_bos = false;
|
||||||
|
bool no_parse_special = false;
|
||||||
bool disable_logging = false;
|
bool disable_logging = false;
|
||||||
bool show_token_count = false;
|
bool show_token_count = false;
|
||||||
const char * model_path = NULL;
|
const char * model_path = NULL;
|
||||||
|
@ -229,6 +231,9 @@ int main(int raw_argc, char ** raw_argv) {
|
||||||
else if (arg == "--no-bos") {
|
else if (arg == "--no-bos") {
|
||||||
no_bos = true;
|
no_bos = true;
|
||||||
}
|
}
|
||||||
|
else if (arg == "--no-parse-special") {
|
||||||
|
no_parse_special = true;
|
||||||
|
}
|
||||||
else if (arg == "-p" || arg == "--prompt") {
|
else if (arg == "-p" || arg == "--prompt") {
|
||||||
if (prompt_set) {
|
if (prompt_set) {
|
||||||
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
|
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
|
||||||
|
@ -359,9 +364,10 @@ int main(int raw_argc, char ** raw_argv) {
|
||||||
|
|
||||||
const bool model_wants_add_bos = llama_should_add_bos_token(model);
|
const bool model_wants_add_bos = llama_should_add_bos_token(model);
|
||||||
const bool add_bos = model_wants_add_bos && !no_bos;
|
const bool add_bos = model_wants_add_bos && !no_bos;
|
||||||
|
const bool parse_special = !no_parse_special;
|
||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
tokens = ::llama_tokenize(model, prompt, add_bos, true);
|
tokens = ::llama_tokenize(model, prompt, add_bos, parse_special);
|
||||||
|
|
||||||
if (printing_ids) {
|
if (printing_ids) {
|
||||||
printf("[");
|
printf("[");
|
||||||
|
|
|
@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
||||||
|
|
||||||
// backend registry
|
// backend registry
|
||||||
|
|
||||||
#define GGML_REG_MAX_BACKENDS 16
|
#define GGML_REG_MAX_BACKENDS 64
|
||||||
|
|
||||||
struct ggml_backend_reg {
|
struct ggml_backend_reg {
|
||||||
char name[128];
|
char name[128];
|
||||||
|
|
|
@ -8,11 +8,12 @@
|
||||||
# include <Accelerate/Accelerate.h>
|
# include <Accelerate/Accelerate.h>
|
||||||
#elif defined(GGML_BLAS_USE_MKL)
|
#elif defined(GGML_BLAS_USE_MKL)
|
||||||
# include <mkl.h>
|
# include <mkl.h>
|
||||||
|
#elif defined(GGML_BLAS_USE_BLIS)
|
||||||
|
# include <blis.h>
|
||||||
|
#elif defined(GGML_BLAS_USE_NVPL)
|
||||||
|
# include <nvpl_blas.h>
|
||||||
#else
|
#else
|
||||||
# include <cblas.h>
|
# include <cblas.h>
|
||||||
# ifdef BLIS_ENABLE_CBLAS
|
|
||||||
# include <blis.h>
|
|
||||||
# endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct ggml_backend_blas_context {
|
struct ggml_backend_blas_context {
|
||||||
|
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
||||||
openblas_set_num_threads(ctx->n_threads);
|
openblas_set_num_threads(ctx->n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(BLIS_ENABLE_CBLAS)
|
#if defined(GGML_BLAS_USE_BLIS)
|
||||||
bli_thread_set_num_threads(ctx->n_threads);
|
bli_thread_set_num_threads(ctx->n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_BLAS_USE_NVPL)
|
||||||
|
nvpl_blas_set_num_threads(ctx->n_threads);
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
|
|
|
@ -104,7 +104,7 @@
|
||||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||||
#define cudaStream_t hipStream_t
|
#define cudaStream_t hipStream_t
|
||||||
#define cudaSuccess hipSuccess
|
#define cudaSuccess hipSuccess
|
||||||
#define __trap abort
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||||
|
|
|
@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
|
||||||
}
|
}
|
||||||
#endif // defined(INT8_MMA_AVAILABLE)
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_B_J8K4 {
|
struct mma_int_B_J8K4 {
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
||||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool need_sum>
|
template <mmq_q8_1_ds_layout ds_layout>
|
||||||
static __global__ void quantize_mmq_q8_1(
|
static __global__ void quantize_mmq_q8_1(
|
||||||
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||||
|
|
||||||
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||||
|
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||||
|
|
||||||
|
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||||
|
|
||||||
if (ix0 >= kx0_padded) {
|
if (ix0 >= kx0_padded) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const float4 * x4 = (const float4 *) x;
|
||||||
|
|
||||||
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
||||||
|
|
||||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||||
|
|
||||||
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
|
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||||
|
|
||||||
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
|
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||||
float amax = fabsf(xi);
|
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
float amax = fabsf(xi.x);
|
||||||
|
amax = fmaxf(amax, fabsf(xi.y));
|
||||||
|
amax = fmaxf(amax, fabsf(xi.z));
|
||||||
|
amax = fmaxf(amax, fabsf(xi.w));
|
||||||
|
|
||||||
amax = warp_reduce_max(amax);
|
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||||
|
#pragma unroll
|
||||||
float sum;
|
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
||||||
if (need_sum) {
|
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||||
sum = warp_reduce_sum(xi);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const float d = amax / 127;
|
float sum;
|
||||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||||
|
sum = xi.x + xi.y + xi.z + xi.w;
|
||||||
|
|
||||||
y[ib].qs[iqs] = q;
|
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
||||||
|
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d_inv = 127.0f / amax;
|
||||||
|
char4 q;
|
||||||
|
q.x = roundf(xi.x*d_inv);
|
||||||
|
q.y = roundf(xi.y*d_inv);
|
||||||
|
q.z = roundf(xi.z*d_inv);
|
||||||
|
q.w = roundf(xi.w*d_inv);
|
||||||
|
|
||||||
|
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
||||||
|
char4 * yqs4 = (char4 *) y[ib].qs;
|
||||||
|
yqs4[iqs/4] = q;
|
||||||
|
|
||||||
|
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
||||||
|
if (iqs % 16 != 0 || iqs >= 96) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[ib].d2s6[2 + iqs/16] = sum;
|
||||||
|
|
||||||
|
if (iqs % 64 != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = 1.0f / d_inv;
|
||||||
|
|
||||||
|
y[ib].d2s6[iqs/64] = d;
|
||||||
|
|
||||||
if (iqs % QK8_1 != 0) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (need_sum) {
|
if (iqs % 32 != 0) {
|
||||||
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = 1.0f / d_inv;
|
||||||
|
|
||||||
|
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
||||||
|
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
||||||
} else {
|
} else {
|
||||||
((float *) y[ib].ds)[iqs/QK8_1] = d;
|
y[ib].d4[iqs/32] = d;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
|
||||||
|
|
||||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||||
|
|
||||||
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||||
const dim3 num_blocks(block_num_x, kx1, channels);
|
const dim3 num_blocks(block_num_x, kx1, channels);
|
||||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||||
if (mmq_need_sum(type_x)) {
|
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||||
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||||
} else {
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||||
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||||
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||||
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||||
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||||
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,11 @@
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
|
||||||
|
|
||||||
|
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
|
||||||
|
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
||||||
|
|
||||||
typedef void (*quantize_cuda_t)(
|
typedef void (*quantize_cuda_t)(
|
||||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||||
|
|
|
@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||||
#define VDR_Q2_K_Q8_1_MMQ 2
|
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||||
|
|
||||||
// contiguous v/x values
|
// contiguous v/x values
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||||
|
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||||
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
|
template <int ns8>
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
|
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_d8 = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
|
||||||
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
|
const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
|
||||||
int sumi_d = 0;
|
int sumi_d0 = 0;
|
||||||
int sumi_m = 0;
|
|
||||||
|
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
|
||||||
|
int sumi_d1 = 0;
|
||||||
|
|
||||||
const int vi0 = v[i0/(QI8_1/2)];
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
|
sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
|
||||||
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
|
|
||||||
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
|
|
||||||
}
|
}
|
||||||
|
sumf_d8 += dm2f0.x * sumi_d0;
|
||||||
|
|
||||||
sumf_d += dm2f.x * sumi_d;
|
#pragma unroll
|
||||||
sumf_m += dm2f.y * sumi_m;
|
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||||
|
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
|
||||||
|
}
|
||||||
|
sumf_d8 += dm2f1.x * sumi_d1;
|
||||||
|
|
||||||
|
if (i0/QI8_1 < ns8) {
|
||||||
|
const float2 s8f = __half22float2(s8[i0/QI8_1]);
|
||||||
|
sumf -= dm2f0.y*s8f.x;
|
||||||
|
sumf -= dm2f1.y*s8f.y;
|
||||||
|
} else {
|
||||||
|
int sumi_m0 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
|
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
|
||||||
|
}
|
||||||
|
sumf_d8 -= dm2f0.y * sumi_m0;
|
||||||
|
|
||||||
|
int sumi_m1 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||||
|
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
|
||||||
|
}
|
||||||
|
sumf_d8 -= dm2f1.y * sumi_m1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return d8*(sumf_d - sumf_m);
|
return sumf + d8*sumf_d8;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_Q3_K_Q8_1_MMVQ 1
|
#define VDR_Q3_K_Q8_1_MMVQ 1
|
||||||
|
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
|
||||||
return d3 * sumf;
|
return d3 * sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
||||||
const float & d3, const float & d8) {
|
const float & d3, const float & d8) {
|
||||||
|
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
|
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
|
||||||
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
||||||
|
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
|
||||||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||||
|
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
|
||||||
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||||
|
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
|
||||||
return d*sumf;
|
return d*sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
||||||
const float & d6, const float * __restrict__ d8) {
|
const float & d6, const float * __restrict__ d8) {
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
|
|
||||||
|
const int sc_packed = get_int_b4(sc, 0);
|
||||||
|
const int8_t * sc_reg = (const int8_t *) &sc_packed;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
||||||
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
||||||
|
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||||
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
|
sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
return d6 * sumf_d;
|
return d6 * sumf_d;
|
||||||
|
|
|
@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
||||||
|
|
||||||
const ggml_tensor_extra_gpu *src0_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)src0->extra;
|
|
||||||
const ggml_tensor_extra_gpu *src1_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)src1->extra;
|
|
||||||
const ggml_tensor_extra_gpu *dst_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)dst->extra;
|
|
||||||
|
|
||||||
ggml_tensor_extra_gpu src0_row_extra;
|
|
||||||
ggml_tensor_extra_gpu src1_row_extra;
|
|
||||||
ggml_tensor_extra_gpu dst_row_extra;
|
|
||||||
|
|
||||||
ggml_tensor src0_row = *src0;
|
ggml_tensor src0_row = *src0;
|
||||||
ggml_tensor src1_row = *src1;
|
ggml_tensor src1_row = *src1;
|
||||||
ggml_tensor dst_row = *dst;
|
ggml_tensor dst_row = *dst;
|
||||||
|
|
||||||
src1_row.backend = GGML_BACKEND_TYPE_GPU;
|
char *src0_original = (char *)src0->data;
|
||||||
dst_row.backend = GGML_BACKEND_TYPE_GPU;
|
char *src1_original = (char *)src1->data;
|
||||||
|
char *dst_original = (char *)dst->data;
|
||||||
src0_row.extra = &src0_row_extra;
|
|
||||||
src1_row.extra = &src1_row_extra;
|
|
||||||
dst_row.extra = &dst_row_extra;
|
|
||||||
|
|
||||||
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)src0->data
|
|
||||||
: (char *)src0_extra->data_device[ctx.device];
|
|
||||||
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)src1->data
|
|
||||||
: (char *)src1_extra->data_device[ctx.device];
|
|
||||||
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)dst->data
|
|
||||||
: (char *)dst_extra->data_device[ctx.device];
|
|
||||||
|
|
||||||
src0_row.ne[2] = 1;
|
src0_row.ne[2] = 1;
|
||||||
src0_row.ne[3] = 1;
|
src0_row.ne[3] = 1;
|
||||||
|
@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
const int64_t i1 = id;
|
const int64_t i1 = id;
|
||||||
const int64_t i2 = i12;
|
const int64_t i2 = i12;
|
||||||
|
|
||||||
src0_row_extra.data_device[ctx.device] =
|
src0_row.data = src0_original + i02*nb02;
|
||||||
src0_original + i02*nb02;
|
src1_row.data = src1_original + + i11*nb11 + i12*nb12;
|
||||||
src1_row_extra.data_device[ctx.device] =
|
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
||||||
src1_original + + i11*nb11 + i12*nb12;
|
|
||||||
dst_row_extra.data_device[ctx.device] =
|
|
||||||
dst_original + i1*nb1 + i2*nb2;
|
|
||||||
|
|
||||||
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||||
}
|
}
|
||||||
|
@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||||
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||||
|
|
||||||
src1_row_extra.data_device[ctx.device] = src1_contiguous.get();
|
src1_row.data = src1_contiguous.get();
|
||||||
dst_row_extra.data_device[ctx.device] = dst_contiguous.get();
|
dst_row.data = dst_contiguous.get();
|
||||||
|
|
||||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||||
int64_t num_src1_rows = 0;
|
int64_t num_src1_rows = 0;
|
||||||
|
@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02;
|
src0_row.data = src0_original + i02*nb02;
|
||||||
|
|
||||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||||
|
@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
if (src0_type == GGML_TYPE_BF16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
|
|
@ -5883,13 +5883,6 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
#ifdef GGML_USE_SYCL
|
|
||||||
// disable MoE with SYCL until mul_mat_id is updated
|
|
||||||
if (hparams.n_expert > 0) {
|
|
||||||
n_gpu_layers = 0;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
model.split_mode = split_mode;
|
model.split_mode = split_mode;
|
||||||
model.main_gpu = main_gpu;
|
model.main_gpu = main_gpu;
|
||||||
model.n_gpu_layers = n_gpu_layers;
|
model.n_gpu_layers = n_gpu_layers;
|
||||||
|
@ -8134,7 +8127,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue