Merge branch 'master' into concedo_experimental

# Conflicts:
#	.github/workflows/build.yml
#	ggml-opencl.cpp
This commit is contained in:
Concedo 2023-10-06 17:58:08 +08:00
commit b5cd935cdb
5 changed files with 138 additions and 93 deletions

View file

@ -44,9 +44,12 @@ let package = Package(
cSettings: [ cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32"]), .unsafeFlags(["-Wno-shorten-64-to-32"]),
.define("GGML_USE_K_QUANTS"), .define("GGML_USE_K_QUANTS"),
.define("GGML_USE_ACCELERATE"), .define("GGML_USE_ACCELERATE")
.define("ACCELERATE_NEW_LAPACK"), // NOTE: NEW_LAPACK will required iOS version 16.4+
.define("ACCELERATE_LAPACK_ILP64") // We should consider add this in the future when we drop support for iOS 14
// (ref: ref: https://developer.apple.com/documentation/accelerate/1513264-cblas_sgemm?language=objc)
// .define("ACCELERATE_NEW_LAPACK"),
// .define("ACCELERATE_LAPACK_ILP64")
] + additionalSettings, ] + additionalSettings,
linkerSettings: [ linkerSettings: [
.linkedFramework("Accelerate") .linkedFramework("Accelerate")

View file

@ -361,7 +361,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.push_back({argv[i], 1.0f}); params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f));
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-scaled") { } else if (arg == "--lora-scaled") {
if (++i >= argc) { if (++i >= argc) {
@ -373,7 +373,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])}); params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i])));
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { } else if (arg == "--lora-base") {
if (++i >= argc) { if (++i >= argc) {
@ -616,6 +616,9 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
process_escapes(params.prompt); process_escapes(params.prompt);
process_escapes(params.input_prefix); process_escapes(params.input_prefix);
process_escapes(params.input_suffix); process_escapes(params.input_suffix);
for (auto & antiprompt : params.antiprompt) {
process_escapes(antiprompt);
}
} }
return true; return true;

View file

@ -4,6 +4,7 @@
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import contextlib
import json import json
import os import os
import struct import struct
@ -20,10 +21,10 @@ if 'NO_LOCAL_GGUF' not in os.environ:
import gguf import gguf
def count_model_parts(dir_model: Path) -> int: def count_model_parts(dir_model: Path, prefix: str) -> int:
num_parts = 0 num_parts = 0
for filename in os.listdir(dir_model): for filename in os.listdir(dir_model):
if filename.startswith("pytorch_model-"): if filename.startswith(prefix):
num_parts += 1 num_parts += 1
if num_parts > 0: if num_parts > 0:
@ -77,20 +78,26 @@ print("gguf: loading model "+dir_model.name)
with open(dir_model / "config.json", "r", encoding="utf-8") as f: with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f) hparams = json.load(f)
if hparams["architectures"][0] != "RWForCausalLM": if hparams["architectures"][0] != "FalconForCausalLM":
print("Model architecture not supported: " + hparams["architectures"][0]) print("Model architecture not supported: " + hparams["architectures"][0])
sys.exit(1) sys.exit(1)
# get number of model parts # get number of model parts
num_parts = count_model_parts(dir_model) num_parts = count_model_parts(dir_model, "model-00")
if num_parts:
is_safetensors = True
from safetensors import safe_open
else:
is_safetensors = False
num_parts = count_model_parts(dir_model, "pytorch_model-")
ARCH=gguf.MODEL_ARCH.FALCON ARCH=gguf.MODEL_ARCH.FALCON
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
print("gguf: get model metadata") print("gguf: get model metadata")
block_count = hparams["n_layer"] block_count = hparams["num_hidden_layers"]
gguf_writer.add_name("Falcon") gguf_writer.add_name("Falcon")
gguf_writer.add_context_length(2048) # not in config.json gguf_writer.add_context_length(2048) # not in config.json
@ -98,9 +105,9 @@ gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
gguf_writer.add_embedding_length(hparams["hidden_size"]) gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"]) gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
gguf_writer.add_block_count(block_count) gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"]) gguf_writer.add_head_count(hparams["num_attention_heads"])
if "n_head_kv" in hparams: if "num_kv_heads" in hparams:
gguf_writer.add_head_count_kv(hparams["n_head_kv"]) gguf_writer.add_head_count_kv(hparams["num_kv_heads"])
else: else:
gguf_writer.add_head_count_kv(1) gguf_writer.add_head_count_kv(1)
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
@ -146,8 +153,8 @@ special_vocab.add_to_gguf(gguf_writer)
tensor_map = gguf.get_tensor_name_map(ARCH,block_count) tensor_map = gguf.get_tensor_name_map(ARCH,block_count)
# params for qkv transform # params for qkv transform
n_head = hparams["n_head"] n_head = hparams["num_attention_heads"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 n_head_kv = hparams["num_kv_heads"] if "num_kv_heads" in hparams else 1
head_dim = hparams["hidden_size"] // n_head head_dim = hparams["hidden_size"] // n_head
@ -156,6 +163,10 @@ print("gguf: get tensor metadata")
if num_parts == 0: if num_parts == 0:
part_names = iter(("pytorch_model.bin",)) part_names = iter(("pytorch_model.bin",))
elif is_safetensors:
part_names = (
f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1)
)
else: else:
part_names = ( part_names = (
f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1)
@ -165,60 +176,64 @@ for part_name in part_names:
if args.vocab_only: if args.vocab_only:
break break
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
model_part = torch.load(dir_model / part_name, map_location="cpu") if is_safetensors:
ctx = safe_open(dir_model / part_name, framework="pt", device="cpu")
else:
ctx = contextlib.nullcontext(torch.load(dir_model / part_name, map_location="cpu"))
for name in model_part.keys(): with ctx as model_part:
data = model_part[name] for name in model_part.keys():
data = model_part.get_tensor(name) if is_safetensors else model_part[name]
old_dtype = data.dtype old_dtype = data.dtype
# convert any unsupported data types to float32 # convert any unsupported data types to float32
if data.dtype != torch.float16 and data.dtype != torch.float32: if data.dtype != torch.float16 and data.dtype != torch.float32:
data = data.to(torch.float32) data = data.to(torch.float32)
# QKV tensor transform # QKV tensor transform
# The original query_key_value tensor contains n_head_kv "kv groups", # The original query_key_value tensor contains n_head_kv "kv groups",
# each consisting of n_head/n_head_kv query weights followed by one key # each consisting of n_head/n_head_kv query weights followed by one key
# and one value weight (shared by all query heads in the kv group). # and one value weight (shared by all query heads in the kv group).
# This layout makes it a big pain to work with in GGML. # This layout makes it a big pain to work with in GGML.
# So we rearrange them here,, so that we have n_head query weights # So we rearrange them here,, so that we have n_head query weights
# followed by n_head_kv key weights followed by n_head_kv value weights, # followed by n_head_kv key weights followed by n_head_kv value weights,
# in contiguous fashion. # in contiguous fashion.
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
if "query_key_value" in name: if "query_key_value" in name:
qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head) q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
data = torch.cat((q,k,v)).reshape_as(data) data = torch.cat((q,k,v)).reshape_as(data)
data = data.squeeze().numpy() data = data.squeeze().numpy()
# map tensor names # map tensor names
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
if new_name is None: if new_name is None:
print("Can not map tensor '" + name + "'") print("Can not map tensor '" + name + "'")
sys.exit() sys.exit()
n_dims = len(data.shape) n_dims = len(data.shape)
data_dtype = data.dtype data_dtype = data.dtype
# if f32 desired, convert any float16 to float32 # if f32 desired, convert any float16 to float32
if ftype == 0 and data_dtype == np.float16: if ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32) data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if ftype == 1 and data_dtype == np.float16 and n_dims == 1: if ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32) data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16 # if f16 desired, convert any float32 2-dim weight tensors to float16
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16) data = data.astype(np.float16)
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
gguf_writer.add_tensor(new_name, data) gguf_writer.add_tensor(new_name, data)
print("gguf: write header") print("gguf: write header")

View file

@ -504,9 +504,11 @@ struct llama_server_context
}); });
} }
bool tg = true;
while (n_past < embd.size()) while (n_past < embd.size())
{ {
int n_eval = (int)embd.size() - n_past; int n_eval = (int)embd.size() - n_past;
tg = n_eval == 1;
if (n_eval > params.n_batch) if (n_eval > params.n_batch)
{ {
n_eval = params.n_batch; n_eval = params.n_batch;
@ -633,7 +635,9 @@ struct llama_server_context
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(result.tok); last_n_tokens.push_back(result.tok);
num_tokens_predicted++; if (tg) {
num_tokens_predicted++;
}
} }
// add it to the context // add it to the context
@ -1011,7 +1015,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.push_back({argv[i], 1.0f}); params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f));
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-scaled") else if (arg == "--lora-scaled")
@ -1027,7 +1031,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
params.lora_adapter.push_back({lora_adapter, std::stof(argv[i])}); params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i])));
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-base") else if (arg == "--lora-base")
@ -1124,8 +1128,6 @@ static json format_timings(llama_server_context &llama)
{ {
const auto timings = llama_get_timings(llama.ctx); const auto timings = llama_get_timings(llama.ctx);
assert(timings.n_eval == ptrdiff_t(llama.num_tokens_predicted));
return json{ return json{
{"prompt_n", timings.n_p_eval}, {"prompt_n", timings.n_p_eval},
{"prompt_ms", timings.t_p_eval_ms}, {"prompt_ms", timings.t_p_eval_ms},

View file

@ -203,14 +203,14 @@ inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8
__kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy) __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
{ {
const int i = get_group_id(0); const int i = get_group_id(0) + get_global_offset(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int n = tid / 32; const int n = tid / 32;
const int l = tid - 32 * n; const int l = tid - 32 * n;
const int is = 8 * n + l / 16; const int is = 8 * n + l / 16;
const uint8_t q = x[i].qs[32 * n + l]; const uint8_t q = x[i].qs[32 * n + l];
__global float *y = yy + i * QK_K + 128 * n; __global float *y = yy + get_group_id(0) * QK_K + 128 * n;
const float dall = vload_half(0, &x[i].d); const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin); const float dmin = vload_half(0, &x[i].dmin);
@ -224,7 +224,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
__kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy) __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
{ {
int r = get_local_id(0) / 4; int r = get_local_id(0) / 4;
int i = get_group_id(0); int i = get_group_id(0) + get_global_offset(0);
int tid = r / 2; int tid = r / 2;
int is0 = r % 2; int is0 = r % 2;
int l0 = 16 * is0 + 4 * (get_local_id(0) % 4); int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
@ -242,7 +242,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
float d_all = vload_half(0, &x[i].d); float d_all = vload_half(0, &x[i].d);
float dl = d_all * (us - 32); float dl = d_all * (us - 32);
__global float *y = yy + i * QK_K + 128 * n + 32 * j; __global float *y = yy + get_group_id(0) * QK_K + 128 * n + 32 * j;
const __global uint8_t *q = x[i].qs + 32 * n; const __global uint8_t *q = x[i].qs + 32 * n;
const __global uint8_t *hm = x[i].hmask; const __global uint8_t *hm = x[i].hmask;
@ -252,14 +252,14 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
__kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy) __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
{ {
const int i = get_group_id(0); const int i = get_group_id(0) + get_global_offset(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int il = tid / 8; const int il = tid / 8;
const int ir = tid % 8; const int ir = tid % 8;
const int is = 2 * il; const int is = 2 * il;
const int n = 4; const int n = 4;
__global float *y = yy + i * QK_K + 64 * il + n * ir; __global float *y = yy + get_group_id(0) * QK_K + 64 * il + n * ir;
const float dall = vload_half(0, &x[i].d); const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin); const float dmin = vload_half(0, &x[i].dmin);
@ -282,13 +282,13 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
__kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy) __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
{ {
const int i = get_group_id(0); const int i = get_group_id(0) + get_global_offset(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int il = tid / 16; const int il = tid / 16;
const int ir = tid % 16; const int ir = tid % 16;
const int is = 2 * il; const int is = 2 * il;
__global float *y = yy + i * QK_K + 64 * il + 2 * ir; __global float *y = yy + get_group_id(0) * QK_K + 64 * il + 2 * ir;
const float dall = vload_half(0, &x[i].d); const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin); const float dmin = vload_half(0, &x[i].dmin);
@ -314,13 +314,13 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
__kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy) __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
{ {
const int i = get_group_id(0); const int i = get_group_id(0) + get_global_offset(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int ip = tid / 32; const int ip = tid / 32;
const int il = tid - 32 * ip; const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16; const int is = 8 * ip + il / 16;
__global float *y = yy + i * QK_K + 128 * ip + il; __global float *y = yy + get_group_id(0) * QK_K + 128 * ip + il;
const float d = vload_half(0, &x[i].d); const float d = vload_half(0, &x[i].d);
@ -731,7 +731,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
const uint qk = QUANT_K; const uint qk = QUANT_K;
const uint qr = QUANT_R; const uint qr = QUANT_R;
const int ib = i/qk; // block index const int ib = i/qk + get_global_offset(0); // block index
const int iqs = (i%qk)/qr; // quant index const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
@ -1357,30 +1357,42 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
const enum ggml_type type = src->type; const enum ggml_type type = src->type;
const size_t ts = ggml_type_size(type); const size_t ts = ggml_type_size(type);
const size_t bs = ggml_blck_size(type); const size_t bs = ggml_blck_size(type);
const uint64_t row_size = ts*ne0/bs;
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); const char * x = (const char *) src->data + i2*nb2 + i3*nb3;
if (nb0 == ts && nb1 == ts*ne0/bs) { if (nb0 == ts && nb1 == row_size) {
err = clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*nb1, x, 0, NULL, ev); return clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*row_size, x, 0, NULL, ev);
return err;
} }
if (nb0 == ts) { if (nb0 == ts) {
const size_t buffer_origin[3] = { offset, 0, 0 }; const size_t buffer_origin[3] = { offset, 0, 0 };
const size_t host_origin[3] = { 0, 0, 0 }; const size_t host_origin[3] = { 0, 0, 0 };
const size_t region[3] = { ts*ne0/bs, ne1, 1 }; const size_t region[3] = { row_size, ne1, 1 };
err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts*ne0/bs, 0, nb1, 0, x, 0, NULL, ev); return clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, row_size, 0, nb1, 0, x, 0, NULL, ev);
return err;
} }
std::vector<cl_event> events;
if (ev && ne1>1) events.reserve(ne1-1);
for (uint64_t i1 = 0; i1 < ne1; i1++) { for (uint64_t i1 = 0; i1 < ne1; i1++) {
// pretend the row is a matrix with cols=1 // pretend the row is a matrix with cols=1
const size_t buffer_origin[3] = { offset, i1, 0 }; const size_t buffer_origin[3] = { offset + i1*row_size, 0, 0 };
const size_t host_origin[3] = { 0, 0, 0 }; const size_t host_origin[3] = { 0, 0, 0 };
const size_t region[3] = { ts/bs, ne0, 1 }; const size_t region[3] = { ts, ne0/bs, 1 };
err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, 0, 0, nb0, 0, ((const char *)x) + i1*nb0, 0, NULL, ev); // if an event is requested, make the last write wait for all previous writes to complete
if (ev && i1) {
events.push_back(*ev);
}
cl_uint nevents = i1 == ne1-1 ? events.size() : 0U;
err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts, 0, nb0, 0, x + i1*nb1, nevents, nevents ? events.data() : nullptr, ev);
if (err != CL_SUCCESS) { if (err != CL_SUCCESS) {
break; for (auto event : events) {
clReleaseEvent(event);
}
return err;
} }
} }
return err; for (auto event : events) {
CL_CHECK(clReleaseEvent(event));
}
return CL_SUCCESS;
} }
static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -1511,6 +1523,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size); cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
size_t x_offset = 0;
int64_t pi02 = -1; int64_t pi02 = -1;
int64_t pi03 = -1; int64_t pi03 = -1;
@ -1521,7 +1534,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
int64_t i02 = i12 / r2; int64_t i02 = i12 / r2;
// copy data to device // copy data to device
if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) { if (src0->backend == GGML_BACKEND_GPU) {
x_offset = (i03 * ne02 + i02) * x_ne;
} else if (i02 != pi02 || i03 != pi03) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
pi02 = i02; pi02 = i02;
pi03 = i03; pi03 = i03;
@ -1536,7 +1551,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
(CLBlastTranspose)clblast::Transpose::kYes, (CLBlastTranspose)clblast::Transpose::kNo, (CLBlastTranspose)clblast::Transpose::kYes, (CLBlastTranspose)clblast::Transpose::kNo,
ne01, ne11, ne10, ne01, ne11, ne10,
alpha, alpha,
d_X, 0, ne00, d_X, x_offset, ne00,
d_Y, 0, ne10, d_Y, 0, ne10,
beta, beta,
d_D, 0, ne01, d_D, 0, ne01,
@ -1605,6 +1620,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
bool src1_cont_rows = nb10 == sizeof(float); bool src1_cont_rows = nb10 == sizeof(float);
bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
size_t x_offset = 0;
int64_t pi02 = -1; int64_t pi02 = -1;
int64_t pi03 = -1; int64_t pi03 = -1;
@ -1615,7 +1631,9 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
int64_t i02 = i12 / r2; int64_t i02 = i12 / r2;
// copy src0 to device // copy src0 to device
if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) { if (src0->backend == GGML_BACKEND_GPU) {
x_offset = (i03 * ne02 + i02) * x_ne;
} else if (i02 != pi02 || i03 != pi03) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
pi02 = i02; pi02 = i02;
pi03 = i03; pi03 = i03;
@ -1655,7 +1673,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
(CLBlastTranspose)clblast::Transpose::kYes, (CLBlastTranspose)clblast::Transpose::kNo, (CLBlastTranspose)clblast::Transpose::kYes, (CLBlastTranspose)clblast::Transpose::kNo,
ne01, ne11, ne10, ne01, ne11, ne10,
alpha, alpha,
d_X, 0, ne00, d_X, x_offset, ne00,
d_Y, 0, ne10, d_Y, 0, ne10,
beta, beta,
d_D, 0, ne01, d_D, 0, ne01,
@ -1706,7 +1724,8 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
const int x_ne = ne01 * ne00; const int x_ne = ne01 * ne00;
const int y_ne = ne11 * ne10; const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01; const int d_ne = ne11 * ne01;
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type); const int x_bps = x_ne / ggml_blck_size(type); // blocks per 2D slice
const size_t q_sz = ggml_type_size(type) * x_bps;
size_t x_size; size_t x_size;
size_t y_size; size_t y_size;
@ -1774,9 +1793,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
} else { // general dequantization kernel + CLBlast matrix matrix multiplication } else { // general dequantization kernel + CLBlast matrix matrix multiplication
// convert src0 to fp32 on device // convert src0 to fp32 on device
const size_t global = x_ne / global_denom; const size_t global = x_ne / global_denom;
const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, offset > 0 ? &offset : NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
// copy src1 to device // copy src1 to device
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
@ -1908,17 +1928,19 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
const int64_t ne3 = tensor->ne[3]; const int64_t ne3 = tensor->ne[3];
const ggml_type type = tensor->type; const ggml_type type = tensor->type;
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); const size_t s_sz = ggml_type_size(type) * (size_t) (ne0 * ne1 / ggml_blck_size(type));
const size_t q_sz = s_sz * (size_t) (ne2 * ne3);
size_t q_size; size_t q_size;
cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size); cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
tensor->data = data; tensor->data = data;
// copy tensor to device // copy tensor to device
size_t offset = 0;
for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) { for (int64_t i2 = 0; i2 < ne2; i2++) {
int i = i3*ne2 + i2; CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, offset, tensor, i3, i2, NULL));
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, i*ne0*ne1, tensor, i3, i2, NULL)); offset += s_sz;
} }
} }