cuda : fix build
This commit is contained in:
parent
013721df2b
commit
6be02b5969
3 changed files with 44 additions and 16 deletions
|
@ -2384,17 +2384,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
ggml_cuda_op_argsort(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
ggml_cuda_flash_attn_ext(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
ggml_cuda_flash_attn_ext(ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||
} else {
|
||||
func(ctx, tensor->src[0], tensor->src[1], tensor);
|
||||
}
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue