metal : move mm_id indices to shared mem (#5982)
This commit is contained in:
parent
7ab7b733bb
commit
bb6d00bbf9
2 changed files with 6 additions and 6 deletions
|
@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
|
|||
void kernel_mul_mm_id_impl(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
thread short * src1ids,
|
||||
threadgroup short * src1ids,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
|
@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
|
|||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
// row indices of src1 for expert id
|
||||
int64_t _ne1 = 0;
|
||||
short src1ids[512];
|
||||
threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
|
||||
|
||||
int64_t _ne1 = 0;
|
||||
for (int64_t i1 = 0; i1 < ne1; i1++) {
|
||||
if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
|
||||
src1ids[_ne1++] = i1;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue