llama : avoid ggml_cast, use F32 query
This commit is contained in:
parent
40ea8cd1ac
commit
f9ca5dcbe8
6 changed files with 44 additions and 17 deletions
|
@ -2177,7 +2177,7 @@ static bool ggml_metal_graph_compute(
|
|||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
|
||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||
|
@ -2254,7 +2254,7 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||
const int64_t nsg = ne01 < 4 ? 4 : 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsg = ne01 < 4 ? 12 : 2; // simdgroups per threadgroup (a.k.a. warps)
|
||||
|
||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup (does not work for other values)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue