Merge branch 'master' into compilade/batch-splits

This commit is contained in:
Francis Couture-Harpin 2024-08-21 04:17:29 -04:00
commit 80d9d2a551
58 changed files with 2499 additions and 657 deletions

View file

@ -1018,10 +1018,6 @@ static bool ggml_is_view_op(enum ggml_op op) {
#define GGML_SCHED_MAX_BACKENDS 16
#endif
#ifndef GGML_SCHED_MAX_SPLITS
#define GGML_SCHED_MAX_SPLITS 2048
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
#endif
@ -1125,7 +1121,8 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
}
#if 0
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
#define GET_CAUSE(node) causes[hash_id(node)]
#else
@ -1549,7 +1546,6 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split));
GGML_ASSERT(sched->splits != NULL);
}
GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS);
split = &sched->splits[i_split];
split->backend_id = node_backend_id;
split->i_start = i;
@ -1865,13 +1861,14 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0]));
sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *));
const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2;
const size_t ggml_sched_max_splits = graph_size; // at most there is one split for each node in the graph
const size_t nodes_size = graph_size + ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0]));
sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0]));
sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0]));
sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0]));
sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
sched->context_buffer_size = ggml_sched_max_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false);
sched->context_buffer = malloc(sched->context_buffer_size);
const int initial_splits_capacity = 16;

View file

@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
// RPC commands
enum rpc_cmd {
ALLOC_BUFFER = 0,
GET_ALIGNMENT,
GET_MAX_SIZE,
BUFFER_GET_BASE,
FREE_BUFFER,
BUFFER_CLEAR,
SET_TENSOR,
GET_TENSOR,
COPY_TENSOR,
GRAPH_COMPUTE,
GET_DEVICE_MEMORY,
RPC_CMD_ALLOC_BUFFER = 0,
RPC_CMD_GET_ALIGNMENT,
RPC_CMD_GET_MAX_SIZE,
RPC_CMD_BUFFER_GET_BASE,
RPC_CMD_FREE_BUFFER,
RPC_CMD_BUFFER_CLEAR,
RPC_CMD_SET_TENSOR,
RPC_CMD_GET_TENSOR,
RPC_CMD_COPY_TENSOR,
RPC_CMD_GRAPH_COMPUTE,
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_COUNT,
};
// RPC data structures
@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.empty());
delete ctx;
@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | base_ptr (8 bytes) |
@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
GGML_ASSERT(status);
}
@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) |
@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
GGML_ASSERT(status);
// output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1);
@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
GGML_ASSERT(status);
}
@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | alignment (8 bytes) |
@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | max_size (8 bytes) |
@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
@ -636,7 +637,7 @@ GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const
}
GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
return false;
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
@ -678,6 +679,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
}
auto sock = get_socket(endpoint);
if (sock == nullptr) {
fprintf(stderr, "Failed to connect to %s\n", endpoint);
return nullptr;
}
size_t alignment = get_alignment(sock);
@ -719,7 +721,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
// input serialization format: | 0 bytes |
std::vector<uint8_t> input;
std::vector<uint8_t> output;
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | free (8 bytes) | total (8 bytes) |
@ -1098,59 +1100,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
if (cmd >= RPC_CMD_COUNT) {
// fail fast if the command is invalid
fprintf(stderr, "Unknown command: %d\n", cmd);
break;
}
std::vector<uint8_t> input;
std::vector<uint8_t> output;
uint64_t input_size;
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
break;
}
input.resize(input_size);
try {
input.resize(input_size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
break;
}
if (!recv_data(sockfd, input.data(), input_size)) {
break;
}
bool ok = true;
switch (cmd) {
case ALLOC_BUFFER: {
case RPC_CMD_ALLOC_BUFFER: {
ok = server.alloc_buffer(input, output);
break;
}
case GET_ALIGNMENT: {
case RPC_CMD_GET_ALIGNMENT: {
server.get_alignment(output);
break;
}
case GET_MAX_SIZE: {
case RPC_CMD_GET_MAX_SIZE: {
server.get_max_size(output);
break;
}
case BUFFER_GET_BASE: {
case RPC_CMD_BUFFER_GET_BASE: {
ok = server.buffer_get_base(input, output);
break;
}
case FREE_BUFFER: {
case RPC_CMD_FREE_BUFFER: {
ok = server.free_buffer(input);
break;
}
case BUFFER_CLEAR: {
case RPC_CMD_BUFFER_CLEAR: {
ok = server.buffer_clear(input);
break;
}
case SET_TENSOR: {
case RPC_CMD_SET_TENSOR: {
ok = server.set_tensor(input);
break;
}
case GET_TENSOR: {
case RPC_CMD_GET_TENSOR: {
ok = server.get_tensor(input, output);
break;
}
case COPY_TENSOR: {
case RPC_CMD_COPY_TENSOR: {
ok = server.copy_tensor(input, output);
break;
}
case GRAPH_COMPUTE: {
case RPC_CMD_GRAPH_COMPUTE: {
ok = server.graph_compute(input, output);
break;
}
case GET_DEVICE_MEMORY: {
case RPC_CMD_GET_DEVICE_MEMORY: {
// output serialization format: | free (8 bytes) | total (8 bytes) |
output.resize(2*sizeof(uint64_t), 0);
memcpy(output.data(), &free_mem, sizeof(free_mem));
@ -1203,8 +1215,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
return;
}
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
fflush(stdout);
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
printf("Client connection closed\n");
fflush(stdout);
}
#ifdef _WIN32
WSACleanup();

View file

@ -893,43 +893,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
template <typename T>
static void im2col_kernel(const float *x, T *dst, int offset_delta,
int IW, int IH, int OW, int KW, int KH,
int pelements, int CHW, int s0, int s1, int p0,
int p1, int d0, int d1,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_id(2) +
item_ct1.get_group(2) * item_ct1.get_local_range(2);
if (i >= pelements) {
return;
}
const int ksize = OW * (KH > 1 ? KW : 1);
const int kx = i / ksize;
const int kd = kx * ksize;
const int ky = (i - kd) / OW;
const int ix = i % OW;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
const int64_t offset_dst =
(item_ct1.get_group(1) * OW + ix) * CHW +
(item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] =
sycl::vec<float, 1>(0.0f)
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
} else {
const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
dst[offset_dst] =
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
}
}
template <typename Ti, typename To>
static void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
@ -1742,32 +1705,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
});
}
template <typename T>
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
int OW, int OH, int KW, int KH, int IC,
int offset_delta, int s0, int s1, int p0,
int p1, int d0, int d1,
queue_ptr stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
sycl::range<3> block_nums(IC, OH, num_blocks);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums *
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
}
}
static bool g_sycl_loaded = false;
bool ggml_sycl_loaded(void) {
@ -2636,47 +2573,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
(void) src1_dd;
}
inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
}
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
const ggml_tensor *src1, ggml_tensor *dst,
const float *src0_dd, const float *src1_dd,
@ -3581,7 +3477,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;

View file

@ -25,5 +25,6 @@
#include "norm.hpp"
#include "softmax.hpp"
#include "tsembd.hpp"
#include "im2col.hpp"
#endif // GGML_SYCL_BACKEND_HPP

View file

@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
<< ", line:" << __LINE__ << std::endl;
std::exit(1);
}
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
const int64_t max_range = std::numeric_limits<int>::max();
int64_t sycl_down_blk_size = block_size;
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
while(global_range > max_range) {
sycl_down_blk_size /= 2;
global_range = accumulate_block_num * sycl_down_blk_size;
}
return sycl_down_blk_size;
}

View file

@ -130,6 +130,7 @@ typedef sycl::float2 dfloat2;
#endif // GGML_SYCL_F16
#define MMVQ_MAX_BATCH_SIZE 8
#define MMVQ_MIN_BATCH_SIZE 4
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
@ -352,4 +353,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
}
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
#endif // GGML_SYCL_COMMON_HPP

View file

@ -3,19 +3,19 @@
#include "presets.hpp"
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2));
if (i >= k) {
return;
}
const int ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
const int64_t ib = i/qk; // block index
const int64_t iqs = (i%qk)/qr; // quant index
const int64_t iybs = i - i%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k,
dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) {
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
}
template <typename dst_t>
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
#if QK_K == 256
{
dpct::has_capability_or_fail(stream->get_device(),
@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
#if QK_K == 256
{
dpct::has_capability_or_fail(stream->get_device(),
@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
const int64_t nb32 = k / 32;
const int64_t nb = (k + 255) / 256;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
const int64_t nb32 = k / 32;
const int64_t nb = (k + 255) / 256;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
template <typename dst_t>
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
#if QK_K == 256
{
dpct::has_capability_or_fail(stream->get_device(),
@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
#if QK_K == 256
{
dpct::has_capability_or_fail(stream->get_device(),
@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
template <typename dst_t>
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = k / QK_K;
const int64_t nb = k / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K;
const int64_t nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
#else
@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
}
template <typename dst_t>
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
const int nb = (k + QK_K - 1) / QK_K;
const int64_t nb = (k + QK_K - 1) / QK_K;
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
}
template <typename src_t, typename dst_t>
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
if (i >= k) {
return;
}
const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
// make each work-item deal with more elements since sycl global range can not exceed max int
const src_t * x = (src_t *) vx;
y[i] = x[i];
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
y[i] = x[i];
}
}
template <typename src_t, typename dst_t>
static void convert_unary_sycl(const void *__restrict__ vx,
dst_t *__restrict__ y, const int k,
dst_t *__restrict__ y, const int64_t k,
dpct::queue_ptr stream) {
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
// decrease global range when it exceeds the max int
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
sycl::range<3> block_nums(1, 1, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(
sycl::range<3>(1, 1, num_blocks) *
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
sycl::nd_range<3>(block_nums * local_range, local_range),
[=](sycl::nd_item<3> item_ct1) {
convert_unary<src_t>(vx, y, k, item_ct1);
});

View file

@ -17,7 +17,7 @@
template <typename T>
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
int k, dpct::queue_ptr stream);
int64_t k, dpct::queue_ptr stream);
typedef to_t_sycl_t<float> to_fp32_sycl_t;
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;

View file

@ -15,9 +15,9 @@
#include "common.hpp"
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q4_0 * x = (const block_q4_0 *) vx;
@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q4_1 * x = (const block_q4_1 *) vx;
@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q5_0 * x = (const block_q5_0 *) vx;
@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q5_1 * x = (const block_q5_1 *) vx;
@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
#endif // GGML_SYCL_F16
}
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
const int iqs, dfloat2 &v) {
const block_q8_0 * x = (const block_q8_0 *) vx;
@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
}
template<typename dst_t>
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
// assume 32 threads
const int tid = item_ct1.get_local_id(2);
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
}
template<typename dst_t>
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
// assume 32 threads
const int tid = item_ct1.get_local_id(2);
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t ib = 8*i + ir;
if (ib >= nb32) {
return;
}
@ -203,14 +203,14 @@ template<typename dst_t>
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_q2_K * x = (const block_q2_K *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
const int64_t n = tid/32;
const int64_t l = tid - 32*n;
const int64_t is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;
@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
#else
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const int64_t is = tid/16; // 0 or 1
const int64_t il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
dst_t * y = yy + i*QK_K + 16*is + il;
@ -239,19 +239,19 @@ template<typename dst_t>
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_q3_K * x = (const block_q3_K *) vx;
#if QK_K == 256
const int r = item_ct1.get_local_id(2) / 4;
const int tid = r/2;
const int is0 = r%2;
const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
const int n = tid / 4;
const int j = tid - 4*n;
const int64_t r = item_ct1.get_local_id(2) / 4;
const int64_t tid = r/2;
const int64_t is0 = r%2;
const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
const int64_t n = tid / 4;
const int64_t j = tid - 4*n;
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int64_t is = 8*n + 2*j + is0;
int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
#else
const int tid = item_ct1.get_local_id(2);
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const int im = il/8; // 0...1
const int in = il%8; // 0...7
const int64_t tid = item_ct1.get_local_id(2);
const int64_t is = tid/16; // 0 or 1
const int64_t il = tid%16; // 0...15
const int64_t im = il/8; // 0...1
const int64_t in = il%8; // 0...7
dst_t * y = yy + i*QK_K + 16*is + il;
@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
#if QK_K == 256
// assume 32 threads
const int tid = item_ct1.get_local_id(2);
const int il = tid/8;
const int ir = tid%8;
const int is = 2*il;
const int n = 4;
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8;
const int64_t ir = tid%8;
const int64_t is = 2*il;
const int64_t n = 4;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
}
#else
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
const uint8_t * q = x[i].qs;
dst_t * y = yy + i*QK_K;
const float d = (float)x[i].dm[0];
@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
const sycl::nd_item<3> &item_ct1) {
const block_q5_K * x = (const block_q5_K *) vx;
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = item_ct1.get_local_id(2);
const int il = tid/16; // il is in 0...3
const int ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/16; // il is in 0...3
const int64_t ir = tid%16; // ir is in 0...15
const int64_t is = 2*il; // is is in 0...6
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
#else
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
const uint8_t q = x[i].qs[tid];
const int im = tid/8; // 0...3
const int in = tid%8; // 0...7
const int is = tid/16; // 0 or 1
const int64_t im = tid/8; // 0...3
const int64_t in = tid%8; // 0...7
const int64_t is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d;
dst_t * y = yy + i*QK_K + tid;
@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
const sycl::nd_item<3> &item_ct1) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = item_ct1.get_local_id(2);
const int ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16;
const int64_t tid = item_ct1.get_local_id(2);
const int64_t ip = tid/32; // ip is 0 or 1
const int64_t il = tid - 32*ip; // 0...32
const int64_t is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il;
@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
#else
// assume 32 threads
const int tid = item_ct1.get_local_id(2);
const int ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15
const int64_t tid = item_ct1.get_local_id(2);
const int64_t ip = tid/16; // 0 or 1
const int64_t il = tid - 16*ip; // 0...15
dst_t * y = yy + i*QK_K + 16*ip + il;
@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
const uint8_t *ksigns_iq2xs_ptr,
const uint8_t *kmask_iq2xs_ptr) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * aux8 = (const uint8_t *)q2;
@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq2_xs * x = (const block_iq2_xs *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
@ -504,13 +504,13 @@ __dpct_inline__ static void
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq2_s * x = (const block_iq2_s *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
const uint8_t *ksigns_iq2xs,
const uint8_t *kmask_iq2xs) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * q3 = x[i].qs + 8*ib;
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq3_s * x = (const block_iq3_s *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_gpu) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq1_s * x = (const block_iq1_s *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1,
const uint32_t *iq1s_grid_gpu) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq1_m * x = (const block_iq1_m *) vx;
const int tid = item_ct1.get_local_id(2);
const int64_t tid = item_ct1.get_local_id(2);
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * sc = (const uint16_t *)x[i].scales;
iq1m_scale_t scale;
@ -656,12 +656,12 @@ __dpct_inline__ static void
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
const int tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d;
@ -678,12 +678,12 @@ template <typename dst_t>
__dpct_inline__ static void
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_group(2);
const int64_t i = item_ct1.get_group(2);
const block_iq4_xs * x = (const block_iq4_xs *)vx;
const int tid = item_ct1.get_local_id(2);
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
const int64_t tid = item_ct1.get_local_id(2);
const int64_t il = tid/8; // 0...3
const int64_t ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);

View file

@ -4,7 +4,7 @@
#include "presets.hpp"
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const sycl::half *x = (const sycl::half *)vx;
// automatic half -> float type cast if dfloat == float
@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
v.y() = x[ib + iqs + 1];
}
static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const float * x = (const float *) vx;
// automatic half -> float type cast if dfloat == float

View file

@ -0,0 +1,125 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#include "im2col.hpp"
template <typename T>
static void im2col_kernel(
const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
const sycl::nd_item<3> &item_ct1) {
const int64_t work_group_size = item_ct1.get_local_range(2);
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
// make each work-item deal with more elements since sycl global range can not exceed max int
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int64_t oh = item_ct1.get_group(1);
const int64_t batch = item_ct1.get_group(0) / IC;
const int64_t ic = item_ct1.get_group(0) % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] =
sycl::vec<float, 1>(0.0f)
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] =
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
}
}
}
template <typename T>
static void im2col_sycl(
const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0, int s1, int p0, int p1, int d0, int d1,
queue_ptr stream) {
const int64_t parallel_elements = OW * KW * KH;
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
// decrease global range when it exceeds the max int
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
sycl::range<3> local_range(1, 1, local_size);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * local_range, local_range),
[=](sycl::nd_item<3> item_ct1) {
im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
parallel_elements, (IC * KH * KW), s0, s1, p0,
p1, d0, d1, item_ct1);
});
}
}
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if (dst->type == GGML_TYPE_F16) {
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
} else {
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
}
(void) src0;
(void) src0_dd;
}

View file

@ -0,0 +1,23 @@
//
// MIT license
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: MIT
//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
#ifndef GGML_SYCL_IM2COL_HPP
#define GGML_SYCL_IM2COL_HPP
#include "common.hpp"
void ggml_sycl_op_im2col(
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
const queue_ptr &main_stream);
#endif // GGML_SYCL_IM2COL_HPP

View file

@ -180,6 +180,7 @@ struct vk_device_struct {
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_acc_f32;
vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
vk_pipeline pipeline_mul_f32;
vk_pipeline pipeline_div_f32;
@ -1687,6 +1688,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@ -3971,6 +3974,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_get_rows_f32[src0->type];
}
return nullptr;
case GGML_OP_ACC:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_acc_f32;
}
return nullptr;
case GGML_OP_ADD:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_add_f32;
@ -4463,6 +4471,28 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
}, dryrun);
}
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
d_offset,
0.0f, 0.0f, offset,
}, dryrun);
}
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
@ -5621,6 +5651,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_REPEAT:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -5668,6 +5699,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_REPEAT:
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
break;
case GGML_OP_ACC:
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_GET_ROWS:
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
@ -5808,6 +5843,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
switch (tensor->op) {
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_GET_ROWS:
case GGML_OP_MUL:
case GGML_OP_DIV:
@ -6539,6 +6575,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_CONCAT:
@ -6995,6 +7032,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_ADD) {
tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
} else if (tensor->op == GGML_OP_ACC) {
tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
} else if (tensor->op == GGML_OP_NORM) {
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_GROUP_NORM) {

View file

@ -0,0 +1,24 @@
#version 450
#include "types.comp"
#include "generic_binary_head.comp"
void main() {
const uint idx = gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
const uint offset = p.param3;
const uint src1_i = idx - offset;
const uint oz = src1_i / p.nb02;
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
const uint ox = src1_i % p.nb01;
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
} else {
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
}
}

View file

@ -368,6 +368,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
}));