Merge branch 'gg/flash-attn' of https://github.com/ggerganov/llama.cpp into flash-attn-cuda
This commit is contained in:
commit
09db1a7cf3
24 changed files with 1255 additions and 325 deletions
12
ggml-cuda.cu
12
ggml-cuda.cu
|
@ -5131,10 +5131,10 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
|||
const block_q_t * x = (const block_q_t *) vx;
|
||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||
|
||||
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index
|
||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||
const int ibx = row*blocks_per_row + i; // x block index
|
||||
|
||||
const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx
|
||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||
|
||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
||||
|
||||
|
@ -11058,6 +11058,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
if (a->ne[3] != b->ne[3]) {
|
||||
return false;
|
||||
}
|
||||
ggml_type a_type = a->type;
|
||||
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
|
||||
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue