llama : per-layer KV cache + quantum K cache (#4309)
* per-layer KV * remove unnecessary copies * less code duplication, offload k and v separately * llama : offload KV cache per-layer * llama : offload K shift tensors * llama : offload for rest of the model arches * llama : enable offload debug temporarily * llama : keep the KV related layers on the device * llama : remove mirrors, perform Device -> Host when partial offload * common : add command-line arg to disable KV cache offloading * llama : update session save/load * llama : support quantum K cache (#4312) * llama : support quantum K cache (wip) * metal : add F32 -> Q8_0 copy kernel * cuda : add F32 -> Q8_0 copy kernel ggml-ci * cuda : use mmv kernel for quantum cache ops * llama : pass KV cache type through API * llama : fix build ggml-ci * metal : add F32 -> Q4_0 copy kernel * metal : add F32 -> Q4_1 copy kernel * cuda : wip * cuda : add F32 -> Q4_0 and F32 -> Q4_1 copy kernels * llama-bench : support type_k/type_v * metal : use mm kernel only for quantum KV cache * cuda : add comment * llama : remove memory_f16 and kv_f16 flags --------- Co-authored-by: slaren <slarengh@gmail.com> * readme : add API change notice --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
81bc9214a3
commit
bcc0eb4591
11 changed files with 747 additions and 287 deletions
32
ggml-metal.m
32
ggml-metal.m
|
@ -118,6 +118,11 @@ struct ggml_metal_context {
|
|||
GGML_METAL_DECL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
|
||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DECL_KERNEL(concat);
|
||||
GGML_METAL_DECL_KERNEL(sqr);
|
||||
|
@ -324,6 +329,11 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
GGML_METAL_ADD_KERNEL(im2col_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
|
||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_ADD_KERNEL(concat);
|
||||
GGML_METAL_ADD_KERNEL(sqr);
|
||||
|
@ -425,6 +435,11 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|||
GGML_METAL_DEL_KERNEL(im2col_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
|
||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
|
||||
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
|
||||
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
|
||||
GGML_METAL_DEL_KERNEL(concat);
|
||||
GGML_METAL_DEL_KERNEL(sqr);
|
||||
|
@ -1114,7 +1129,7 @@ void ggml_metal_graph_compute(
|
|||
!ggml_is_transposed(src1) &&
|
||||
src1t == GGML_TYPE_F32 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
ne11 > ne11_mm_min) {
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
|
||||
|
@ -1549,14 +1564,23 @@ void ggml_metal_graph_compute(
|
|||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
{
|
||||
const int nth = MIN(1024, ne00);
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
|
||||
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
||||
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
||||
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
|
||||
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
|
||||
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
|
||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
|
||||
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
|
||||
//case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
|
||||
//case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
} break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue