Merge branch 'ggerganov:master' into master

This commit is contained in:
Daniel Han 2024-07-11 22:21:03 -07:00 committed by GitHub
commit fee7936705
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 911 additions and 739 deletions

5
.gitignore vendored
View file

@ -61,6 +61,11 @@ llama-batched-swift
out/ out/
tmp/ tmp/
# Deprecated
/main
/server
# CI # CI
!.github/workflows/*.yml !.github/workflows/*.yml

View file

@ -547,11 +547,17 @@ ifdef GGML_OPENBLAS64
endif # GGML_OPENBLAS64 endif # GGML_OPENBLAS64
ifdef GGML_BLIS ifdef GGML_BLIS
MK_CPPFLAGS += -DGGML_USE_BLAS -I/usr/local/include/blis -I/usr/include/blis MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_BLIS -I/usr/local/include/blis -I/usr/include/blis
MK_LDFLAGS += -lblis -L/usr/local/lib MK_LDFLAGS += -lblis -L/usr/local/lib
OBJ_GGML += ggml/src/ggml-blas.o OBJ_GGML += ggml/src/ggml-blas.o
endif # GGML_BLIS endif # GGML_BLIS
ifdef GGML_NVPL
MK_CPPFLAGS += -DGGML_USE_BLAS -DGGML_BLAS_USE_NVPL -DNVPL_ILP64 -I/usr/local/include/nvpl_blas -I/usr/include/nvpl_blas
MK_LDFLAGS += -L/usr/local/lib -lnvpl_blas_core -lnvpl_blas_ilp64_gomp
OBJ_GGML += ggml/src/ggml-blas.o
endif # GGML_NVPL
ifndef GGML_NO_LLAMAFILE ifndef GGML_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE
OBJ_GGML += ggml/src/llamafile/sgemm.o OBJ_GGML += ggml/src/llamafile/sgemm.o

View file

@ -29,6 +29,7 @@ static void print_usage_information(const char * argv0, FILE * stream) {
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
fprintf(stream, " --stdin read prompt from standard input.\n"); fprintf(stream, " --stdin read prompt from standard input.\n");
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
fprintf(stream, " --no-parse-special do not parse control tokens.\n");
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
fprintf(stream, " --show-count print the total number of tokens.\n"); fprintf(stream, " --show-count print the total number of tokens.\n");
} }
@ -195,6 +196,7 @@ int main(int raw_argc, char ** raw_argv) {
// variables where to put any arguments we see. // variables where to put any arguments we see.
bool printing_ids = false; bool printing_ids = false;
bool no_bos = false; bool no_bos = false;
bool no_parse_special = false;
bool disable_logging = false; bool disable_logging = false;
bool show_token_count = false; bool show_token_count = false;
const char * model_path = NULL; const char * model_path = NULL;
@ -229,6 +231,9 @@ int main(int raw_argc, char ** raw_argv) {
else if (arg == "--no-bos") { else if (arg == "--no-bos") {
no_bos = true; no_bos = true;
} }
else if (arg == "--no-parse-special") {
no_parse_special = true;
}
else if (arg == "-p" || arg == "--prompt") { else if (arg == "-p" || arg == "--prompt") {
if (prompt_set) { if (prompt_set) {
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
@ -359,9 +364,10 @@ int main(int raw_argc, char ** raw_argv) {
const bool model_wants_add_bos = llama_should_add_bos_token(model); const bool model_wants_add_bos = llama_should_add_bos_token(model);
const bool add_bos = model_wants_add_bos && !no_bos; const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special;
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens = ::llama_tokenize(model, prompt, add_bos, true); tokens = ::llama_tokenize(model, prompt, add_bos, parse_special);
if (printing_ids) { if (printing_ids) {
printf("["); printf("[");

View file

@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// backend registry // backend registry
#define GGML_REG_MAX_BACKENDS 16 #define GGML_REG_MAX_BACKENDS 64
struct ggml_backend_reg { struct ggml_backend_reg {
char name[128]; char name[128];

View file

@ -8,11 +8,12 @@
# include <Accelerate/Accelerate.h> # include <Accelerate/Accelerate.h>
#elif defined(GGML_BLAS_USE_MKL) #elif defined(GGML_BLAS_USE_MKL)
# include <mkl.h> # include <mkl.h>
#elif defined(GGML_BLAS_USE_BLIS)
# include <blis.h>
#elif defined(GGML_BLAS_USE_NVPL)
# include <nvpl_blas.h>
#else #else
# include <cblas.h> # include <cblas.h>
# ifdef BLIS_ENABLE_CBLAS
# include <blis.h>
# endif
#endif #endif
struct ggml_backend_blas_context { struct ggml_backend_blas_context {
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
openblas_set_num_threads(ctx->n_threads); openblas_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(BLIS_ENABLE_CBLAS) #if defined(GGML_BLAS_USE_BLIS)
bli_thread_set_num_threads(ctx->n_threads); bli_thread_set_num_threads(ctx->n_threads);
#endif #endif
#if defined(GGML_BLAS_USE_NVPL)
nvpl_blas_set_num_threads(ctx->n_threads);
#endif
for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) { for (int64_t i12 = 0; i12 < ne12; i12++) {
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;

View file

@ -104,7 +104,7 @@
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t #define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess #define cudaSuccess hipSuccess
#define __trap abort #define __trap() do { abort(); __builtin_unreachable(); } while(0)
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED

View file

@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
} }
#endif // defined(INT8_MMA_AVAILABLE) #endif // defined(INT8_MMA_AVAILABLE)
} }
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
}
}; };
struct mma_int_B_J8K4 { struct mma_int_B_J8K4 {

File diff suppressed because it is too large Load diff

View file

@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum; reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <bool need_sum> template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1( static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) { if (ix0 >= kx0_padded) {
return; return;
} }
const float4 * x4 = (const float4 *) x;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f; // Load 4 floats per thread and calculate max. abs. value between them:
float amax = fabsf(xi); const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w));
amax = warp_reduce_max(amax); // Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll
float sum; for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
if (need_sum) { amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
sum = warp_reduce_sum(xi);
} }
const float d = amax / 127; float sum;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w;
y[ib].qs[iqs] = q; // Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
}
}
const float d_inv = 127.0f / amax;
char4 q;
q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) {
return;
}
y[ib].d2s6[2 + iqs/16] = sum;
if (iqs % 64 != 0) {
return;
}
const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d;
if (iqs % QK8_1 != 0) {
return; return;
} }
if (need_sum) { if (iqs % 32 != 0) {
y[ib].ds[iqs/QK8_1] = make_half2(d, sum); return;
}
const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum);
} else { } else {
((float *) y[ib].ds)[iqs/QK8_1] = d; y[ib].d4[iqs/32] = d;
} }
} }
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels); const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
if (mmq_need_sum(type_x)) { switch (mmq_get_q8_1_ds_layout(type_x)) {
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); case MMQ_Q8_1_DS_LAYOUT_D4:
} else { quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break;
default:
GGML_ASSERT(false);
break;
} }
} }

View file

@ -5,7 +5,11 @@
#include <cstdint> #include <cstdint>
#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)( typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,

View file

@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
} }
#define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 2 #define VDR_Q2_K_Q8_1_MMQ 4
// contiguous v/x values // contiguous v/x values
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
return dm2f.x*sumf_d - dm2f.y*sumf_m; return dm2f.x*sumf_d - dm2f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
template <int ns8>
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) { const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
float sumf_d = 0.0f; float sumf = 0.0f;
float sumf_m = 0.0f; float sumf_d8 = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]); const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
int sumi_d = 0; int sumi_d0 = 0;
int sumi_m = 0;
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
int sumi_d1 = 0;
const int vi0 = v[i0/(QI8_1/2)];
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303; sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
} }
sumf_d8 += dm2f0.x * sumi_d0;
sumf_d += dm2f.x * sumi_d; #pragma unroll
sumf_m += dm2f.y * sumi_m; for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
}
sumf_d8 += dm2f1.x * sumi_d1;
if (i0/QI8_1 < ns8) {
const float2 s8f = __half22float2(s8[i0/QI8_1]);
sumf -= dm2f0.y*s8f.x;
sumf -= dm2f1.y*s8f.y;
} else {
int sumi_m0 = 0;
#pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) {
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
}
sumf_d8 -= dm2f0.y * sumi_m0;
int sumi_m1 = 0;
#pragma unroll
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
}
sumf_d8 -= dm2f1.y * sumi_m1;
}
} }
return d8*(sumf_d - sumf_m); return sumf + d8*sumf_d8;
} }
#define VDR_Q3_K_Q8_1_MMVQ 1 #define VDR_Q3_K_Q8_1_MMVQ 1
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
return d3 * sumf; return d3 * sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
const float & d3, const float & d8) { const float & d3, const float & d8) {
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
#pragma unroll #pragma unroll
for (int i = i0; i < i0 + QI8_1/2; ++i) { for (int i = i0; i < i0 + QI8_1/2; ++i) {
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404); sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
} }
sumi += sumi_sc * scales[i0 / (QI8_1/2)]; sumi += sumi_sc * scales[i0 / (QI8_1/2)];
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
return dm4f.x*sumf_d - dm4f.y*sumf_m; return dm4f.x*sumf_d - dm4f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
return dm5f.x*sumf_d - dm5f.y*sumf_m; return dm5f.x*sumf_d - dm5f.y*sumf_m;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
return d*sumf; return d*sumf;
} }
// contiguous u/y values // contiguous v/x + u/y values
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
const float & d6, const float * __restrict__ d8) { const float & d6, const float * __restrict__ d8) {
float sumf_d = 0.0f; float sumf_d = 0.0f;
const int sc_packed = get_int_b4(sc, 0);
const int8_t * sc_reg = (const int8_t *) &sc_packed;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
} }
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
} }
return d6 * sumf_d; return d6 * sumf_d;

View file

@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
const ggml_tensor_extra_gpu *src0_extra =
(const ggml_tensor_extra_gpu *)src0->extra;
const ggml_tensor_extra_gpu *src1_extra =
(const ggml_tensor_extra_gpu *)src1->extra;
const ggml_tensor_extra_gpu *dst_extra =
(const ggml_tensor_extra_gpu *)dst->extra;
ggml_tensor_extra_gpu src0_row_extra;
ggml_tensor_extra_gpu src1_row_extra;
ggml_tensor_extra_gpu dst_row_extra;
ggml_tensor src0_row = *src0; ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1; ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst; ggml_tensor dst_row = *dst;
src1_row.backend = GGML_BACKEND_TYPE_GPU; char *src0_original = (char *)src0->data;
dst_row.backend = GGML_BACKEND_TYPE_GPU; char *src1_original = (char *)src1->data;
char *dst_original = (char *)dst->data;
src0_row.extra = &src0_row_extra;
src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src0->data
: (char *)src0_extra->data_device[ctx.device];
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
? (char *)src1->data
: (char *)src1_extra->data_device[ctx.device];
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
? (char *)dst->data
: (char *)dst_extra->data_device[ctx.device];
src0_row.ne[2] = 1; src0_row.ne[2] = 1;
src0_row.ne[3] = 1; src0_row.ne[3] = 1;
@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
const int64_t i1 = id; const int64_t i1 = id;
const int64_t i2 = i12; const int64_t i2 = i12;
src0_row_extra.data_device[ctx.device] = src0_row.data = src0_original + i02*nb02;
src0_original + i02*nb02; src1_row.data = src1_original + + i11*nb11 + i12*nb12;
src1_row_extra.data_device[ctx.device] = dst_row.data = dst_original + i1*nb1 + i2*nb2;
src1_original + + i11*nb11 + i12*nb12;
dst_row_extra.data_device[ctx.device] =
dst_original + i1*nb1 + i2*nb2;
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
} }
@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
src1_row_extra.data_device[ctx.device] = src1_contiguous.get(); src1_row.data = src1_contiguous.get();
dst_row_extra.data_device[ctx.device] = dst_contiguous.get(); dst_row.data = dst_contiguous.get();
for (int64_t i02 = 0; i02 < n_as; i02++) { for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0; int64_t num_src1_rows = 0;
@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
}); });
} }
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02; src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0); GGML_ASSERT(nb1 == sizeof(float)*ne0);
@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
return false; return false;
} }
} }
ggml_type src0_type = op->src[0]->type;
if (src0_type == GGML_TYPE_BF16) {
return false;
}
return true; return true;
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:

View file

@ -5883,13 +5883,6 @@ static bool llm_load_tensors(
auto & hparams = model.hparams; auto & hparams = model.hparams;
#ifdef GGML_USE_SYCL
// disable MoE with SYCL until mul_mat_id is updated
if (hparams.n_expert > 0) {
n_gpu_layers = 0;
}
#endif
model.split_mode = split_mode; model.split_mode = split_mode;
model.main_gpu = main_gpu; model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers; model.n_gpu_layers = n_gpu_layers;
@ -8134,7 +8127,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) { if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);