k_quants: fixed issue caused by merging with master
This commit is contained in:
parent
53e81ca289
commit
5fd83379ff
1 changed files with 7 additions and 4 deletions
11
ggml-cuda.cu
11
ggml-cuda.cu
|
@ -706,7 +706,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -823,9 +823,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
const int num_blocks_per_row = ncols / QK_K;
|
const int num_blocks_per_row = ncols / QK_K;
|
||||||
const int ib0 = row*num_blocks_per_row;
|
const int ib0 = row*num_blocks_per_row;
|
||||||
|
|
||||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
|
||||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
|
||||||
|
|
||||||
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
|
@ -833,6 +830,9 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
const uint16_t kmask2 = 0x0f0f;
|
const uint16_t kmask2 = 0x0f0f;
|
||||||
const uint16_t kmask3 = 0xc0c0;
|
const uint16_t kmask3 = 0xc0c0;
|
||||||
|
|
||||||
|
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||||
|
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||||
|
|
||||||
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||||
|
|
||||||
const int il = tid/step; // 0...3
|
const int il = tid/step; // 0...3
|
||||||
|
@ -878,6 +878,9 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
||||||
|
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
|
||||||
|
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
|
||||||
|
|
||||||
const int step = tid * K_QUANTS_PER_ITERATION;
|
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||||
|
|
||||||
uint16_t aux16[2];
|
uint16_t aux16[2];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue