wip
This commit is contained in:
parent
06c2d0d117
commit
035c4f01e6
3 changed files with 55 additions and 52 deletions
|
@ -2183,6 +2183,7 @@ static bool ggml_metal_graph_compute(
|
|||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
||||
GGML_ASSERT(src3);
|
||||
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_src3 = 0;
|
||||
|
@ -2252,11 +2253,11 @@ static bool ggml_metal_graph_compute(
|
|||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||
|
||||
const int64_t nwarps = 8;
|
||||
const int64_t nwarps = 4;
|
||||
const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t nqptg = 4; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
|
||||
const size_t smem = nqptg*(nhptg*ne00 + nwarps*(nhptg*ne00 + 32))*(sizeof(float)/2);
|
||||
const size_t smem = nqptg*(nhptg*ne00 + nwarps*(nhptg*ne00 + 256))*(sizeof(float)/2);
|
||||
|
||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue