Merge branch 'master' into pr/8836
This commit is contained in:
commit
cfe866e152
26 changed files with 546 additions and 300 deletions
|
@ -106,6 +106,7 @@ Typically finetunes of the base models below are supported as well.
|
|||
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
|
||||
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
|
||||
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)
|
||||
- [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a)
|
||||
|
||||
(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))
|
||||
|
||||
|
|
|
@ -77,6 +77,41 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
//
|
||||
// Environment variable utils
|
||||
//
|
||||
|
||||
template<typename T>
|
||||
static typename std::enable_if<std::is_same<T, std::string>::value, void>::type
|
||||
get_env(std::string name, T & target) {
|
||||
char * value = std::getenv(name.c_str());
|
||||
target = value ? std::string(value) : target;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static typename std::enable_if<!std::is_same<T, bool>::value && std::is_integral<T>::value, void>::type
|
||||
get_env(std::string name, T & target) {
|
||||
char * value = std::getenv(name.c_str());
|
||||
target = value ? std::stoi(value) : target;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static typename std::enable_if<std::is_floating_point<T>::value, void>::type
|
||||
get_env(std::string name, T & target) {
|
||||
char * value = std::getenv(name.c_str());
|
||||
target = value ? std::stof(value) : target;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static typename std::enable_if<std::is_same<T, bool>::value, void>::type
|
||||
get_env(std::string name, T & target) {
|
||||
char * value = std::getenv(name.c_str());
|
||||
if (value) {
|
||||
std::string val(value);
|
||||
target = val == "1" || val == "true";
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// CPU utils
|
||||
//
|
||||
|
@ -220,12 +255,6 @@ int32_t cpu_get_num_math() {
|
|||
// CLI argument parsing
|
||||
//
|
||||
|
||||
void gpt_params_handle_hf_token(gpt_params & params) {
|
||||
if (params.hf_token.empty() && std::getenv("HF_TOKEN")) {
|
||||
params.hf_token = std::getenv("HF_TOKEN");
|
||||
}
|
||||
}
|
||||
|
||||
void gpt_params_handle_model_default(gpt_params & params) {
|
||||
if (!params.hf_repo.empty()) {
|
||||
// short-hand to avoid specifying --hf-file -> default it to --model
|
||||
|
@ -273,7 +302,9 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
|
||||
gpt_params_handle_model_default(params);
|
||||
|
||||
gpt_params_handle_hf_token(params);
|
||||
if (params.hf_token.empty()) {
|
||||
get_env("HF_TOKEN", params.hf_token);
|
||||
}
|
||||
|
||||
if (params.escape) {
|
||||
string_process_escapes(params.prompt);
|
||||
|
@ -293,6 +324,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
return true;
|
||||
}
|
||||
|
||||
void gpt_params_parse_from_env(gpt_params & params) {
|
||||
// we only care about server-related params for now
|
||||
get_env("LLAMA_ARG_MODEL", params.model);
|
||||
get_env("LLAMA_ARG_THREADS", params.n_threads);
|
||||
get_env("LLAMA_ARG_CTX_SIZE", params.n_ctx);
|
||||
get_env("LLAMA_ARG_N_PARALLEL", params.n_parallel);
|
||||
get_env("LLAMA_ARG_BATCH", params.n_batch);
|
||||
get_env("LLAMA_ARG_UBATCH", params.n_ubatch);
|
||||
get_env("LLAMA_ARG_N_GPU_LAYERS", params.n_gpu_layers);
|
||||
get_env("LLAMA_ARG_THREADS_HTTP", params.n_threads_http);
|
||||
get_env("LLAMA_ARG_CHAT_TEMPLATE", params.chat_template);
|
||||
get_env("LLAMA_ARG_N_PREDICT", params.n_predict);
|
||||
get_env("LLAMA_ARG_ENDPOINT_METRICS", params.endpoint_metrics);
|
||||
get_env("LLAMA_ARG_ENDPOINT_SLOTS", params.endpoint_slots);
|
||||
get_env("LLAMA_ARG_EMBEDDINGS", params.embedding);
|
||||
get_env("LLAMA_ARG_FLASH_ATTN", params.flash_attn);
|
||||
get_env("LLAMA_ARG_DEFRAG_THOLD", params.defrag_thold);
|
||||
}
|
||||
|
||||
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
const auto params_org = params; // the example can modify the default params
|
||||
|
||||
|
|
|
@ -267,7 +267,7 @@ struct gpt_params {
|
|||
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
|
||||
};
|
||||
|
||||
void gpt_params_handle_hf_token(gpt_params & params);
|
||||
void gpt_params_parse_from_env(gpt_params & params);
|
||||
void gpt_params_handle_model_default(gpt_params & params);
|
||||
|
||||
bool gpt_params_parse_ex (int argc, char ** argv, gpt_params & params);
|
||||
|
|
|
@ -295,6 +295,7 @@ class Model:
|
|||
gguf.MODEL_TENSOR.FFN_GATE_INP,
|
||||
gguf.MODEL_TENSOR.POS_EMBD,
|
||||
gguf.MODEL_TENSOR.TOKEN_TYPES,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
)
|
||||
)
|
||||
or not name.endswith(".weight")
|
||||
|
@ -2711,7 +2712,7 @@ class StarCoder2Model(Model):
|
|||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
||||
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel")
|
||||
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
|
||||
class MambaModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
|
@ -2742,7 +2743,10 @@ class MambaModel(Model):
|
|||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
|
||||
|
||||
use_dt_b_c_norm = False
|
||||
# For falconmamba we do apply RMS norm on B / DT and C layers
|
||||
if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",):
|
||||
use_dt_b_c_norm = True
|
||||
# Fail early for models which don't have a block expansion factor of 2
|
||||
assert d_inner == 2 * d_model
|
||||
|
||||
|
@ -2750,12 +2754,13 @@ class MambaModel(Model):
|
|||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
|
||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_tok_embd = None
|
||||
|
@ -2782,23 +2787,6 @@ class MambaModel(Model):
|
|||
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
if bid is not None and new_name in (
|
||||
self.format_tensor_name(
|
||||
n, bid, ".weight" if name.endswith(".weight") else ""
|
||||
)
|
||||
for n in [
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D,
|
||||
gguf.MODEL_TENSOR.SSM_X,
|
||||
gguf.MODEL_TENSOR.SSM_DT,
|
||||
gguf.MODEL_TENSOR.SSM_A,
|
||||
gguf.MODEL_TENSOR.SSM_D,
|
||||
]
|
||||
):
|
||||
return gguf.GGMLQuantizationType.F32
|
||||
|
||||
return super().tensor_force_quant(name, new_name, bid, n_dims)
|
||||
|
||||
|
||||
@Model.register("CohereForCausalLM")
|
||||
class CommandR2Model(Model):
|
||||
|
@ -3792,7 +3780,7 @@ class ExaoneModel(Model):
|
|||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
||||
assert(hparams["activation_function"] == "silu")
|
||||
assert (hparams["activation_function"] == "silu")
|
||||
|
||||
max_position_embeddings = hparams["max_position_embeddings"]
|
||||
embed_dim = hparams["hidden_size"]
|
||||
|
@ -3855,8 +3843,8 @@ class ExaoneModel(Model):
|
|||
|
||||
super().prepare_tensors()
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
# tree of lazy tensors
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
|
|
|
@ -20,6 +20,10 @@
|
|||
#include "ggml-cann.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
#include "ggml-vulkan.h"
|
||||
#endif
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
|
@ -1108,7 +1112,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
}
|
||||
}
|
||||
|
||||
clip_ctx * new_clip = new clip_ctx;
|
||||
clip_ctx * new_clip = new clip_ctx{};
|
||||
|
||||
// update projector type
|
||||
{
|
||||
|
@ -1142,6 +1146,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
LOG_TEE("%s: CLIP using CANN backend\n", __func__);
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
new_clip->backend = ggml_backend_vk_init(0);
|
||||
LOG_TEE("%s: CLIP using Vulkan backend\n", __func__);
|
||||
#endif
|
||||
|
||||
if (!new_clip->backend) {
|
||||
new_clip->backend = ggml_backend_cpu_init();
|
||||
|
|
|
@ -247,6 +247,25 @@ logging:
|
|||
--log-append Don't truncate the old log file.
|
||||
```
|
||||
|
||||
Available environment variables (if specified, these variables will override parameters specified in arguments):
|
||||
|
||||
- `LLAMA_CACHE` (cache directory, used by `--hf-repo`)
|
||||
- `HF_TOKEN` (Hugging Face access token, used when accessing a gated model with `--hf-repo`)
|
||||
- `LLAMA_ARG_MODEL`
|
||||
- `LLAMA_ARG_THREADS`
|
||||
- `LLAMA_ARG_CTX_SIZE`
|
||||
- `LLAMA_ARG_N_PARALLEL`
|
||||
- `LLAMA_ARG_BATCH`
|
||||
- `LLAMA_ARG_UBATCH`
|
||||
- `LLAMA_ARG_N_GPU_LAYERS`
|
||||
- `LLAMA_ARG_THREADS_HTTP`
|
||||
- `LLAMA_ARG_CHAT_TEMPLATE`
|
||||
- `LLAMA_ARG_N_PREDICT`
|
||||
- `LLAMA_ARG_ENDPOINT_METRICS`
|
||||
- `LLAMA_ARG_ENDPOINT_SLOTS`
|
||||
- `LLAMA_ARG_EMBEDDINGS`
|
||||
- `LLAMA_ARG_FLASH_ATTN`
|
||||
- `LLAMA_ARG_DEFRAG_THOLD`
|
||||
|
||||
## Build
|
||||
|
||||
|
|
|
@ -2507,6 +2507,9 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
// parse arguments from environment variables
|
||||
gpt_params_parse_from_env(params);
|
||||
|
||||
// TODO: not great to use extern vars
|
||||
server_log_json = params.log_json;
|
||||
server_verbose = params.verbosity > 0;
|
||||
|
|
|
@ -893,43 +893,6 @@ static void clamp_f32(const float * x, float * dst, const float min, const float
|
|||
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void im2col_kernel(const float *x, T *dst, int offset_delta,
|
||||
int IW, int IH, int OW, int KW, int KH,
|
||||
int pelements, int CHW, int s0, int s1, int p0,
|
||||
int p1, int d0, int d1,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (i >= pelements) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int kx = i / ksize;
|
||||
const int kd = kx * ksize;
|
||||
const int ky = (i - kd) / OW;
|
||||
const int ix = i % OW;
|
||||
|
||||
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||
const int64_t iih = item_ct1.get_group(1) * s1 + ky * d1 - p1;
|
||||
|
||||
const int64_t offset_dst =
|
||||
(item_ct1.get_group(1) * OW + ix) * CHW +
|
||||
(item_ct1.get_group(0) * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] =
|
||||
sycl::vec<float, 1>(0.0f)
|
||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
} else {
|
||||
const int64_t offset_src = item_ct1.get_group(0) * offset_delta;
|
||||
dst[offset_dst] =
|
||||
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
|
||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Ti, typename To>
|
||||
static void pool2d_nchw_kernel(
|
||||
const int ih, const int iw, const int oh, const int ow,
|
||||
|
@ -1742,32 +1705,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
|
||||
int OW, int OH, int KW, int KH, int IC,
|
||||
int offset_delta, int s0, int s1, int p0,
|
||||
int p1, int d0, int d1,
|
||||
queue_ptr stream) {
|
||||
const int parallel_elements = OW * KW * KH;
|
||||
const int num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
||||
sycl::range<3> block_nums(IC, OH, num_blocks);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums *
|
||||
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
im2col_kernel(x, dst, offset_delta, IW, IH, OW, KW, KH,
|
||||
parallel_elements, (IC * KH * KW), s0, s1, p0,
|
||||
p1, d0, d1, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static bool g_sycl_loaded = false;
|
||||
|
||||
bool ggml_sycl_loaded(void) {
|
||||
|
@ -2636,47 +2573,6 @@ static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tens
|
|||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int64_t IW = src1->ne[0];
|
||||
|
||||
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int64_t KW = src0->ne[0];
|
||||
|
||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
} else {
|
||||
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
}
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
|
@ -3581,7 +3477,8 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE
|
||||
&& (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda || src1->ne[1] > MMVQ_MIN_BATCH_SIZE);
|
||||
|
||||
bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
||||
|
|
|
@ -25,5 +25,6 @@
|
|||
#include "norm.hpp"
|
||||
#include "softmax.hpp"
|
||||
#include "tsembd.hpp"
|
||||
#include "im2col.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
|
|||
<< ", line:" << __LINE__ << std::endl;
|
||||
std::exit(1);
|
||||
}
|
||||
|
||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
|
||||
const int64_t max_range = std::numeric_limits<int>::max();
|
||||
int64_t sycl_down_blk_size = block_size;
|
||||
int64_t global_range = accumulate_block_num * sycl_down_blk_size;
|
||||
while(global_range > max_range) {
|
||||
sycl_down_blk_size /= 2;
|
||||
global_range = accumulate_block_num * sycl_down_blk_size;
|
||||
}
|
||||
return sycl_down_blk_size;
|
||||
}
|
||||
|
|
|
@ -130,6 +130,7 @@ typedef sycl::float2 dfloat2;
|
|||
#endif // GGML_SYCL_F16
|
||||
|
||||
#define MMVQ_MAX_BATCH_SIZE 8
|
||||
#define MMVQ_MIN_BATCH_SIZE 4
|
||||
|
||||
static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
||||
|
||||
|
@ -352,4 +353,6 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
|||
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
}
|
||||
|
||||
int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
|
|
@ -3,19 +3,19 @@
|
|||
#include "presets.hpp"
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
||||
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
const int64_t i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2));
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ib = i/qk; // block index
|
||||
const int iqs = (i%qk)/qr; // quant index
|
||||
const int iybs = i - i%qk; // y block start index
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
const int64_t ib = i/qk; // block index
|
||||
const int64_t iqs = (i%qk)/qr; // quant index
|
||||
const int64_t iybs = i - i%qk; // y block start index
|
||||
const int64_t y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
// dequantize
|
||||
dfloat2 v;
|
||||
|
@ -27,9 +27,9 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__
|
|||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_sycl(const void *__restrict__ vx,
|
||||
dst_t *__restrict__ y, const int k,
|
||||
dst_t *__restrict__ y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||
const int64_t num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -45,9 +45,9 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
|
@ -77,9 +77,9 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
|
@ -108,10 +108,10 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb32 = k / 32;
|
||||
const int nb = (k + 255) / 256;
|
||||
const int64_t nb32 = k / 32;
|
||||
const int64_t nb = (k + 255) / 256;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -126,10 +126,10 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb32 = k / 32;
|
||||
const int nb = (k + 255) / 256;
|
||||
const int64_t nb32 = k / 32;
|
||||
const int64_t nb = (k + 255) / 256;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -145,9 +145,9 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
|
|||
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -165,9 +165,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
|
@ -197,9 +197,9 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
#if QK_K == 256
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
|
@ -229,9 +229,9 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -250,9 +250,9 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -271,9 +271,9 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -292,9 +292,9 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -313,9 +313,9 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -333,9 +333,9 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -354,9 +354,9 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = k / QK_K;
|
||||
const int64_t nb = k / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -374,9 +374,9 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
const int64_t nb = (k + QK_K - 1) / QK_K;
|
||||
#if QK_K == 64
|
||||
dequantize_row_iq4_nl_sycl(vx, y, k, stream);
|
||||
#else
|
||||
|
@ -398,9 +398,9 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename dst_t>
|
||||
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
||||
static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int nb = (k + QK_K - 1) / QK_K;
|
||||
const int64_t nb = (k + QK_K - 1) / QK_K;
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
@ -418,34 +418,34 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
|
|||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
|
||||
static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
item_ct1.get_local_id(2);
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
|
||||
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||
const src_t * x = (src_t *) vx;
|
||||
|
||||
for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||
y[i] = x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_sycl(const void *__restrict__ vx,
|
||||
dst_t *__restrict__ y, const int k,
|
||||
dst_t *__restrict__ y, const int64_t k,
|
||||
dpct::queue_ptr stream) {
|
||||
const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||
const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
|
||||
|
||||
// decrease global range when it exceeds the max int
|
||||
int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE);
|
||||
sycl::range<3> block_nums(1, 1, num_blocks);
|
||||
sycl::range<3> local_range(1, 1, local_size);
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(
|
||||
sycl::range<3>(1, 1, num_blocks) *
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
|
||||
sycl::nd_range<3>(block_nums * local_range, local_range),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
convert_unary<src_t>(vx, y, k, item_ct1);
|
||||
});
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
template <typename T>
|
||||
using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y,
|
||||
int k, dpct::queue_ptr stream);
|
||||
int64_t k, dpct::queue_ptr stream);
|
||||
typedef to_t_sycl_t<float> to_fp32_sycl_t;
|
||||
typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t;
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
|
||||
#include "common.hpp"
|
||||
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
|
||||
|
||||
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
|
||||
static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
|
@ -40,7 +40,7 @@ static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
|
|||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
|
||||
static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||
|
||||
|
@ -64,7 +64,7 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
|
|||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
|
||||
static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||
|
||||
|
@ -91,7 +91,7 @@ static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
|
|||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
|
||||
static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||
|
||||
|
@ -118,7 +118,7 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
|
|||
#endif // GGML_SYCL_F16
|
||||
}
|
||||
|
||||
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
|
||||
static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib,
|
||||
const int iqs, dfloat2 &v) {
|
||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||
|
||||
|
@ -138,16 +138,16 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
|
|||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||
static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int ib = 8*i + ir;
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
@ -168,16 +168,16 @@ static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restri
|
|||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
|
||||
static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int ib = 8*i + ir;
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t ib = 8*i + ir;
|
||||
if (ib >= nb32) {
|
||||
return;
|
||||
}
|
||||
|
@ -203,14 +203,14 @@ template<typename dst_t>
|
|||
static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_q2_K * x = (const block_q2_K *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int n = tid/32;
|
||||
const int l = tid - 32*n;
|
||||
const int is = 8*n + l/16;
|
||||
const int64_t n = tid/32;
|
||||
const int64_t l = tid - 32*n;
|
||||
const int64_t is = 8*n + l/16;
|
||||
|
||||
const uint8_t q = x[i].qs[32*n + l];
|
||||
dst_t * y = yy + i*QK_K + 128*n;
|
||||
|
@ -222,8 +222,8 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri
|
|||
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
||||
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
||||
#else
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int il = tid%16; // 0...15
|
||||
const int64_t is = tid/16; // 0 or 1
|
||||
const int64_t il = tid%16; // 0...15
|
||||
const uint8_t q = x[i].qs[il] >> (2*is);
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
|
||||
|
@ -239,19 +239,19 @@ template<typename dst_t>
|
|||
static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_q3_K * x = (const block_q3_K *) vx;
|
||||
|
||||
#if QK_K == 256
|
||||
const int r = item_ct1.get_local_id(2) / 4;
|
||||
const int tid = r/2;
|
||||
const int is0 = r%2;
|
||||
const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
||||
const int n = tid / 4;
|
||||
const int j = tid - 4*n;
|
||||
const int64_t r = item_ct1.get_local_id(2) / 4;
|
||||
const int64_t tid = r/2;
|
||||
const int64_t is0 = r%2;
|
||||
const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
|
||||
const int64_t n = tid / 4;
|
||||
const int64_t j = tid - 4*n;
|
||||
|
||||
uint8_t m = 1 << (4*n + j);
|
||||
int is = 8*n + 2*j + is0;
|
||||
int64_t is = 8*n + 2*j + is0;
|
||||
int shift = 2*j;
|
||||
|
||||
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
|
||||
|
@ -267,11 +267,11 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
|
|||
|
||||
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int il = tid%16; // 0...15
|
||||
const int im = il/8; // 0...1
|
||||
const int in = il%8; // 0...7
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t is = tid/16; // 0 or 1
|
||||
const int64_t il = tid%16; // 0...15
|
||||
const int64_t im = il/8; // 0...1
|
||||
const int64_t in = il%8; // 0...7
|
||||
|
||||
dst_t * y = yy + i*QK_K + 16*is + il;
|
||||
|
||||
|
@ -307,15 +307,15 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|||
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
|
||||
const block_q4_K * x = (const block_q4_K *) vx;
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 32 threads
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/8;
|
||||
const int ir = tid%8;
|
||||
const int is = 2*il;
|
||||
const int n = 4;
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8;
|
||||
const int64_t ir = tid%8;
|
||||
const int64_t is = 2*il;
|
||||
const int64_t n = 4;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + n*ir;
|
||||
|
||||
|
@ -341,7 +341,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri
|
|||
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
|
||||
}
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const uint8_t * q = x[i].qs;
|
||||
dst_t * y = yy + i*QK_K;
|
||||
const float d = (float)x[i].dm[0];
|
||||
|
@ -356,14 +356,14 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
|||
const sycl::nd_item<3> &item_ct1) {
|
||||
const block_q5_K * x = (const block_q5_K *) vx;
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
|
||||
#if QK_K == 256
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/16; // il is in 0...3
|
||||
const int ir = tid%16; // ir is in 0...15
|
||||
const int is = 2*il; // is is in 0...6
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/16; // il is in 0...3
|
||||
const int64_t ir = tid%16; // ir is in 0...15
|
||||
const int64_t is = 2*il; // is is in 0...6
|
||||
|
||||
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
|
||||
|
||||
|
@ -386,11 +386,11 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri
|
|||
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
||||
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const uint8_t q = x[i].qs[tid];
|
||||
const int im = tid/8; // 0...3
|
||||
const int in = tid%8; // 0...7
|
||||
const int is = tid/16; // 0 or 1
|
||||
const int64_t im = tid/8; // 0...3
|
||||
const int64_t in = tid%8; // 0...7
|
||||
const int64_t is = tid/16; // 0 or 1
|
||||
const uint8_t h = x[i].qh[in] >> im;
|
||||
const float d = x[i].d;
|
||||
dst_t * y = yy + i*QK_K + tid;
|
||||
|
@ -404,14 +404,14 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|||
const sycl::nd_item<3> &item_ct1) {
|
||||
const block_q6_K * x = (const block_q6_K *) vx;
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
#if QK_K == 256
|
||||
|
||||
// assume 64 threads - this is very slightly better than the one below
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int ip = tid/32; // ip is 0 or 1
|
||||
const int il = tid - 32*ip; // 0...32
|
||||
const int is = 8*ip + il/16;
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t ip = tid/32; // ip is 0 or 1
|
||||
const int64_t il = tid - 32*ip; // 0...32
|
||||
const int64_t is = 8*ip + il/16;
|
||||
|
||||
dst_t * y = yy + i*QK_K + 128*ip + il;
|
||||
|
||||
|
@ -428,9 +428,9 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
|
|||
#else
|
||||
|
||||
// assume 32 threads
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int ip = tid/16; // 0 or 1
|
||||
const int il = tid - 16*ip; // 0...15
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t ip = tid/16; // 0 or 1
|
||||
const int64_t il = tid - 16*ip; // 0...15
|
||||
|
||||
dst_t * y = yy + i*QK_K + 16*ip + il;
|
||||
|
||||
|
@ -452,13 +452,13 @@ static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __res
|
|||
const uint8_t *ksigns_iq2xs_ptr,
|
||||
const uint8_t *kmask_iq2xs_ptr) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||
const uint8_t * aux8 = (const uint8_t *)q2;
|
||||
|
@ -480,13 +480,13 @@ static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __rest
|
|||
const uint8_t *ksigns_iq2xs,
|
||||
const uint8_t *kmask_iq2xs) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint16_t * q2 = x[i].qs + 4*ib;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||
|
@ -504,13 +504,13 @@ __dpct_inline__ static void
|
|||
dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||
|
@ -532,13 +532,13 @@ static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __res
|
|||
const uint8_t *ksigns_iq2xs,
|
||||
const uint8_t *kmask_iq2xs) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * q3 = x[i].qs + 8*ib;
|
||||
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
|
||||
|
@ -563,13 +563,13 @@ dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint8_t * qs = x[i].qs + 8*ib;
|
||||
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
|
||||
|
@ -593,13 +593,13 @@ dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq1s_grid_gpu) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq1_s * x = (const block_iq1_s *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
||||
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
||||
|
@ -623,13 +623,13 @@ dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
|||
const sycl::nd_item<3> &item_ct1,
|
||||
const uint32_t *iq1s_grid_gpu) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq1_m * x = (const block_iq1_m *) vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
#if QK_K == 256
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
iq1m_scale_t scale;
|
||||
|
@ -656,12 +656,12 @@ __dpct_inline__ static void
|
|||
dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||
const float d = (float)x[ib].d;
|
||||
|
@ -678,12 +678,12 @@ template <typename dst_t>
|
|||
__dpct_inline__ static void
|
||||
dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_group(2);
|
||||
const int64_t i = item_ct1.get_group(2);
|
||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int il = tid/8; // 0...3
|
||||
const int ib = tid%8; // 0...7
|
||||
const int64_t tid = item_ct1.get_local_id(2);
|
||||
const int64_t il = tid/8; // 0...3
|
||||
const int64_t ib = tid%8; // 0...7
|
||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
#include "presets.hpp"
|
||||
|
||||
|
||||
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const sycl::half *x = (const sycl::half *)vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
|
@ -12,7 +12,7 @@ static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 &
|
|||
v.y() = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||
const float * x = (const float *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
|
|
125
ggml/src/ggml-sycl/im2col.cpp
Normal file
125
ggml/src/ggml-sycl/im2col.cpp
Normal file
|
@ -0,0 +1,125 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#include "im2col.hpp"
|
||||
|
||||
template <typename T>
|
||||
static void im2col_kernel(
|
||||
const float *x, T *dst, int64_t batch_offset, int64_t offset_delta,
|
||||
int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
|
||||
int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int64_t work_group_size = item_ct1.get_local_range(2);
|
||||
const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2);
|
||||
|
||||
// make each work-item deal with more elements since sycl global range can not exceed max int
|
||||
for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) {
|
||||
|
||||
const int64_t ksize = OW * (KH > 1 ? KW : 1);
|
||||
const int64_t kx = i / ksize;
|
||||
const int64_t kd = kx * ksize;
|
||||
const int64_t ky = (i - kd) / OW;
|
||||
const int64_t ix = i % OW;
|
||||
|
||||
const int64_t oh = item_ct1.get_group(1);
|
||||
const int64_t batch = item_ct1.get_group(0) / IC;
|
||||
const int64_t ic = item_ct1.get_group(0) % IC;
|
||||
|
||||
const int64_t iiw = ix * s0 + kx * d0 - p0;
|
||||
const int64_t iih = oh * s1 + ky * d1 - p1;
|
||||
|
||||
const int64_t offset_dst =
|
||||
((batch * OH + oh) * OW + ix) * CHW +
|
||||
(ic * (KW * KH) + ky * KW + kx);
|
||||
|
||||
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
dst[offset_dst] =
|
||||
sycl::vec<float, 1>(0.0f)
|
||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
} else {
|
||||
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
|
||||
dst[offset_dst] =
|
||||
sycl::vec<float, 1>(x[offset_src + iih * IW + iiw])
|
||||
.convert<sycl::half, sycl::rounding_mode::automatic>()[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void im2col_sycl(
|
||||
const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW,
|
||||
int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta,
|
||||
int s0, int s1, int p0, int p1, int d0, int d1,
|
||||
queue_ptr stream) {
|
||||
const int64_t parallel_elements = OW * KW * KH;
|
||||
const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE;
|
||||
|
||||
// decrease global range when it exceeds the max int
|
||||
int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE);
|
||||
sycl::range<3> block_nums(batch * IC, OH, num_blocks);
|
||||
sycl::range<3> local_range(1, 1, local_size);
|
||||
|
||||
{
|
||||
dpct::has_capability_or_fail(stream->get_device(),
|
||||
{sycl::aspect::fp16});
|
||||
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * local_range, local_range),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH,
|
||||
parallel_elements, (IC * KH * KW), s0, s1, p0,
|
||||
p1, d0, d1, item_ct1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_im2col(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
||||
|
||||
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
|
||||
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
|
||||
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
|
||||
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
|
||||
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
|
||||
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
|
||||
|
||||
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
|
||||
|
||||
const int64_t IC = src1->ne[is_2D ? 2 : 1];
|
||||
const int64_t IH = is_2D ? src1->ne[1] : 1;
|
||||
const int64_t IW = src1->ne[0];
|
||||
|
||||
const int64_t KH = is_2D ? src0->ne[1] : 1;
|
||||
const int64_t KW = src0->ne[0];
|
||||
|
||||
const int64_t OH = is_2D ? dst->ne[2] : 1;
|
||||
const int64_t OW = dst->ne[1];
|
||||
|
||||
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
|
||||
const int64_t batch = src1->ne[3];
|
||||
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
|
||||
|
||||
if (dst->type == GGML_TYPE_F16) {
|
||||
im2col_sycl(src1_dd, (sycl::half *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
} else {
|
||||
im2col_sycl(src1_dd, (float *)dst_dd, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
|
||||
}
|
||||
|
||||
(void) src0;
|
||||
(void) src0_dd;
|
||||
}
|
23
ggml/src/ggml-sycl/im2col.hpp
Normal file
23
ggml/src/ggml-sycl/im2col.hpp
Normal file
|
@ -0,0 +1,23 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_IM2COL_HPP
|
||||
#define GGML_SYCL_IM2COL_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_im2col(
|
||||
ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd, const float *src1_dd, float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
|
||||
#endif // GGML_SYCL_IM2COL_HPP
|
|
@ -180,6 +180,7 @@ struct vk_device_struct {
|
|||
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
|
||||
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
||||
vk_pipeline pipeline_acc_f32;
|
||||
vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
|
||||
vk_pipeline pipeline_mul_f32;
|
||||
vk_pipeline pipeline_div_f32;
|
||||
|
@ -1687,6 +1688,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
|
@ -3971,6 +3974,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|||
return ctx->device->pipeline_get_rows_f32[src0->type];
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ACC:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_acc_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_ADD:
|
||||
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_add_f32;
|
||||
|
@ -4463,6 +4471,28 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||
const uint32_t d_offset = ((extra->offset + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
||||
|
||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
|
||||
int offset = dst->op_params[3] / 4; // offset in bytes
|
||||
|
||||
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, {
|
||||
(uint32_t)ggml_nelements(src0),
|
||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
|
||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
|
||||
d_offset,
|
||||
0.0f, 0.0f, offset,
|
||||
}, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||
|
@ -5621,6 +5651,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_CONCAT:
|
||||
|
@ -5668,6 +5699,10 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|||
case GGML_OP_REPEAT:
|
||||
ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_ACC:
|
||||
ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
||||
break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||
|
@ -5808,6 +5843,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_GET_ROWS:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
|
@ -6539,6 +6575,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
|
|||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ACC:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_CONCAT:
|
||||
|
@ -6995,6 +7032,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|||
tensor_clone = ggml_repeat(ggml_ctx, src0_clone, src1_clone);
|
||||
} else if (tensor->op == GGML_OP_ADD) {
|
||||
tensor_clone = ggml_add(ggml_ctx, src0_clone, src1_clone);
|
||||
} else if (tensor->op == GGML_OP_ACC) {
|
||||
tensor_clone = ggml_acc(ggml_ctx, src0_clone, src1_clone, tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
|
||||
} else if (tensor->op == GGML_OP_NORM) {
|
||||
tensor_clone = ggml_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params);
|
||||
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
||||
|
|
24
ggml/src/vulkan-shaders/acc.comp
Normal file
24
ggml/src/vulkan-shaders/acc.comp
Normal file
|
@ -0,0 +1,24 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
void main() {
|
||||
const uint idx = gl_GlobalInvocationID.x;
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint offset = p.param3;
|
||||
const uint src1_i = idx - offset;
|
||||
const uint oz = src1_i / p.nb02;
|
||||
const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
|
||||
const uint ox = src1_i % p.nb01;
|
||||
|
||||
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
||||
} else {
|
||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
|
||||
}
|
||||
}
|
||||
|
|
@ -368,6 +368,10 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
|
|||
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
}));
|
||||
|
|
|
@ -130,6 +130,7 @@ class Keys:
|
|||
INNER_SIZE = "{arch}.ssm.inner_size"
|
||||
STATE_SIZE = "{arch}.ssm.state_size"
|
||||
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
|
||||
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
|
@ -1380,6 +1381,7 @@ KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
|||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||
|
|
|
@ -730,6 +730,9 @@ class GGUFWriter:
|
|||
def add_ssm_time_step_rank(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
|
||||
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
|
||||
|
||||
def add_tokenizer_model(self, model: str) -> None:
|
||||
self.add_string(Keys.Tokenizer.MODEL, model)
|
||||
|
||||
|
|
|
@ -321,6 +321,21 @@ private:
|
|||
|
||||
// TODO: there are a lot of common parts between spm and bpe tokenizers, should be refactored and reused
|
||||
|
||||
template<typename T, typename Container = std::vector<T>, typename Compare = std::less<typename Container::value_type>>
|
||||
class llama_priority_queue : public std::priority_queue<T, Container, Compare> {
|
||||
public:
|
||||
using std::priority_queue<T, Container, Compare>::priority_queue;
|
||||
|
||||
T pop_move() {
|
||||
T item = std::move(this->c.front());
|
||||
std::pop_heap(this->c.begin(), this->c.end(), this->comp);
|
||||
this->c.pop_back();
|
||||
return item;
|
||||
}
|
||||
|
||||
void pop() = delete;
|
||||
};
|
||||
|
||||
struct llm_bigram_bpe {
|
||||
struct comparator {
|
||||
bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
|
||||
|
@ -329,7 +344,7 @@ struct llm_bigram_bpe {
|
|||
};
|
||||
|
||||
using queue_storage = std::vector<llm_bigram_bpe>;
|
||||
using queue = std::priority_queue<llm_bigram_bpe, queue_storage, comparator>;
|
||||
using queue = llama_priority_queue<llm_bigram_bpe, queue_storage, comparator>;
|
||||
llm_symbol::index left;
|
||||
llm_symbol::index right;
|
||||
std::string text;
|
||||
|
@ -520,8 +535,7 @@ struct llm_tokenizer_bpe {
|
|||
|
||||
// build token(s)
|
||||
while (!work_queue.empty()) {
|
||||
auto bigram = work_queue.top();
|
||||
work_queue.pop();
|
||||
auto bigram = work_queue.pop_move();
|
||||
|
||||
auto & left_symbol = symbols[bigram.left];
|
||||
auto & right_symbol = symbols[bigram.right];
|
||||
|
|
|
@ -328,6 +328,7 @@ enum llm_kv {
|
|||
LLM_KV_SSM_CONV_KERNEL,
|
||||
LLM_KV_SSM_STATE_SIZE,
|
||||
LLM_KV_SSM_TIME_STEP_RANK,
|
||||
LLM_KV_SSM_DT_B_C_RMS,
|
||||
|
||||
LLM_KV_TOKENIZER_MODEL,
|
||||
LLM_KV_TOKENIZER_PRE,
|
||||
|
@ -426,6 +427,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
||||
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
||||
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
||||
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
||||
|
||||
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
||||
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
||||
|
@ -2237,6 +2239,7 @@ struct llama_hparams {
|
|||
uint32_t ssm_d_inner = 0;
|
||||
uint32_t ssm_d_state = 0;
|
||||
uint32_t ssm_dt_rank = 0;
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
float f_max_alibi_bias = 0.0f;
|
||||
|
@ -2286,6 +2289,7 @@ struct llama_hparams {
|
|||
if (this->ssm_d_inner != other.ssm_d_inner) return true;
|
||||
if (this->ssm_d_state != other.ssm_d_state) return true;
|
||||
if (this->ssm_dt_rank != other.ssm_dt_rank) return true;
|
||||
if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true;
|
||||
|
||||
if (this->dec_start_token_id != other.dec_start_token_id) return true;
|
||||
|
||||
|
@ -5060,6 +5064,7 @@ static void llm_load_hparams(
|
|||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false);
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
|
@ -5915,6 +5920,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||
LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank);
|
||||
LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms);
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
|
||||
|
@ -12169,6 +12175,10 @@ struct llm_build_context {
|
|||
GGML_ASSERT(2 * d_model == d_inner);
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t dt_rank = hparams.ssm_dt_rank;
|
||||
// Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
|
||||
const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
|
||||
// Use the same RMS norm as the final layer norm
|
||||
const float norm_rms_eps = hparams.f_norm_rms_eps;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
@ -12249,6 +12259,13 @@ struct llm_build_context {
|
|||
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*dt_rank);
|
||||
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, n_tokens, x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
|
||||
|
||||
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers
|
||||
if (ssm_dt_b_c_rms) {
|
||||
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps);
|
||||
B = ggml_rms_norm(ctx0, B, norm_rms_eps);
|
||||
C = ggml_rms_norm(ctx0, C, norm_rms_eps);
|
||||
}
|
||||
|
||||
// {dt_rank, d_inner} * {dt_rank, n_tokens} => {d_inner, n_tokens}
|
||||
dt = llm_build_lora_mm(lctx, ctx0, model.layers[il].ssm_dt, dt);
|
||||
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
|
||||
|
@ -16353,6 +16370,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
|
|||
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
|
||||
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
|
||||
}
|
||||
if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
|
||||
new_type = GGML_TYPE_F16;
|
||||
}
|
||||
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
|
||||
++qs.n_fallback;
|
||||
}
|
||||
|
@ -16698,8 +16718,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
// do not quantize Mamba's small yet 2D weights
|
||||
// NOTE: can't use LLM_TN here because the layer number is not known
|
||||
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||
quantize &= name.find("ssm_x.weight") == std::string::npos;
|
||||
quantize &= name.find("ssm_dt.weight") == std::string::npos;
|
||||
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||
|
|
|
@ -2145,6 +2145,13 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
|
||||
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)
|
||||
// these cases are verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
|
||||
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
|
||||
// test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
|
||||
|
||||
test_cases.emplace_back(new test_conv_transpose_1d());
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
|
||||
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
|
||||
|
@ -2287,6 +2294,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
|
||||
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
|
||||
|
||||
// sycl backend will limit task global_range < MAX_INT
|
||||
// test case for f16-type-convert-to-fp32 kernel with large k under fp32 compute dtype (occurs in stable-diffusion)
|
||||
// however this case needs to alloc more memory which may fail in some devices (Intel Arc770, etc.)
|
||||
// this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
|
||||
// test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
|
||||
|
||||
for (ggml_type type_a : base_types) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
|
||||
for (int n_mats : {4, 8}) {
|
||||
|
|
|
@ -503,7 +503,7 @@ static void test_special_chars() {
|
|||
"aaaaabcccc",
|
||||
"aaaabccc",
|
||||
"aaaabccccc",
|
||||
"🔵🟠✅❌abc❌✅🟠🔵"
|
||||
"🔵🟠✅❌abc❌✅🟠🔵",
|
||||
"🔵🟠abc🟠🔵"
|
||||
}
|
||||
);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue