cuda : fix build

This commit is contained in:
Georgi Gerganov 2024-03-27 10:31:52 +02:00
parent 013721df2b
commit 6be02b5969
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 44 additions and 16 deletions

View file

@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_argsort(ctx, dst);
break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_cuda_flash_attn_ext(ctx, dst);
break;
default:
return false;
}
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
} else {
func(ctx, tensor->src[0], tensor->src[1], tensor);
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));