Merge branch 'ggerganov:master' into fix-#2023
This commit is contained in:
commit
c04a42de5b
4 changed files with 266 additions and 93 deletions
92
examples/make-ggml.py
Normal file
92
examples/make-ggml.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
"""
|
||||
This script converts Hugging Face llama models to GGML and quantizes them.
|
||||
|
||||
Usage:
|
||||
python make-ggml.py --model {model_dir_or_hf_repo_name} [--outname {output_name} (Optional)] [--outdir {output_directory} (Optional)] [--quants {quant_types} (Optional)] [--keep_fp16 (Optional)]
|
||||
|
||||
Arguments:
|
||||
- --model: (Required) The directory of the downloaded Hugging Face model or the name of the Hugging Face model repository. If the model directory does not exist, it will be downloaded from the Hugging Face model hub.
|
||||
- --outname: (Optional) The name of the output model. If not specified, the last part of the model directory path or the Hugging Face model repo name will be used.
|
||||
- --outdir: (Optional) The directory where the output model(s) will be stored. If not specified, '../models/{outname}' will be used.
|
||||
- --quants: (Optional) The types of quantization to apply. This should be a space-separated list. The default is 'Q4_K_M Q5_K_S'.
|
||||
- --keep_fp16: (Optional) If specified, the FP16 model will not be deleted after the quantized models are created.
|
||||
|
||||
Quant types:
|
||||
- Q4_0: small, very high quality loss - legacy, prefer using Q3_K_M
|
||||
- Q4_1: small, substantial quality loss - legacy, prefer using Q3_K_L
|
||||
- Q5_0: medium, balanced quality - legacy, prefer using Q4_K_M
|
||||
- Q5_1: medium, low quality loss - legacy, prefer using Q5_K_M
|
||||
- Q2_K: smallest, extreme quality loss - not recommended
|
||||
- Q3_K: alias for Q3_K_M
|
||||
- Q3_K_S: very small, very high quality loss
|
||||
- Q3_K_M: very small, very high quality loss
|
||||
- Q3_K_L: small, substantial quality loss
|
||||
- Q4_K: alias for Q4_K_M
|
||||
- Q4_K_S: small, significant quality loss
|
||||
- Q4_K_M: medium, balanced quality - recommended
|
||||
- Q5_K: alias for Q5_K_M
|
||||
- Q5_K_S: large, low quality loss - recommended
|
||||
- Q5_K_M: large, very low quality loss - recommended
|
||||
- Q6_K: very large, extremely low quality loss
|
||||
- Q8_0: very large, extremely low quality loss - not recommended
|
||||
- F16: extremely large, virtually no quality loss - not recommended
|
||||
- F32: absolutely huge, lossless - not recommended
|
||||
"""
|
||||
import subprocess
|
||||
subprocess.run(f"pip install huggingface-hub==0.16.4", shell=True, check=True)
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
def main(model, outname, outdir, quants, keep_fp16):
|
||||
ggml_version = "v3"
|
||||
|
||||
if not os.path.isdir(model):
|
||||
print(f"Model not found at {model}. Downloading...")
|
||||
try:
|
||||
if outname is None:
|
||||
outname = model.split('/')[-1]
|
||||
model = snapshot_download(repo_id=model, cache_dir='../models/hf_cache')
|
||||
except Exception as e:
|
||||
raise Exception(f"Could not download the model: {e}")
|
||||
|
||||
if outdir is None:
|
||||
outdir = f'../models/{outname}'
|
||||
|
||||
if not os.path.isfile(f"{model}/config.json"):
|
||||
raise Exception(f"Could not find config.json in {model}")
|
||||
|
||||
os.makedirs(outdir, exist_ok=True)
|
||||
|
||||
print("Building llama.cpp")
|
||||
subprocess.run(f"cd .. && make quantize", shell=True, check=True)
|
||||
|
||||
fp16 = f"{outdir}/{outname}.ggml{ggml_version}.fp16.bin"
|
||||
|
||||
print(f"Making unquantised GGML at {fp16}")
|
||||
if not os.path.isfile(fp16):
|
||||
subprocess.run(f"python3 ../convert.py {model} --outtype f16 --outfile {fp16}", shell=True, check=True)
|
||||
else:
|
||||
print(f"Unquantised GGML already exists at: {fp16}")
|
||||
|
||||
print("Making quants")
|
||||
for type in quants:
|
||||
outfile = f"{outdir}/{outname}.ggml{ggml_version}.{type}.bin"
|
||||
print(f"Making {type} : {outfile}")
|
||||
subprocess.run(f"../quantize {fp16} {outfile} {type}", shell=True, check=True)
|
||||
|
||||
if not keep_fp16:
|
||||
os.remove(fp16)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Convert/Quantize HF to GGML. If you have the HF model downloaded already, pass the path to the model dir. Otherwise, pass the Hugging Face model repo name. You need to be in the /examples folder for it to work.')
|
||||
parser.add_argument('--model', required=True, help='Downloaded model dir or Hugging Face model repo name')
|
||||
parser.add_argument('--outname', default=None, help='Output model(s) name')
|
||||
parser.add_argument('--outdir', default=None, help='Output directory')
|
||||
parser.add_argument('--quants', nargs='*', default=["Q4_K_M", "Q5_K_S"], help='Quant types')
|
||||
parser.add_argument('--keep_fp16', action='store_true', help='Keep fp16 model', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.model, args.outname, args.outdir, args.quants, args.keep_fp16)
|
60
ggml-cuda.cu
60
ggml-cuda.cu
|
@ -2423,20 +2423,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
|||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
int nnz = 0;
|
||||
size_t max_size = 0, tot_size = 0;
|
||||
#endif
|
||||
size_t best_diff = 1ull << 36;
|
||||
int ibest = -1;
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[id][i];
|
||||
if (b.size >= size && b.ptr != nullptr) {
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
if (b.ptr != nullptr) {
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
++nnz;
|
||||
tot_size += b.size;
|
||||
if (b.size > max_size) max_size = b.size;
|
||||
#endif
|
||||
if (b.size >= size) {
|
||||
size_t diff = b.size - size;
|
||||
if (diff < best_diff) {
|
||||
best_diff = diff;
|
||||
ibest = i;
|
||||
if (!best_diff) {
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ibest >= 0) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
#ifdef DEBUG_CUDA_MALLOC
|
||||
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
|
||||
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
|
||||
#endif
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
|
||||
*actual_size = size;
|
||||
size_t look_ahead_size = (size_t) (1.05 * size);
|
||||
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
|
||||
*actual_size = look_ahead_size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
@ -2955,8 +2988,13 @@ inline void ggml_cuda_op_rope(
|
|||
const int mode = ((int32_t *) src1->data)[2];
|
||||
const int n_ctx = ((int32_t *) src1->data)[3];
|
||||
|
||||
const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
||||
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
|
||||
// RoPE alteration for extended context
|
||||
float freq_base, freq_scale;
|
||||
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
|
||||
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
|
||||
|
||||
const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
||||
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
|
||||
|
||||
bool is_glm = mode & 4;
|
||||
|
||||
|
|
15
ggml-metal.m
15
ggml-metal.m
|
@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
||||
} break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
|
@ -743,15 +743,18 @@ void ggml_metal_graph_compute(
|
|||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#else
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#endif
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
|
|
192
ggml-metal.metal
192
ggml-metal.metal
|
@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
|
|||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
// broadcast, simd group number is ntg / 32
|
||||
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
|
||||
if (tpitg < i) {
|
||||
sum[tpitg] += sum[tpitg + i];
|
||||
}
|
||||
|
@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
|
|||
}
|
||||
}
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
|
@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
|
||||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
|
||||
#if QK_K == 256
|
||||
|
||||
const uint8_t m3 = 3;
|
||||
const int8_t m4 = 4;
|
||||
float yl[16];
|
||||
|
||||
const uint16_t kmask1 = 0x0303;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid = tpitg.y; // expecting 16
|
||||
const int tid = tiisg/2;
|
||||
const int ix = tiisg%2;
|
||||
const int ip = tid/8; // 0 or 1
|
||||
const int il = tid/2 - 4*ip; // 0...3
|
||||
const int ir = tid%2;
|
||||
const int n = 8;
|
||||
const int l0 = n*ir;
|
||||
|
||||
const uint8_t m = 1 << (4*ip + il);
|
||||
const uint16_t m1 = 1 << (4*ip + il);
|
||||
const uint16_t m2 = m1 << 8;
|
||||
|
||||
const int shift = 2*il;
|
||||
const uint16_t qm1 = 0x0003 << shift;
|
||||
const uint16_t qm2 = 0x0300 << shift;
|
||||
const int32_t v1 = 4 << shift;
|
||||
const int32_t v2 = 1024 << shift;
|
||||
|
||||
const uint16_t s_shift1 = 4*ip;
|
||||
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
|
||||
|
@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||
const int q_offset = 32*ip + l0;
|
||||
const int y_offset = 128*ip + 32*il + l0;
|
||||
|
||||
//float sumf = 0;
|
||||
float sumf1 = 0, sumf2 = 0;
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
const int step = sizeof(block_q3_K) * nb / 2;
|
||||
|
||||
const float d_all = (float)(x[i].d);
|
||||
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||
|
||||
device const uint8_t * q = x[i].qs + q_offset;
|
||||
device const uint8_t * h = x[i].hmask + l0;
|
||||
device const float * y = yy + i * QK_K + y_offset;
|
||||
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
||||
for (int i = ix; i < nb; i += 2) {
|
||||
|
||||
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||
|
||||
float s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
yl[l+0] = y1[l+ 0];
|
||||
yl[l+8] = y1[l+16];
|
||||
}
|
||||
float d = d_all * s;
|
||||
sumf1 += d * scales[0];
|
||||
sumf2 += d;
|
||||
//sumf += d_all * s * (scales[0] - 32);
|
||||
|
||||
s = 0;
|
||||
for (int l = 0; l < n; ++l) {
|
||||
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
|
||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
||||
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
||||
device const half * dh = &x[i].d;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
|
||||
const float d_all = (float)dh[0];
|
||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||
|
||||
float s1 = 0, s2 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2];
|
||||
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
||||
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
||||
}
|
||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[0];
|
||||
sumf2[row] += d;
|
||||
|
||||
s1 = s2 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint16_t qs = q[l/2+8];
|
||||
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
||||
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
||||
}
|
||||
d = d_all * (s1 + 1.f/256.f * s2);
|
||||
sumf1[row] += d * scales[1];
|
||||
sumf2[row] += d;
|
||||
|
||||
q += step;
|
||||
h += step;
|
||||
a += step;
|
||||
dh += step;
|
||||
|
||||
}
|
||||
d = d_all * s;
|
||||
sumf1 += d * scales[1];
|
||||
sumf2 += d;
|
||||
//sumf += d_all * s * (scales[1] - 32);
|
||||
|
||||
y1 += 2 * QK_K;
|
||||
|
||||
}
|
||||
|
||||
//sum[ith] = sumf;
|
||||
sum[ith] = sumf1 - 32.f*sumf2;
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
const int row = 2 * r0 + sgitg;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
const int ix = tiisg/4;
|
||||
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
||||
const int im = il/8; // 0, 0, 1, 1
|
||||
const int in = il%8; // 0, 4, 0, 4
|
||||
|
||||
float sumf = 0;
|
||||
float2 sum = {0.f, 0.f};
|
||||
|
||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||
for (int i = ix; i < nb; i += 8) {
|
||||
|
||||
const float d_all = (float)(x[i].d);
|
||||
|
||||
device const uint8_t * q = x[i].qs + il;
|
||||
device const uint8_t * h = x[i].hmask + in;
|
||||
device const float * y = yy + i * QK_K + il;
|
||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
|
||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
|
||||
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
|
||||
device const float * y = yy + i * QK_K + il;
|
||||
|
||||
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
||||
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
||||
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
||||
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
||||
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
|
||||
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
|
||||
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
|
||||
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
|
||||
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
const uint8_t hm = h[l] >> im;
|
||||
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
||||
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
||||
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
||||
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
||||
for (int l = 0; l < 4; l += 2) {
|
||||
const uint16_t hm = h[l/2] >> im;
|
||||
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
|
||||
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
|
||||
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
|
||||
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
|
||||
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
|
||||
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
|
||||
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
|
||||
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
|
||||
}
|
||||
|
||||
}
|
||||
const float sumf = sum[0] + sum[1] * 1.f/256.f;
|
||||
|
||||
sum[ith] = sumf;
|
||||
|
||||
#endif
|
||||
|
||||
//
|
||||
// Accumulate the sum from all threads in the threadgroup
|
||||
//
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%4 == 0) {
|
||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%16 == 0) {
|
||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith == 0) {
|
||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + row] = tot;
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
#if QK_K == 256
|
||||
kernel void kernel_mul_mat_q4_K_f32(
|
||||
|
@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32(
|
|||
|
||||
for (int i = ix; i < nb; i += 8) {
|
||||
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
yl[l+0] = y[l+ 0];
|
||||
yl[l+4] = y[l+16];
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue