Merge branch 'master' into compilade/batch-splits
This commit is contained in:
commit
22504ec67e
182 changed files with 18526 additions and 149599 deletions
2
ggml/.gitignore
vendored
Normal file
2
ggml/.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
|||
src/ggml-vulkan-shaders.hpp
|
||||
src/ggml-vulkan-shaders.cpp
|
|
@ -104,7 +104,7 @@ option(GGML_ACCELERATE "ggml: enable Accelerate framework"
|
|||
option(GGML_BLAS "ggml: use BLAS" ${GGML_BLAS_DEFAULT})
|
||||
set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
||||
"ggml: BLAS library vendor")
|
||||
option(GGML_LLAMAFILE "ggml: use ggml SGEMM" OFF)
|
||||
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
|
||||
|
||||
option(GGML_CUDA "ggml: use CUDA" OFF)
|
||||
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||
|
|
|
@ -1,220 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
from tempfile import gettempdir
|
||||
|
||||
logger = logging.getLogger("ggml-vk-generate-shaders")
|
||||
|
||||
GLSLC = "glslc"
|
||||
|
||||
type_names = [
|
||||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1",
|
||||
"q5_0",
|
||||
"q5_1",
|
||||
"q8_0",
|
||||
"q2_k",
|
||||
"q3_k",
|
||||
"q4_k",
|
||||
"q5_k",
|
||||
"q6_k",
|
||||
]
|
||||
|
||||
ASYNCIO_CONCURRENCY = 64
|
||||
|
||||
input_dir = "vulkan-shaders"
|
||||
output_dir = gettempdir()
|
||||
|
||||
lock = asyncio.Lock()
|
||||
shader_fnames = []
|
||||
|
||||
|
||||
async def string_to_spv(name, in_fname, defines, fp16=True):
|
||||
name = f"{name}{'_fp32' if not fp16 else ''}"
|
||||
out_fname = os.path.join(output_dir, f"{name}.spv")
|
||||
|
||||
in_path = os.path.join(input_dir, in_fname)
|
||||
|
||||
cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname]
|
||||
|
||||
cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
|
||||
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
stdout = stdout.decode()
|
||||
error = stderr.decode()
|
||||
|
||||
if proc.returncode:
|
||||
cmd = " ".join(cmd)
|
||||
logger.error(f"cannot compile {name}\n\n{cmd}\n\n{error}")
|
||||
return
|
||||
|
||||
async with lock:
|
||||
shader_fnames.append((name, out_fname))
|
||||
|
||||
|
||||
def matmul_shaders(tasks, fp16, matmul_id):
|
||||
if fp16:
|
||||
load_vec = "8"
|
||||
aligned_b_type_f32 = "mat2x4"
|
||||
aligned_b_type_f16 = "f16mat2x4"
|
||||
else:
|
||||
load_vec = "4"
|
||||
aligned_b_type_f32 = "vec4"
|
||||
aligned_b_type_f16 = "f16vec4"
|
||||
|
||||
base_dict = {"FLOAT_TYPE": "float" if not fp16 else "float16_t"}
|
||||
shader_name = "matmul"
|
||||
|
||||
if matmul_id:
|
||||
base_dict["MUL_MAT_ID"] = "1"
|
||||
shader_name = "matmul_id"
|
||||
|
||||
if fp16:
|
||||
base_dict["FLOAT16"] = "1"
|
||||
|
||||
# Shaders with f16 B_TYPE
|
||||
tasks.append(string_to_spv(f"{shader_name}_f32_f16", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv(f"{shader_name}_f32_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F32": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
|
||||
|
||||
tasks.append(string_to_spv(f"{shader_name}_f16", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv(f"{shader_name}_f16_aligned", "mul_mm.comp", base_dict | {"DATA_A_F16": "1", "LOAD_VEC_A": load_vec, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f16, "D_TYPE": "float"}, fp16))
|
||||
|
||||
for tname in type_names:
|
||||
data_a_key = f"DATA_A_{tname.upper()}"
|
||||
load_vec_a = load_vec if tname in ("f32", "f16") else "2"
|
||||
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32", "mul_mm.comp", base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
|
||||
tasks.append(string_to_spv(f"{shader_name}_{tname}_f32_aligned", "mul_mm.comp", base_dict | {data_a_key: "2", "LOAD_VEC_A": load_vec_a, "LOAD_VEC_B": load_vec, "B_TYPE": aligned_b_type_f32, "D_TYPE": "float"}, fp16))
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("ggml_vulkan: Generating and compiling shaders to SPIR-V")
|
||||
|
||||
tasks = []
|
||||
|
||||
for fp16 in (False, True):
|
||||
# MUL_MAT
|
||||
matmul_shaders(tasks, fp16, False)
|
||||
# MUL_MAT_ID
|
||||
matmul_shaders(tasks, fp16, True)
|
||||
|
||||
for tname in type_names:
|
||||
base_dict = {"FLOAT_TYPE": "float"}
|
||||
|
||||
# mul mat vec
|
||||
data_a_key = f"DATA_A_{tname.upper()}"
|
||||
shader = f"mul_mat_vec_{tname}.comp" if tname.endswith("_k") else "mul_mat_vec.comp"
|
||||
|
||||
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f32_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv(f"mul_mat_vec_{tname}_f16_f32", shader, base_dict | {data_a_key: "1", "B_TYPE": "float16_t", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv(f"mul_mat_vec_id_{tname}_f32", shader, base_dict | {"MUL_MAT_ID": "1", data_a_key: "1", "B_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
# Dequant shaders
|
||||
if tname != "f16":
|
||||
tasks.append(string_to_spv(f"dequant_{tname}", f"dequant_{tname}.comp", base_dict | {data_a_key: "1", "D_TYPE": "float16_t"}))
|
||||
|
||||
# get_rows
|
||||
if not tname.endswith("_k"):
|
||||
shader = "get_rows.comp" if tname in ("f32", "f16") else "get_rows_quant.comp"
|
||||
|
||||
if tname == "f16":
|
||||
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
|
||||
else:
|
||||
tasks.append(string_to_spv(f"get_rows_{tname}", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float16_t"}))
|
||||
tasks.append(string_to_spv(f"get_rows_{tname}_f32", shader, {data_a_key: "1", "B_TYPE": "int", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
# Norms
|
||||
tasks.append(string_to_spv("norm_f32", "norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rms_norm_f32", "rms_norm.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("cpy_f32_f32", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("cpy_f32_f16", "copy.comp", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
|
||||
tasks.append(string_to_spv("cpy_f16_f16", "copy.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t", "OPTIMIZATION_ERROR_WORKAROUND": "1"}))
|
||||
|
||||
tasks.append(string_to_spv("add_f32", "add.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}))
|
||||
|
||||
tasks.append(string_to_spv("mul_f32", "mul.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("div_f32", "div.comp", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("scale_f32", "scale.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("sqr_f32", "square.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("clamp_f32", "clamp.comp", {"A_TYPE": "float", "D_TYPE": "float", "FLOAT_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("gelu_f32", "gelu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("silu_f32", "silu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("relu_f32", "relu.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("soft_max_f32", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("soft_max_f32_f16", "soft_max.comp", base_dict | {"A_TYPE": "float", "B_TYPE": "float16_t", "D_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("rope_norm_f32", "rope_norm.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_norm_f16", "rope_norm.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
|
||||
tasks.append(string_to_spv("rope_neox_f32", "rope_neox.comp", {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
tasks.append(string_to_spv("rope_neox_f16", "rope_neox.comp", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
|
||||
|
||||
tasks.append(string_to_spv("argsort_f32", "argsort.comp", {"A_TYPE": "float"}))
|
||||
|
||||
tasks.append(string_to_spv("sum_rows_f32", "sum_rows.comp", base_dict | {"A_TYPE": "float", "D_TYPE": "float"}))
|
||||
|
||||
# Helper to decorate tasks with semaphore acquisition.
|
||||
async def withSemaphore(sem, task):
|
||||
async with sem:
|
||||
return await task
|
||||
|
||||
# Run tasks concurrently guarded by a concurrency limit.
|
||||
sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
|
||||
await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
|
||||
|
||||
with open("ggml-vulkan-shaders.hpp", "w") as f:
|
||||
f.write("#include <cstdint>\n\n")
|
||||
for name, path in sorted(shader_fnames):
|
||||
|
||||
with open(path, "rb") as spv:
|
||||
counter = 0
|
||||
newline_counter = 0
|
||||
f.write(f"unsigned char {name}_data[] = {{\n")
|
||||
for val in spv.read():
|
||||
f.write(f"0x{val:02x},")
|
||||
newline_counter += 1
|
||||
counter += 1
|
||||
if newline_counter >= 12:
|
||||
newline_counter = 0
|
||||
f.write("\n")
|
||||
f.write("\n};\n")
|
||||
f.write(f"const uint64_t {name}_len = {counter};\n\n")
|
||||
os.remove(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
|
||||
|
||||
parser.add_argument("--glslc", help="Path to glslc")
|
||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
if args.glslc:
|
||||
GLSLC = args.glslc
|
||||
|
||||
asyncio.run(main())
|
|
@ -383,6 +383,9 @@ extern "C" {
|
|||
GGML_TYPE_F64 = 28,
|
||||
GGML_TYPE_IQ1_M = 29,
|
||||
GGML_TYPE_BF16 = 30,
|
||||
GGML_TYPE_Q4_0_4_4 = 31,
|
||||
GGML_TYPE_Q4_0_4_8 = 32,
|
||||
GGML_TYPE_Q4_0_8_8 = 33,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
|
@ -424,6 +427,9 @@ extern "C" {
|
|||
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
@ -708,9 +714,9 @@ extern "C" {
|
|||
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
||||
|
||||
GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type);
|
||||
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
|
||||
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||
|
||||
GGML_DEPRECATED(
|
||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||
|
@ -2401,20 +2407,31 @@ extern "C" {
|
|||
#endif
|
||||
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||
typedef void (*ggml_from_float_to_mat_t)
|
||||
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
|
||||
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||
const void * GGML_RESTRICT y, int nr, int nc);
|
||||
|
||||
typedef struct {
|
||||
const char * type_name;
|
||||
int blck_size;
|
||||
size_t type_size;
|
||||
bool is_quantized;
|
||||
ggml_to_float_t to_float;
|
||||
ggml_from_float_t from_float;
|
||||
ggml_from_float_t from_float_reference;
|
||||
ggml_vec_dot_t vec_dot;
|
||||
enum ggml_type vec_dot_type;
|
||||
int64_t nrows; // number of rows to process simultaneously;
|
||||
const char * type_name;
|
||||
int64_t blck_size;
|
||||
int64_t blck_size_interleave; // interleave elements in blocks
|
||||
size_t type_size;
|
||||
bool is_quantized;
|
||||
ggml_to_float_t to_float;
|
||||
ggml_from_float_t from_float;
|
||||
ggml_from_float_t from_float_ref;
|
||||
ggml_from_float_to_mat_t from_float_to_mat;
|
||||
ggml_vec_dot_t vec_dot;
|
||||
enum ggml_type vec_dot_type;
|
||||
int64_t nrows; // number of rows to process simultaneously
|
||||
int64_t ncols; // number of columns to process simultaneously
|
||||
ggml_gemv_t gemv;
|
||||
ggml_gemm_t gemm;
|
||||
} ggml_type_traits_t;
|
||||
|
||||
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||
|
|
|
@ -238,12 +238,12 @@ if (GGML_BLAS)
|
|||
endif()
|
||||
|
||||
if (GGML_LLAMAFILE)
|
||||
message(STATUS "Using ggml SGEMM")
|
||||
message(STATUS "Using llamafile")
|
||||
|
||||
add_compile_definitions(GGML_USE_LLAMAFILE)
|
||||
|
||||
set(GGML_HEADERS_LLAMAFILE sgemm.h)
|
||||
set(GGML_SOURCES_LLAMAFILE sgemm.cpp)
|
||||
set(GGML_HEADERS_LLAMAFILE llamafile/sgemm.h)
|
||||
set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA)
|
||||
|
@ -440,6 +440,10 @@ if (GGML_HIPBLAS)
|
|||
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_FORCE_CUBLAS)
|
||||
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_NO_PEER_COPY)
|
||||
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
@ -490,7 +494,7 @@ if (GGML_SYCL)
|
|||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
else()
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
|
||||
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
|
||||
endif()
|
||||
|
||||
file(GLOB GGML_HEADERS_SYCL "ggml-sycl/*.hpp")
|
||||
|
@ -527,14 +531,11 @@ if (GGML_RPC)
|
|||
endif()
|
||||
|
||||
if (GGML_VULKAN)
|
||||
find_package(Vulkan)
|
||||
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
||||
|
||||
if (Vulkan_FOUND)
|
||||
message(STATUS "Vulkan found")
|
||||
|
||||
set(GGML_HEADERS_VULKAN ../include/ggml-vulkan.h)
|
||||
set(GGML_SOURCES_VULKAN ggml-vulkan.cpp)
|
||||
|
||||
list(APPEND GGML_CDEF_PUBLIC GGML_USE_VULKAN)
|
||||
|
||||
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
|
||||
|
@ -563,7 +564,37 @@ if (GGML_VULKAN)
|
|||
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
||||
endif()
|
||||
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan)
|
||||
add_subdirectory(vulkan-shaders)
|
||||
|
||||
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
|
||||
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
||||
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
||||
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
||||
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
||||
|
||||
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${_ggml_vk_header}
|
||||
${_ggml_vk_source}
|
||||
|
||||
COMMAND ${_ggml_vk_genshaders_cmd}
|
||||
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
||||
--input-dir ${_ggml_vk_input_dir}
|
||||
--output-dir ${_ggml_vk_output_dir}
|
||||
--target-hpp ${_ggml_vk_header}
|
||||
--target-cpp ${_ggml_vk_source}
|
||||
--no-clean
|
||||
|
||||
DEPENDS ${_ggml_vk_shader_deps}
|
||||
COMMENT "Generate vulkan shaders"
|
||||
)
|
||||
|
||||
set(GGML_HEADERS_VULKAN ${CMAKE_CURRENT_SOURCE_DIR}/../include/ggml-vulkan.h ${_ggml_vk_header})
|
||||
set(GGML_SOURCES_VULKAN ggml-vulkan.cpp ${_ggml_vk_source})
|
||||
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} Vulkan::Vulkan)
|
||||
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} ${CMAKE_CURRENT_BINARY_DIR})
|
||||
else()
|
||||
message(WARNING "Vulkan not found")
|
||||
endif()
|
||||
|
@ -1153,6 +1184,7 @@ add_library(ggml
|
|||
${GGML_SOURCES_ROCM} ${GGML_HEADERS_ROCM}
|
||||
${GGML_SOURCES_BLAS} ${GGML_HEADERS_BLAS}
|
||||
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
|
||||
ggml-aarch64.c ggml-aarch64.h
|
||||
)
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
|
@ -1175,4 +1207,5 @@ endif()
|
|||
|
||||
if (BUILD_SHARED_LIBS)
|
||||
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(ggml PRIVATE GGML_SHARED GGML_BUILD)
|
||||
endif()
|
||||
|
|
2193
ggml/src/ggml-aarch64.c
Normal file
2193
ggml/src/ggml-aarch64.c
Normal file
File diff suppressed because it is too large
Load diff
39
ggml/src/ggml-aarch64.h
Normal file
39
ggml/src/ggml-aarch64.h
Normal file
|
@ -0,0 +1,39 @@
|
|||
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
||||
#pragma once
|
||||
|
||||
#define GGML_COMMON_DECL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
// GGML internal header
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
|
||||
|
||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
||||
// GEMV
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
// GEMM
|
||||
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
|||
|
||||
// backend registry
|
||||
|
||||
#define GGML_REG_MAX_BACKENDS 16
|
||||
#define GGML_REG_MAX_BACKENDS 64
|
||||
|
||||
struct ggml_backend_reg {
|
||||
char name[128];
|
||||
|
|
|
@ -8,11 +8,12 @@
|
|||
# include <Accelerate/Accelerate.h>
|
||||
#elif defined(GGML_BLAS_USE_MKL)
|
||||
# include <mkl.h>
|
||||
#elif defined(GGML_BLAS_USE_BLIS)
|
||||
# include <blis.h>
|
||||
#elif defined(GGML_BLAS_USE_NVPL)
|
||||
# include <nvpl_blas.h>
|
||||
#else
|
||||
# include <cblas.h>
|
||||
# ifdef BLIS_ENABLE_CBLAS
|
||||
# include <blis.h>
|
||||
# endif
|
||||
#endif
|
||||
|
||||
struct ggml_backend_blas_context {
|
||||
|
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
|||
openblas_set_num_threads(ctx->n_threads);
|
||||
#endif
|
||||
|
||||
#if defined(BLIS_ENABLE_CBLAS)
|
||||
#if defined(GGML_BLAS_USE_BLIS)
|
||||
bli_thread_set_num_threads(ctx->n_threads);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_BLAS_USE_NVPL)
|
||||
nvpl_blas_set_num_threads(ctx->n_threads);
|
||||
#endif
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||
const int64_t i03 = i13/r3;
|
||||
|
|
|
@ -199,6 +199,30 @@ typedef struct {
|
|||
} block_q8_1;
|
||||
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[4]; // deltas for 4 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks
|
||||
} block_q4_0x4;
|
||||
static_assert(sizeof(block_q4_0x4) == 4 * sizeof(ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[8]; // deltas for 8 q4_0 blocks
|
||||
uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks
|
||||
} block_q4_0x8;
|
||||
static_assert(sizeof(block_q4_0x8) == 8 * sizeof(ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[4]; // deltas for 4 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks
|
||||
} block_q8_0x4;
|
||||
static_assert(sizeof(block_q8_0x4) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding");
|
||||
|
||||
typedef struct {
|
||||
ggml_half d[8]; // deltas for 8 q8_0 blocks
|
||||
int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks
|
||||
} block_q8_0x8;
|
||||
static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding");
|
||||
|
||||
//
|
||||
// Super-block quantization structures
|
||||
//
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
@ -1875,7 +1876,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|||
|
||||
bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
||||
&& src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2
|
||||
&& src1->ne[1] == 1;
|
||||
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
||||
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
||||
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
||||
|
@ -2261,6 +2263,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||
case GGML_OP_IM2COL:
|
||||
ggml_cuda_op_im2col(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
ggml_cuda_op_conv_transpose_1d(ctx,dst);
|
||||
break;
|
||||
case GGML_OP_POOL_2D:
|
||||
ggml_cuda_op_pool2d(ctx, dst);
|
||||
break;
|
||||
|
@ -2804,6 +2809,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
} break;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
ggml_type src1_type = op->src[1]->type;
|
||||
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} break;
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
|
|
@ -104,7 +104,7 @@
|
|||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define __trap abort
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||
|
|
87
ggml/src/ggml-cuda/conv-transpose-1d.cu
Normal file
87
ggml/src/ggml-cuda/conv-transpose-1d.cu
Normal file
|
@ -0,0 +1,87 @@
|
|||
#include "conv-transpose-1d.cuh"
|
||||
|
||||
static __global__ void conv_transpose_1d_kernel(
|
||||
const int s0, const int p0, const int d0, const int output_size,
|
||||
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
|
||||
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
|
||||
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
|
||||
const float * src0, const float * src1, float * dst) {
|
||||
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (global_index >= output_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int out_index = global_index / dst_ne0;
|
||||
|
||||
float accumulator = 0;
|
||||
|
||||
for (int c = 0; c < src0_ne2; c++) {
|
||||
int idx = global_index % dst_ne0;
|
||||
|
||||
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
|
||||
int input_offset = src1_ne0 * c;
|
||||
|
||||
for (int i = 0; i < src1_ne0; i++) {
|
||||
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
|
||||
continue;
|
||||
}
|
||||
int weight_idx = idx - i*s0;
|
||||
|
||||
float kernel_weight = src0[kernel_offset + weight_idx];
|
||||
float input_value = src1[input_offset+i];
|
||||
|
||||
accumulator += kernel_weight * input_value;
|
||||
}
|
||||
}
|
||||
dst[global_index] = accumulator;
|
||||
}
|
||||
|
||||
static void conv_transpose_1d_f32_f32_cuda(
|
||||
const int s0, const int p0, const int d0, const int output_size,
|
||||
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
|
||||
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
|
||||
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
|
||||
const float * src0, const float * src1, float * dst,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
|
||||
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
|
||||
s0,p0,d0,output_size,
|
||||
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
|
||||
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
|
||||
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
|
||||
src0,src1, dst);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *)src0->data;
|
||||
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const float * src1_d = (const float *)src1->data;
|
||||
|
||||
float * dst_d = (float *)dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
const int32_t * opts = (const int32_t *)dst->op_params;
|
||||
|
||||
const int s0 = opts[0];
|
||||
const int p0 = 0;//opts[3];
|
||||
const int d0 = 1;//opts[4];
|
||||
|
||||
const int64_t kernel_size = ggml_nelements(src0);
|
||||
const int64_t input_size = ggml_nelements(src1);
|
||||
const int64_t output_size = ggml_nelements(dst);
|
||||
|
||||
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
|
||||
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
|
||||
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
|
||||
src0_d, src1_d, dst_d, stream);
|
||||
}
|
5
ggml/src/ggml-cuda/conv-transpose-1d.cuh
Normal file
5
ggml/src/ggml-cuda/conv-transpose-1d.cuh
Normal file
|
@ -0,0 +1,5 @@
|
|||
#include "common.cuh"
|
||||
|
||||
#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
||||
|
||||
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
|
@ -68,7 +68,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|||
const int iqs4 = k_KQ % QI4_0;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_from_uint8(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
@ -108,7 +108,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|||
const int iqs4 = k_KQ % QI4_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
const int v = (get_int_from_uint8_aligned(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int u = Q_q8[k_KQ_0/WARP_SIZE];
|
||||
|
||||
const int sumi = ggml_cuda_dp4a(v, u, 0);
|
||||
|
@ -153,8 +153,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_from_uint8(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_from_uint8(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
|
||||
int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
|
@ -200,8 +200,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|||
const int iqs8 = k_KQ % QI8_1;
|
||||
const int shift = k_KQ & (QI8_1/2);
|
||||
|
||||
int v = (get_int_from_uint8(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_from_uint8(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
|
||||
int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
|
||||
const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1);
|
||||
v |= (vh << 4) & 0x00000010; // 0 -> 4
|
||||
v |= (vh << 11) & 0x00001000; // 1 -> 12
|
||||
v |= (vh << 18) & 0x00100000; // 2 -> 20
|
||||
|
@ -249,7 +249,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|||
const int ib = k_KQ / QI8_0;
|
||||
const int iqs = k_KQ % QI8_0;
|
||||
|
||||
const int v = get_int_from_int8(K_q8_0[ib].qs, iqs);
|
||||
const int v = get_int_b2(K_q8_0[ib].qs, iqs);
|
||||
|
||||
T Q_d;
|
||||
if (std::is_same<T, half>::value) {
|
||||
|
@ -408,7 +408,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
|
|||
|
||||
const T d = x[ib].d;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_from_uint8(x[ib].qh, 0);
|
||||
const int qh0 = get_int_b2(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh) - 16;
|
||||
|
@ -433,7 +433,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
|
|||
|
||||
const half2 dm = x[ib].dm;
|
||||
const int ql0 = x[ib].qs[iqs];
|
||||
const int qh0 = get_int_from_uint8_aligned(x[ib].qh, 0);
|
||||
const int qh0 = get_int_b4(x[ib].qh, 0);
|
||||
const int ql = ((ql0 >> (4*shift)) & 0x0F);
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh);
|
||||
|
|
|
@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
|
|||
}
|
||||
#endif // defined(INT8_MMA_AVAILABLE)
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
|
||||
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K4 {
|
||||
|
|
|
@ -59,6 +59,12 @@ void ggml_cuda_op_mul_mat_q(
|
|||
case GGML_TYPE_Q6_K:
|
||||
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
|
@ -87,6 +93,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
mmq_supported = true;
|
||||
break;
|
||||
default:
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
|||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||
}
|
||||
|
||||
template <bool need_sum>
|
||||
template <mmq_q8_1_ds_layout ds_layout>
|
||||
static __global__ void quantize_mmq_q8_1(
|
||||
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||
|
||||
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
||||
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||
|
||||
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||
|
||||
if (ix0 >= kx0_padded) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float4 * x4 = (const float4 *) x;
|
||||
|
||||
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
||||
|
||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||
|
||||
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
|
||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||
|
||||
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
|
||||
float amax = fabsf(xi);
|
||||
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float amax = fabsf(xi.x);
|
||||
amax = fmaxf(amax, fabsf(xi.y));
|
||||
amax = fmaxf(amax, fabsf(xi.z));
|
||||
amax = fmaxf(amax, fabsf(xi.w));
|
||||
|
||||
amax = warp_reduce_max(amax);
|
||||
|
||||
float sum;
|
||||
if (need_sum) {
|
||||
sum = warp_reduce_sum(xi);
|
||||
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
const float d = amax / 127;
|
||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
||||
float sum;
|
||||
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
sum = xi.x + xi.y + xi.z + xi.w;
|
||||
|
||||
y[ib].qs[iqs] = q;
|
||||
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||
#pragma unroll
|
||||
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
||||
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
const float d_inv = 127.0f / amax;
|
||||
char4 q;
|
||||
q.x = roundf(xi.x*d_inv);
|
||||
q.y = roundf(xi.y*d_inv);
|
||||
q.z = roundf(xi.z*d_inv);
|
||||
q.w = roundf(xi.w*d_inv);
|
||||
|
||||
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
||||
char4 * yqs4 = (char4 *) y[ib].qs;
|
||||
yqs4[iqs/4] = q;
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
||||
if (iqs % 16 != 0 || iqs >= 96) {
|
||||
return;
|
||||
}
|
||||
|
||||
y[ib].d2s6[2 + iqs/16] = sum;
|
||||
|
||||
if (iqs % 64 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float d = 1.0f / d_inv;
|
||||
|
||||
y[ib].d2s6[iqs/64] = d;
|
||||
|
||||
if (iqs % QK8_1 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (need_sum) {
|
||||
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
|
||||
if (iqs % 32 != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float d = 1.0f / d_inv;
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
||||
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
||||
} else {
|
||||
((float *) y[ib].ds)[iqs/QK8_1] = d;
|
||||
y[ib].d4[iqs/32] = d;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
|
|||
|
||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||
|
||||
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
||||
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||
const dim3 num_blocks(block_num_x, kx1, channels);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
||||
if (mmq_need_sum(type_x)) {
|
||||
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
} else {
|
||||
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
|
||||
#include <cstdint>
|
||||
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
|
||||
|
||||
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
|
||||
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
||||
|
||||
typedef void (*quantize_cuda_t)(
|
||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||
|
|
|
@ -22,7 +22,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}
|
|||
|
||||
TYPES_MMQ = [
|
||||
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K"
|
||||
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
|
||||
"GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
|
||||
]
|
||||
|
||||
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
|
|
@ -0,0 +1,5 @@
|
|||
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
||||
|
||||
#include "../mmq.cuh"
|
||||
|
||||
DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
|
|
@ -1,36 +1,8 @@
|
|||
#include "common.cuh"
|
||||
#include <cstdint>
|
||||
|
||||
static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
|
||||
|
||||
int x32 = 0;
|
||||
x32 |= x16[0] << 0;
|
||||
x32 |= x16[1] << 16;
|
||||
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment
|
||||
|
||||
int x32 = 0;
|
||||
x32 |= x16[0] << 0;
|
||||
x32 |= x16[1] << 16;
|
||||
|
||||
return x32;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) {
|
||||
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) {
|
||||
return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
|
||||
const uint16_t * x16 = (const uint16_t *) x;
|
||||
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
|
||||
|
||||
int x32 = x16[2*i32 + 0] << 0;
|
||||
x32 |= x16[2*i32 + 1] << 16;
|
||||
|
@ -217,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
|||
}
|
||||
|
||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||
#define VDR_Q2_K_Q8_1_MMQ 2
|
||||
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||
|
||||
// contiguous v/x values
|
||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||
|
@ -247,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
|||
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
||||
}
|
||||
|
||||
// contiguous u/y values
|
||||
// contiguous v/x + u/y values
|
||||
template <int ns8>
|
||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
|
||||
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
|
||||
|
||||
float sumf_d = 0.0f;
|
||||
float sumf_m = 0.0f;
|
||||
float sumf = 0.0f;
|
||||
float sumf_d8 = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
||||
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
|
||||
int sumi_d = 0;
|
||||
int sumi_m = 0;
|
||||
for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
|
||||
const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
|
||||
int sumi_d0 = 0;
|
||||
|
||||
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
|
||||
int sumi_d1 = 0;
|
||||
|
||||
const int vi0 = v[i0/(QI8_1/2)];
|
||||
#pragma unroll
|
||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
|
||||
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
|
||||
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
|
||||
sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
|
||||
}
|
||||
sumf_d8 += dm2f0.x * sumi_d0;
|
||||
|
||||
sumf_d += dm2f.x * sumi_d;
|
||||
sumf_m += dm2f.y * sumi_m;
|
||||
#pragma unroll
|
||||
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
|
||||
}
|
||||
sumf_d8 += dm2f1.x * sumi_d1;
|
||||
|
||||
if (i0/QI8_1 < ns8) {
|
||||
const float2 s8f = __half22float2(s8[i0/QI8_1]);
|
||||
sumf -= dm2f0.y*s8f.x;
|
||||
sumf -= dm2f1.y*s8f.y;
|
||||
} else {
|
||||
int sumi_m0 = 0;
|
||||
#pragma unroll
|
||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
|
||||
}
|
||||
sumf_d8 -= dm2f0.y * sumi_m0;
|
||||
|
||||
int sumi_m1 = 0;
|
||||
#pragma unroll
|
||||
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
|
||||
}
|
||||
sumf_d8 -= dm2f1.y * sumi_m1;
|
||||
}
|
||||
}
|
||||
|
||||
return d8*(sumf_d - sumf_m);
|
||||
return sumf + d8*sumf_d8;
|
||||
}
|
||||
|
||||
#define VDR_Q3_K_Q8_1_MMVQ 1
|
||||
|
@ -311,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
|
|||
return d3 * sumf;
|
||||
}
|
||||
|
||||
// contiguous u/y values
|
||||
// contiguous v/x + u/y values
|
||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
||||
const float & d3, const float & d8) {
|
||||
|
@ -324,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
|||
|
||||
#pragma unroll
|
||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
|
||||
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
|
||||
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
|
||||
}
|
||||
|
||||
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
||||
|
@ -362,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
|
|||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||
}
|
||||
|
||||
// contiguous u/y values
|
||||
// contiguous v/x + u/y values
|
||||
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||
|
@ -425,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
|
|||
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
||||
}
|
||||
|
||||
// contiguous u/y values
|
||||
// contiguous v/x + u/y values
|
||||
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||
|
@ -479,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
|
|||
return d*sumf;
|
||||
}
|
||||
|
||||
// contiguous u/y values
|
||||
// contiguous v/x + u/y values
|
||||
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
||||
const float & d6, const float * __restrict__ d8) {
|
||||
|
||||
float sumf_d = 0.0f;
|
||||
|
||||
const int sc_packed = get_int_b4(sc, 0);
|
||||
const int8_t * sc_reg = (const int8_t *) &sc_packed;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
||||
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
||||
|
@ -499,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
|||
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
||||
}
|
||||
|
||||
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
|
||||
sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
|
||||
}
|
||||
|
||||
return d6 * sumf_d;
|
||||
|
@ -768,6 +766,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ2_XXS_Q8_1_MMVQ 2
|
||||
#define VDR_IQ2_XXS_Q8_1_MMQ 2
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
@ -802,6 +801,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ2_XS_Q8_1_MMVQ 2
|
||||
#define VDR_IQ2_XS_Q8_1_MMQ 2
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
@ -840,6 +840,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ2_S_Q8_1_MMVQ 2
|
||||
#define VDR_IQ2_S_Q8_1_MMQ 2
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
@ -887,6 +888,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ3_XXS_Q8_1_MMVQ 2
|
||||
#define VDR_IQ3_XXS_Q8_1_MMQ 2
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
@ -921,6 +923,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ3_S_Q8_1_MMVQ 2
|
||||
#define VDR_IQ3_S_Q8_1_MMQ 2
|
||||
|
||||
// TODO: don't use lookup table for signs
|
||||
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
||||
|
@ -962,6 +965,9 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|||
return d * sumi;
|
||||
}
|
||||
|
||||
#define VDR_IQ1_S_Q8_1_MMVQ 1
|
||||
#define VDR_IQ1_S_Q8_1_MMQ 1
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
|
||||
|
@ -992,6 +998,9 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|||
return d1q * (ds.x*sumi + ds.y*delta);
|
||||
}
|
||||
|
||||
#define VDR_IQ1_M_Q8_1_MMVQ 1
|
||||
#define VDR_IQ1_M_Q8_1_MMQ 1
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
|
@ -1051,6 +1060,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
|
|||
}
|
||||
|
||||
#define VDR_IQ4_NL_Q8_1_MMVQ 2
|
||||
#define VDR_IQ4_NL_Q8_1_MMQ 4
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
@ -1074,6 +1084,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|||
}
|
||||
|
||||
#define VDR_IQ4_XS_Q8_1_MMVQ 4
|
||||
#define VDR_IQ4_XS_Q8_1_MMQ 4
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
|
|
@ -609,6 +609,10 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
|||
|
||||
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
||||
|
||||
#ifdef __ARM_FEATURE_SVE
|
||||
#include <arm_sve.h>
|
||||
#endif // __ARM_FEATURE_SVE
|
||||
|
||||
// precomputed f32 table for f16 (256 KB)
|
||||
// defined in ggml.c, initialized in ggml_init()
|
||||
extern float ggml_table_f32_f16[1 << 16];
|
||||
|
|
|
@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
|
|||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||
GGML_METAL_KERNEL_TYPE_SQR,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
|
@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
|
@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q8_0:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
|
@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
}
|
||||
case GGML_TYPE_F16:
|
||||
switch (op->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_F32:
|
||||
case GGML_TYPE_F16:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_GET_ROWS:
|
||||
{
|
||||
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
||||
return op->ne[3] == 1;
|
||||
}
|
||||
default:
|
||||
return false;
|
||||
|
@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
|
@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||
|
@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
case GGML_TYPE_F16:
|
||||
{
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
||||
default: GGML_ASSERT(false && "not implemented");
|
||||
};
|
||||
} break;
|
||||
|
|
|
@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
#define N_F32_F32 4
|
||||
#define N_MV_T_T 4
|
||||
|
||||
void kernel_mul_mv_f32_f32_impl(
|
||||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
void kernel_mul_mv_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
|
@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
|
|||
uint64_t nb12,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
uint r2,
|
||||
uint r3,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
|
||||
uint r2,
|
||||
uint r3,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F32_F32;
|
||||
const int64_t rb = tgpig.y*N_MV_T_T;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
|
@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
|
|||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const float * x = (device const float *) (src0 + offset0);
|
||||
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
sumf += (T0) x[i] * (T1) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
|
@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
device const float4 * x4 = (device const float4 *)x;
|
||||
for (int row = 0; row < N_F32_F32; ++row) {
|
||||
device const T04 * x4 = (device const T04 *) x;
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||
device const T14 * y4 = (device const T14 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_f32_f32")]]
|
||||
kernel void kernel_mul_mv_f32_f32(
|
||||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
kernel void kernel_mul_mv(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
|
@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
|
|||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14>(
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
nb00,
|
||||
nb01,
|
||||
nb02,
|
||||
ne10,
|
||||
ne11,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
tgpig,
|
||||
tiisg);
|
||||
}
|
||||
|
||||
#define N_F16_F16 4
|
||||
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
||||
|
||||
kernel void kernel_mul_mv_f16_f16(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F16_F16;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const half * x = (device const half *) (src0 + offset0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (half) x[i] * (half) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *)x;
|
||||
for (int row = 0; row < N_F16_F16; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
||||
device const half4 * y4 = (device const half4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void kernel_mul_mv_f16_f32_1row_impl(
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
|
@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const half * x = (device const half *) (src0 + offset0);
|
||||
device const T * x = (device const T *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
|
@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
|||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *) x;
|
||||
device const T4 * x4 = (device const T4 *) x;
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
||||
kernel void kernel_mul_mv_f16_f32_1row(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||
}
|
||||
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
||||
|
||||
#define N_F16_F32 4
|
||||
|
||||
void kernel_mul_mv_f16_f32_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
int64_t ne00,
|
||||
int64_t ne01,
|
||||
int64_t ne02,
|
||||
uint64_t nb00,
|
||||
uint64_t nb01,
|
||||
uint64_t nb02,
|
||||
int64_t ne10,
|
||||
int64_t ne11,
|
||||
int64_t ne12,
|
||||
uint64_t nb10,
|
||||
uint64_t nb11,
|
||||
uint64_t nb12,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
uint r2,
|
||||
uint r3,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
|
||||
const int64_t r0 = tgpig.x;
|
||||
const int64_t rb = tgpig.y*N_F16_F32;
|
||||
const int64_t im = tgpig.z;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const half * x = (device const half *) (src0 + offset0);
|
||||
|
||||
if (ne00 < 128) {
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
device const half4 * x4 = (device const half4 *)x;
|
||||
for (int row = 0; row < N_F16_F32; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= ne11) {
|
||||
break;
|
||||
}
|
||||
|
||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_f16_f32")]]
|
||||
kernel void kernel_mul_mv_f16_f32(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
||||
}
|
||||
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
kernel void kernel_mul_mv_f16_f32_l4(
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_l4(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
|
@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|||
|
||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||
|
||||
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
||||
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
||||
|
||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||
}
|
||||
|
||||
float all_sum = simd_sum(sumf);
|
||||
|
@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
|||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||
return 1.0f - min(1.0f, max(0.0f, y));
|
||||
|
@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||
|
||||
kernel void kernel_cpy_f16_f16(
|
||||
device const half * src0,
|
||||
device half * dst,
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy(
|
||||
device const void * src0,
|
||||
device void * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
|
@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16(
|
|||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
dst_data[i00] = src[0];
|
||||
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
dst_data[i00] = (T1) src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f16_f32(
|
||||
device const half * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_f16(
|
||||
device const float * src0,
|
||||
device half * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_cpy_f32_f32(
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne03,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant uint64_t & nb03,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & ne2,
|
||||
constant int64_t & ne3,
|
||||
constant uint64_t & nb0,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
constant uint64_t & nb3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = tgpig[2];
|
||||
const int64_t i02 = tgpig[1];
|
||||
const int64_t i01 = tgpig[0];
|
||||
|
||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||
|
||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||
|
||||
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||
|
||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||
|
||||
dst_data[i00] = src[0];
|
||||
}
|
||||
}
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||
|
||||
kernel void kernel_cpy_f32_q8_0(
|
||||
device const float * src0,
|
||||
|
@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
|||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows(
|
||||
kernel void kernel_get_rows_q(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
|
@ -5745,27 +5456,24 @@ kernel void kernel_get_rows(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
//const int64_t i = tgpig;
|
||||
//const int64_t r = ((device int32_t *) src1)[i];
|
||||
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
||||
float4x4 temp;
|
||||
dequantize_func(
|
||||
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
||||
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
||||
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f32(
|
||||
template<typename T>
|
||||
kernel void kernel_get_rows_f(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device const void * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
|
@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32(
|
|||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_f16(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint3 tptg [[threads_per_threadgroup]]) {
|
||||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_get_rows_i32(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device const void * src1,
|
||||
device int32_t * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
|
@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
|
|||
const int64_t i10 = tgpig.x;
|
||||
const int64_t i11 = tgpig.y;
|
||||
|
||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||
|
||||
const int64_t i02 = i11;
|
||||
|
||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
||||
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
|
|||
#define SG_MAT_ROW 8
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
void kernel_mul_mm_impl(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
|
||||
const uint r0 = tgpig.y;
|
||||
|
@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_T8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_float8x8 c_res[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
|
@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
// load data and store to threadgroup memory
|
||||
half4x4 temp_a;
|
||||
T4x4 temp_a;
|
||||
dequantize_func(x, il, temp_a);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
|
@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load matrices from threadgroup memory and conduct outer products
|
||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||
|
||||
#pragma unroll(4)
|
||||
|
@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
nb02,
|
||||
ne12,
|
||||
nb10,
|
||||
nb11,
|
||||
nb12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
shared_memory,
|
||||
tgpig,
|
||||
tiitg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
device const uchar * src0s,
|
||||
|
@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
|
|||
// get rows
|
||||
//
|
||||
|
||||
typedef void (get_rows_t)(
|
||||
device const void * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb1,
|
||||
constant uint64_t & nb2,
|
||||
uint3, uint, uint3);
|
||||
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
||||
|
||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
||||
|
||||
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// matrix-matrix multiplication
|
||||
//
|
||||
|
||||
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
||||
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
||||
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||
|
||||
//
|
||||
// indirect matrix-matrix multiplication
|
||||
|
@ -6436,7 +6065,7 @@ void mmv_fn(
|
|||
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
||||
|
||||
template<mul_mv_impl_fn_t impl_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
|
@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
|
|||
sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
||||
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
||||
|
|
|
@ -658,7 +658,7 @@ static inline __m128i packNibbles( __m256i bytes ) {
|
|||
#endif //__loongarch_asx
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||
void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
@ -696,11 +696,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
|
|||
}
|
||||
|
||||
void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q4_0_reference(x, y, k);
|
||||
quantize_row_q4_0_ref(x, y, k);
|
||||
}
|
||||
|
||||
|
||||
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
|
||||
void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
|
||||
const int qk = QK4_1;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
@ -738,10 +738,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
|
|||
}
|
||||
|
||||
void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q4_1_reference(x, y, k);
|
||||
quantize_row_q4_1_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
|
||||
void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
|
||||
static const int qk = QK5_0;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
@ -786,10 +786,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
|
|||
}
|
||||
|
||||
void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q5_0_reference(x, y, k);
|
||||
quantize_row_q5_0_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
|
||||
void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
|
||||
const int qk = QK5_1;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
@ -834,11 +834,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
|
|||
}
|
||||
|
||||
void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q5_1_reference(x, y, k);
|
||||
quantize_row_q5_1_ref(x, y, k);
|
||||
}
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
|
||||
void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
|
||||
assert(k % QK8_0 == 0);
|
||||
const int nb = k / QK8_0;
|
||||
|
||||
|
@ -1144,12 +1144,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
|||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
quantize_row_q8_0_reference(x, y, k);
|
||||
quantize_row_q8_0_ref(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
|
||||
void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
|
||||
assert(QK8_1 == 32);
|
||||
assert(k % QK8_1 == 0);
|
||||
const int nb = k / QK8_1;
|
||||
|
@ -1508,7 +1508,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
|||
#else
|
||||
GGML_UNUSED(nb);
|
||||
// scalar
|
||||
quantize_row_q8_1_reference(x, y, k);
|
||||
quantize_row_q8_1_ref(x, y, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -1899,7 +1899,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
|
|||
|
||||
//========================- 2-bit (de)-quantization
|
||||
|
||||
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) {
|
||||
void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
|
@ -2002,7 +2002,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
|
|||
}
|
||||
|
||||
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
quantize_row_q2_K_reference(x, vy, k);
|
||||
quantize_row_q2_K_ref(x, vy, k);
|
||||
}
|
||||
|
||||
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
||||
|
@ -2226,7 +2226,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
|||
size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
|
||||
if (!quant_weights) {
|
||||
quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
}
|
||||
else {
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -2241,7 +2241,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
|
|||
|
||||
//========================= 3-bit (de)-quantization
|
||||
|
||||
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) {
|
||||
void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
|
@ -2368,7 +2368,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
|
|||
}
|
||||
|
||||
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
quantize_row_q3_K_reference(x, vy, k);
|
||||
quantize_row_q3_K_ref(x, vy, k);
|
||||
}
|
||||
|
||||
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
|
||||
|
@ -2458,7 +2458,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
|
|||
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
|
||||
if (!quant_weights) {
|
||||
quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
}
|
||||
else {
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -2473,7 +2473,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
|
|||
|
||||
// ====================== 4-bit (de)-quantization
|
||||
|
||||
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) {
|
||||
void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
|
@ -2572,7 +2572,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
|
|||
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_q4_K * restrict y = vy;
|
||||
quantize_row_q4_K_reference(x, y, k);
|
||||
quantize_row_q4_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||
|
@ -2651,7 +2651,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
|||
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
}
|
||||
else {
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -2666,7 +2666,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
|
|||
|
||||
// ====================== 5-bit (de)-quantization
|
||||
|
||||
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) {
|
||||
void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int64_t nb = k / QK_K;
|
||||
|
||||
|
@ -2783,7 +2783,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
|
|||
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_q5_K * restrict y = vy;
|
||||
quantize_row_q5_K_reference(x, y, k);
|
||||
quantize_row_q5_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||
|
@ -2882,7 +2882,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
|||
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
|
||||
if (!quant_weights) {
|
||||
quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
}
|
||||
else {
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -2897,7 +2897,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
|
|||
|
||||
// ====================== 6-bit (de)-quantization
|
||||
|
||||
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) {
|
||||
void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int64_t nb = k / QK_K;
|
||||
|
||||
|
@ -3001,7 +3001,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
|
|||
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_q6_K * restrict y = vy;
|
||||
quantize_row_q6_K_reference(x, y, k);
|
||||
quantize_row_q6_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||
|
@ -3091,7 +3091,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
|
|||
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
|
||||
if (!quant_weights) {
|
||||
quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
}
|
||||
else {
|
||||
char * qrow = (char *)dst;
|
||||
|
@ -3108,7 +3108,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
|
|||
static_assert(QK4_0 == 32, "QK4_0 must be 32");
|
||||
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_0_reference(x, y, n_per_row);
|
||||
quantize_row_q4_0_ref(x, y, n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3134,7 +3134,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
|
|||
|
||||
size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||
|
@ -3151,7 +3151,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
|
|||
static_assert(QK4_1 == 32, "QK4_1 must be 32");
|
||||
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_1_reference(x, y, n_per_row);
|
||||
quantize_row_q4_1_ref(x, y, n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3179,7 +3179,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
|
|||
|
||||
size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
||||
|
@ -3196,7 +3196,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
|
|||
static_assert(QK5_0 == 32, "QK5_0 must be 32");
|
||||
|
||||
if (!quant_weights) {
|
||||
quantize_row_q5_0_reference(x, y, n_per_row);
|
||||
quantize_row_q5_0_ref(x, y, n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3233,7 +3233,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
|
|||
|
||||
size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
||||
|
@ -3250,7 +3250,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
|
|||
static_assert(QK5_1 == 32, "QK5_1 must be 32");
|
||||
|
||||
if (!quant_weights) {
|
||||
quantize_row_q5_1_reference(x, y, n_per_row);
|
||||
quantize_row_q5_1_ref(x, y, n_per_row);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -3286,7 +3286,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
|
|||
|
||||
size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
if (!quant_weights) {
|
||||
quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
||||
}
|
||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
||||
|
@ -3302,7 +3302,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
|
|||
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
(void)quant_weights; // not used
|
||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
|
||||
quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
||||
quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
|
@ -3590,7 +3590,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
|
|||
|
||||
//===================================== Q8_K ==============================================
|
||||
|
||||
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) {
|
||||
void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
const int64_t nb = k / QK_K;
|
||||
|
||||
|
@ -3641,7 +3641,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
|
|||
}
|
||||
|
||||
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q8_K_reference(x, y, k);
|
||||
quantize_row_q8_K_ref(x, y, k);
|
||||
}
|
||||
|
||||
//===================================== Dot ptoducts =================================
|
||||
|
@ -3814,43 +3814,47 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
||||
if (svcntb() == QK8_0) {
|
||||
const svbool_t ptrueh = svptrue_pat_b8(SV_VL16);
|
||||
const svbool_t ptruel = svnot_b_z(svptrue_b8(), ptrueh);
|
||||
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_0 * restrict x0 = &x[i + 0];
|
||||
const block_q4_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
||||
// load x
|
||||
const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs);
|
||||
const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs);
|
||||
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
||||
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
||||
// 4-bit -> 8-bit
|
||||
const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx0r, 0x0F), 0x04));
|
||||
const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(ptruel, svand_n_u8_m(ptrueh, qx1r, 0x0F), 0x04));
|
||||
|
||||
// sub 8
|
||||
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
||||
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
||||
// sub 8
|
||||
const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8);
|
||||
const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
|
||||
// dot product
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
// dot product
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
return;
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
#elif defined(__ARM_NEON)
|
||||
#endif
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
|
@ -5422,31 +5426,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
|
|||
}
|
||||
#endif
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
if (svcntb() == QK8_0) {
|
||||
svfloat32_t sumv0 = svdup_n_f32(0.0f);
|
||||
svfloat32_t sumv1 = svdup_n_f32(0.0f);
|
||||
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
assert(nb % 2 == 0); // TODO: handle odd nb
|
||||
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q8_0 * restrict x0 = &x[i + 0];
|
||||
const block_q8_0 * restrict x1 = &x[i + 1];
|
||||
const block_q8_0 * restrict y0 = &y[i + 0];
|
||||
const block_q8_0 * restrict y1 = &y[i + 1];
|
||||
|
||||
// load x
|
||||
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
||||
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
||||
// load x
|
||||
const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs);
|
||||
const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs);
|
||||
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
// load y
|
||||
const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs);
|
||||
const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs);
|
||||
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d));
|
||||
sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d));
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
return;
|
||||
}
|
||||
|
||||
*s = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
|
||||
#elif defined(__ARM_NEON)
|
||||
#endif
|
||||
#if defined(__ARM_NEON)
|
||||
float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
|
@ -13522,10 +13530,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
|
|||
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_iq3_xxs * restrict y = vy;
|
||||
quantize_row_iq3_xxs_reference(x, y, k);
|
||||
quantize_row_iq3_xxs_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
|
||||
void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
|
||||
}
|
||||
|
@ -13738,10 +13746,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
|
|||
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_iq3_s * restrict y = vy;
|
||||
quantize_row_iq3_s_reference(x, y, k);
|
||||
quantize_row_iq3_s_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
|
||||
void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
quantize_iq3_s(x, y, 1, k, NULL);
|
||||
}
|
||||
|
@ -14479,7 +14487,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k
|
|||
}
|
||||
}
|
||||
|
||||
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
|
||||
void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
|
||||
assert(k % QK4_NL == 0);
|
||||
quantize_row_iq4_nl(x, y, k);
|
||||
}
|
||||
|
@ -14507,10 +14515,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
|
|||
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_iq4_xs * restrict y = vy;
|
||||
quantize_row_iq4_xs_reference(x, y, k);
|
||||
quantize_row_iq4_xs_ref(x, y, k);
|
||||
}
|
||||
|
||||
void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
|
||||
void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
quantize_iq4_xs(x, y, 1, k, NULL);
|
||||
}
|
||||
|
@ -14697,7 +14705,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
|
|||
return nrow * nblock * sizeof(block_iq2_s);
|
||||
}
|
||||
|
||||
void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
|
||||
void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
quantize_iq2_s(x, y, 1, k, NULL);
|
||||
}
|
||||
|
@ -14705,7 +14713,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri
|
|||
void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK_K == 0);
|
||||
block_iq2_s * restrict y = vy;
|
||||
quantize_row_iq2_s_reference(x, y, k);
|
||||
quantize_row_iq2_s_ref(x, y, k);
|
||||
}
|
||||
|
||||
static bool validate_float(float f, size_t i) {
|
||||
|
@ -14760,6 +14768,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
|
|||
} \
|
||||
}
|
||||
|
||||
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
|
||||
const type * q = (const type *) (data); \
|
||||
for (size_t i = 0; i < (nb); ++i) { \
|
||||
for (size_t j = 0; j < (nr); ++j) { \
|
||||
if (!validate_fp16(q[i].d[j], i)) { \
|
||||
return false; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
|
||||
if (type < 0 || type >= GGML_TYPE_COUNT) {
|
||||
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
|
||||
|
@ -14977,6 +14995,16 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
|||
{
|
||||
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
{
|
||||
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
|
||||
} break;
|
||||
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
|
|
@ -12,25 +12,25 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
|
|
@ -49,7 +49,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend);
|
|||
int ggml_backend_sycl_get_device(ggml_backend_t backend);
|
||||
static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
|
||||
static inline int get_sycl_env(const char *env_name, int default_val);
|
||||
static inline int get_work_group_size(const sycl::device& device);
|
||||
|
||||
|
||||
void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
|
||||
const void *ptr_src, size_t size) {
|
||||
|
@ -291,29 +291,6 @@ static void sqr_f32(const float * x, float * dst, const int k,
|
|||
dst[i] = x[i] * x[i];
|
||||
}
|
||||
|
||||
static void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (item_ct1.get_group(0) < ne02) { // src0
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 +
|
||||
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static void upscale_f32(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
|
@ -892,117 +869,6 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
|
|||
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
||||
}
|
||||
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int rowx = item_ct1.get_group(2);
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
|
||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = rowx/nrows_y; // head index
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = sycl::pow(base, float(exp));
|
||||
}
|
||||
|
||||
float * vals = vals_smem ? buf + WARP_SIZE : dst + rowx*ncols;
|
||||
float max_val = -INFINITY;
|
||||
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = sycl::max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = -INFINITY;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = max_val;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
max_val = buf[lane_id];
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
}
|
||||
|
||||
float tmp = 0.f;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = sycl::native::exp(vals[col] - max_val);
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = 0.f;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = tmp;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
tmp = buf[lane_id];
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
static void scale_f32(const float * x, float * dst, const float scale, const int k,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
|
||||
|
@ -1458,20 +1324,6 @@ static void sqr_f32_sycl(const float *x, float *dst, const int k,
|
|||
});
|
||||
}
|
||||
|
||||
static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||
const int ne0, int ne1, int ne2, int ne02,
|
||||
queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32(x, y, dst, ne0, ne02, item_ct1);
|
||||
});
|
||||
}
|
||||
|
||||
static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
|
||||
const int nb02, const int nb03, const int ne10, const int ne11,
|
||||
const int ne12, const int ne13, const float sf0, const float sf1,
|
||||
|
@ -1890,106 +1742,6 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
|
|||
});
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, queue_ptr stream) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
local_buf_acc.get_pointer());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
queue_ptr stream) {
|
||||
int nth = WARP_SIZE;
|
||||
int max_block_size = get_work_group_size(stream->get_device());
|
||||
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||
if (nth>max_block_size) nth = max_block_size;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, 1, nrows_x);
|
||||
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||
if (ncols_x > max_block_size) {
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
return;
|
||||
}
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, WARP_SIZE, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void im2col_sycl(const float *x, T *dst, int IW, int IH,
|
||||
int OW, int OH, int KW, int KH, int IC,
|
||||
|
@ -2156,6 +1908,8 @@ static ggml_sycl_device_info ggml_sycl_init() {
|
|||
|
||||
info.devices[i].cc =
|
||||
100 * prop.get_major_version() + 10 * prop.get_minor_version();
|
||||
|
||||
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
|
||||
}
|
||||
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
|
@ -2638,28 +2392,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
#pragma message("TODO: generalize concat kernel for dim != 2")
|
||||
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7563")
|
||||
int dim = dst->op_params[0];
|
||||
GGML_ASSERT(dim == 2);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_sycl(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
|
||||
}
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
|
@ -3007,33 +2739,6 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const gg
|
|||
(void) src1_dd;
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream);
|
||||
}
|
||||
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
|
||||
ggml_tensor *dst, const float *src0_dd,
|
||||
const float *src1_dd, float *dst_dd,
|
||||
|
@ -3595,12 +3300,6 @@ static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_concat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_concat);
|
||||
GGML_SYCL_DEBUG("call %s done\n", __func__);
|
||||
}
|
||||
|
||||
static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_SYCL_DEBUG("call %s\n", __func__);
|
||||
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
|
||||
|
@ -3729,10 +3428,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
queue_ptr main_stream = ctx.stream();;
|
||||
|
||||
bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
|
||||
main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
|
||||
|
||||
|
||||
void * src0_ddq = src0->data;
|
||||
sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
|
||||
float * src1_ddf = (float *) src1->data;
|
||||
|
@ -3750,15 +3445,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|||
sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
|
||||
: src1_f16_alloc.get();
|
||||
|
||||
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
||||
char * dst_t;
|
||||
|
||||
dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
|
||||
dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
|
||||
if (no_mixed_dtypes) {
|
||||
cu_compute_type = dpct::library_data_t::real_half;
|
||||
cu_data_type = dpct::library_data_t::real_half;
|
||||
}
|
||||
|
||||
// dst strides
|
||||
size_t nbd2 = dst->nb[2];
|
||||
|
@ -3767,26 +3457,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|||
const float alpha_f32 = 1.0f;
|
||||
const float beta_f32 = 0.0f;
|
||||
|
||||
const sycl::half alpha_f16 = 1.0f;
|
||||
const sycl::half beta_f16 = 0.0f;
|
||||
|
||||
const void * alpha = &alpha_f32;
|
||||
const void * beta = &beta_f32;
|
||||
if (no_mixed_dtypes) {
|
||||
alpha = &alpha_f16;
|
||||
beta = &beta_f16;
|
||||
}
|
||||
|
||||
// TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
|
||||
// when oneMKL open source supports half, half, float, float: datatypes
|
||||
|
||||
dst_t = (char *) dst_ddf;
|
||||
if (no_mixed_dtypes) {
|
||||
dst_t = (char *) dst_f16.alloc(ne_dst);
|
||||
|
||||
nbd2 /= sizeof(float) / sizeof(sycl::half);
|
||||
nbd3 /= sizeof(float) / sizeof(sycl::half);
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
@ -3848,11 +3522,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
|
|||
(void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
|
||||
cu_compute_type)));
|
||||
}
|
||||
|
||||
if (no_mixed_dtypes) {
|
||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||
to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
|
||||
}
|
||||
}
|
||||
catch (sycl::exception const &exc) {
|
||||
std::cerr << exc.what() << "Exception caught at file:" << __FILE__
|
||||
|
@ -3924,6 +3593,10 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
|
|||
use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
||||
#endif // SYCL_USE_XMX
|
||||
|
||||
// mmvq path is faster in the CUDA backend.
|
||||
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
|
||||
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
||||
|
||||
if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
// KQ single-batch
|
||||
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
|
||||
|
@ -4030,37 +3703,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
||||
|
||||
const ggml_tensor_extra_gpu *src0_extra =
|
||||
(const ggml_tensor_extra_gpu *)src0->extra;
|
||||
const ggml_tensor_extra_gpu *src1_extra =
|
||||
(const ggml_tensor_extra_gpu *)src1->extra;
|
||||
const ggml_tensor_extra_gpu *dst_extra =
|
||||
(const ggml_tensor_extra_gpu *)dst->extra;
|
||||
|
||||
ggml_tensor_extra_gpu src0_row_extra;
|
||||
ggml_tensor_extra_gpu src1_row_extra;
|
||||
ggml_tensor_extra_gpu dst_row_extra;
|
||||
|
||||
ggml_tensor src0_row = *src0;
|
||||
ggml_tensor src1_row = *src1;
|
||||
ggml_tensor dst_row = *dst;
|
||||
|
||||
src1_row.backend = GGML_BACKEND_TYPE_GPU;
|
||||
dst_row.backend = GGML_BACKEND_TYPE_GPU;
|
||||
|
||||
src0_row.extra = &src0_row_extra;
|
||||
src1_row.extra = &src1_row_extra;
|
||||
dst_row.extra = &dst_row_extra;
|
||||
|
||||
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
||||
? (char *)src0->data
|
||||
: (char *)src0_extra->data_device[ctx.device];
|
||||
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
||||
? (char *)src1->data
|
||||
: (char *)src1_extra->data_device[ctx.device];
|
||||
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
|
||||
? (char *)dst->data
|
||||
: (char *)dst_extra->data_device[ctx.device];
|
||||
char *src0_original = (char *)src0->data;
|
||||
char *src1_original = (char *)src1->data;
|
||||
char *dst_original = (char *)dst->data;
|
||||
|
||||
src0_row.ne[2] = 1;
|
||||
src0_row.ne[3] = 1;
|
||||
|
@ -4089,12 +3738,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
const int64_t i1 = id;
|
||||
const int64_t i2 = i12;
|
||||
|
||||
src0_row_extra.data_device[ctx.device] =
|
||||
src0_original + i02*nb02;
|
||||
src1_row_extra.data_device[ctx.device] =
|
||||
src1_original + + i11*nb11 + i12*nb12;
|
||||
dst_row_extra.data_device[ctx.device] =
|
||||
dst_original + i1*nb1 + i2*nb2;
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
src1_row.data = src1_original + + i11*nb11 + i12*nb12;
|
||||
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
||||
|
||||
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||
}
|
||||
|
@ -4103,8 +3749,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||
|
||||
src1_row_extra.data_device[ctx.device] = src1_contiguous.get();
|
||||
dst_row_extra.data_device[ctx.device] = dst_contiguous.get();
|
||||
src1_row.data = src1_contiguous.get();
|
||||
dst_row.data = dst_contiguous.get();
|
||||
|
||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||
int64_t num_src1_rows = 0;
|
||||
|
@ -4160,7 +3806,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
|||
});
|
||||
}
|
||||
|
||||
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02;
|
||||
src0_row.data = src0_original + i02*nb02;
|
||||
|
||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||
|
@ -4390,7 +4036,7 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
|
|||
func = ggml_sycl_group_norm;
|
||||
break;
|
||||
case GGML_OP_CONCAT:
|
||||
func = ggml_sycl_concat;
|
||||
func = ggml_sycl_op_concat;
|
||||
break;
|
||||
case GGML_OP_UPSCALE:
|
||||
func = ggml_sycl_upscale;
|
||||
|
@ -5483,6 +5129,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|||
return false;
|
||||
}
|
||||
}
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
if (src0_type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
@ -5530,7 +5180,8 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|||
case GGML_OP_CONCAT:
|
||||
{
|
||||
ggml_type src0_type = op->src[0]->type;
|
||||
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
|
||||
int dim = op->op_params[0];
|
||||
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_NONE:
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#ifndef GGML_SYCL_BACKEND_HPP
|
||||
#define GGML_SYCL_BACKEND_HPP
|
||||
|
||||
#include "concat.hpp"
|
||||
#include "common.hpp"
|
||||
#include "convert.hpp"
|
||||
#include "dequantize.hpp"
|
||||
|
@ -21,5 +22,6 @@
|
|||
#include "mmvq.hpp"
|
||||
#include "rope.hpp"
|
||||
#include "norm.hpp"
|
||||
#include "softmax.hpp"
|
||||
|
||||
#endif // GGML_SYCL_BACKEND_HPP
|
||||
|
|
|
@ -47,10 +47,6 @@ static int g_ggml_sycl_debug = 0;
|
|||
} \
|
||||
}()
|
||||
|
||||
// #define DEBUG_SYCL_MALLOC
|
||||
|
||||
static int g_work_group_size = 0;
|
||||
// typedef sycl::half ggml_fp16_t;
|
||||
|
||||
#define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
|
||||
#define VER_4VEC 610 // todo for hardward optimize.
|
||||
|
@ -193,6 +189,8 @@ struct ggml_sycl_device_info {
|
|||
sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {};
|
||||
|
||||
std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {};
|
||||
|
||||
int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0};
|
||||
};
|
||||
|
||||
const ggml_sycl_device_info & ggml_sycl_info();
|
||||
|
@ -295,15 +293,6 @@ struct ggml_backend_sycl_context {
|
|||
}
|
||||
};
|
||||
|
||||
// common host functions
|
||||
|
||||
static inline int get_work_group_size(const sycl::device& device) {
|
||||
dpct::device_info prop;
|
||||
dpct::get_device_info(prop, device);
|
||||
return prop.get_max_work_group_size();
|
||||
}
|
||||
|
||||
|
||||
// common device functions
|
||||
|
||||
static __dpct_inline__ float warp_reduce_sum(float x,
|
||||
|
@ -357,4 +346,10 @@ inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
|
|||
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
|
||||
}
|
||||
|
||||
// Helper for accessing pointers with no warnings
|
||||
template <typename Tp, int dim>
|
||||
static __dpct_inline__ Tp* get_pointer(sycl::local_accessor<Tp, dim> acc) {
|
||||
return acc.template get_multi_ptr<sycl::access::decorated::no>().get();
|
||||
}
|
||||
|
||||
#endif // GGML_SYCL_COMMON_HPP
|
||||
|
|
195
ggml/src/ggml-sycl/concat.cpp
Normal file
195
ggml/src/ggml-sycl/concat.cpp
Normal file
|
@ -0,0 +1,195 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#include "concat.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
static void concat_f32_dim0(const float *x, const float *y, float *dst,
|
||||
const int ne0, const int ne00,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (nidx < ne00) { // src0
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne00 +
|
||||
item_ct1.get_group(0) * ne00 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx - ne00 + item_ct1.get_group(1) * (ne0 - ne00) +
|
||||
item_ct1.get_group(0) * (ne0 - ne00) * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_f32_dim1(const float *x, const float *y, float *dst,
|
||||
const int ne0, const int ne01,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (item_ct1.get_group(1) < ne01) { // src0
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 + item_ct1.get_group(0) * ne0 * ne01;
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx + (item_ct1.get_group(1) - ne01) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * (item_ct1.get_group_range(1) - ne01);
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_f32_dim2(const float *x, const float *y, float *dst,
|
||||
const int ne0, const int ne02,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
int nidx = item_ct1.get_local_id(2) +
|
||||
item_ct1.get_group(2) * item_ct1.get_local_range(2);
|
||||
if (nidx >= ne0) {
|
||||
return;
|
||||
}
|
||||
// operation
|
||||
int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
if (item_ct1.get_group(0) < ne02) { // src0
|
||||
int offset_src = nidx + item_ct1.get_group(1) * ne0 +
|
||||
item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = x[offset_src];
|
||||
} else {
|
||||
int offset_src =
|
||||
nidx + item_ct1.get_group(1) * ne0 +
|
||||
(item_ct1.get_group(0) - ne02) * ne0 * item_ct1.get_group_range(1);
|
||||
dst[offset_dst] = y[offset_src];
|
||||
}
|
||||
}
|
||||
|
||||
static void concat_f32_sycl(const float *x, const float *y, float *dst,
|
||||
int ne00, int ne01, int ne02, int ne0, int ne1,
|
||||
int ne2, int dim, queue_ptr stream) {
|
||||
int num_blocks = (ne0 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
|
||||
sycl::range<3> gridDim(ne2, ne1, num_blocks);
|
||||
switch (dim) {
|
||||
case 0:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim0(x, y, dst, ne0, ne00, item_ct1);
|
||||
});
|
||||
break;
|
||||
case 1:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
|
||||
});
|
||||
break;
|
||||
default:
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim *
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
|
||||
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
concat_f32_dim2(x, y, dst, ne0, ne02, item_ct1);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// non-contiguous kernel (slow)
|
||||
static void concat_f32_sycl_non_cont(
|
||||
queue_ptr stream, const char *src0, const char *src1, char *dst,
|
||||
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, uint64_t nb00,
|
||||
uint64_t nb01, uint64_t nb02, uint64_t nb03, int64_t /*ne10*/,
|
||||
int64_t /*ne11*/, int64_t /*ne12*/, int64_t /*ne13*/, uint64_t nb10,
|
||||
uint64_t nb11, uint64_t nb12, uint64_t nb13, int64_t ne0, int64_t ne1,
|
||||
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
|
||||
uint64_t nb3, int32_t dim) {
|
||||
sycl::range<3> gridDim(ne3, ne2, ne1);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
int64_t i3 = item_ct1.get_group(0);
|
||||
int64_t i2 = item_ct1.get_group(1);
|
||||
int64_t i1 = item_ct1.get_group(2);
|
||||
|
||||
int64_t o[4] = {0, 0, 0, 0};
|
||||
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
||||
|
||||
const float *x;
|
||||
|
||||
for (int i0 = item_ct1.get_local_id(2); i0 < ne0;
|
||||
i0 += item_ct1.get_local_range(2)) {
|
||||
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
||||
x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
|
||||
(i0)*nb00);
|
||||
} else {
|
||||
x = (const float *)(src1 + (i3 - o[3]) * nb13 + (i2 - o[2]) * nb12 +
|
||||
(i1 - o[1]) * nb11 + (i0 - o[0]) * nb10);
|
||||
}
|
||||
|
||||
float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
|
||||
|
||||
*y = *x;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst) {
|
||||
queue_ptr stream = ctx.stream();
|
||||
|
||||
const int32_t dim = ((int32_t *)dst->op_params)[0];
|
||||
|
||||
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
||||
const float *src0_d = (const float *)src0->data;
|
||||
const float *src1_d = (const float *)src1->data;
|
||||
|
||||
float *dst_d = (float *)dst->data;
|
||||
|
||||
if (dim != 3) {
|
||||
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
|
||||
concat_f32_sycl(
|
||||
src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4),
|
||||
dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1],
|
||||
src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
|
||||
}
|
||||
} else {
|
||||
const size_t size0 = ggml_nbytes(src0);
|
||||
const size_t size1 = ggml_nbytes(src1);
|
||||
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait()));
|
||||
SYCL_CHECK(CHECK_TRY_ERROR(
|
||||
stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait()));
|
||||
}
|
||||
} else
|
||||
concat_f32_sycl_non_cont(
|
||||
stream, (const char *)src0->data, (const char *)src1->data,
|
||||
(char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
|
||||
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0],
|
||||
src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1],
|
||||
src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2],
|
||||
dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
|
||||
}
|
21
ggml/src/ggml-sycl/concat.hpp
Normal file
21
ggml/src/ggml-sycl/concat.hpp
Normal file
|
@ -0,0 +1,21 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_CONCAT_HPP
|
||||
#define GGML_SYCL_CONCAT_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst);
|
||||
|
||||
#endif // GGML_SYCL_CONCAT_HPP
|
|
@ -158,7 +158,7 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
|
|||
sycl::range<3>(1, 1, 32),
|
||||
sycl::range<3>(1, 1, 32)),
|
||||
[=](sycl::nd_item<3> item_ct1) {
|
||||
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
|
||||
dequantize_block_q4_K(vx, y, get_pointer(scale_local_acc), item_ct1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "dequantize.hpp"
|
||||
#include "presets.hpp"
|
||||
|
||||
|
||||
static void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const sycl::half *x = (const sycl::half *)vx;
|
||||
|
||||
|
@ -227,7 +228,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -346,7 +347,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -499,7 +500,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -633,7 +634,7 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -748,7 +749,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
|||
|
||||
// sum up partial sums and write back result
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
|
||||
tmp +=
|
||||
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
|
||||
}
|
||||
|
@ -873,10 +874,10 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q2_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -889,10 +890,10 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q3_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -905,10 +906,10 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q4_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -918,10 +919,10 @@ static void dequantize_mul_mat_vec_q5_K_sycl(const void *vx, const float *y,
|
|||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, 1, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q5_k(vx, y, dst, ncols, item_ct1);
|
||||
});
|
||||
}
|
||||
|
@ -934,10 +935,10 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
|||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
|
||||
stream->parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
|
||||
dequantize_mul_mat_vec_q6_k(vx, y, dst, ncols, nrows, item_ct1);
|
||||
});
|
||||
}
|
||||
|
|
|
@ -2426,6 +2426,7 @@ namespace dpct
|
|||
b, ldb, beta, c, ldc, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32):
|
||||
|
@ -2458,7 +2459,6 @@ namespace dpct
|
|||
batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float):
|
||||
|
@ -2595,6 +2595,7 @@ namespace dpct
|
|||
stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_int8, library_data_t::real_int8,
|
||||
library_data_t::real_int32, library_data_t::real_int32):
|
||||
|
@ -2623,7 +2624,6 @@ namespace dpct
|
|||
beta, c, ldc, stride_c, batch_size);
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
case detail::get_type_combination_id(
|
||||
library_data_t::real_half, library_data_t::real_half,
|
||||
library_data_t::real_half, library_data_t::real_float):
|
||||
|
|
|
@ -1835,10 +1835,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q4_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q4_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -1870,10 +1870,10 @@ static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q4_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q4_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -1950,10 +1950,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -1985,10 +1985,10 @@ static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q4_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q4_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2065,10 +2065,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q5_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q5_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2100,10 +2100,10 @@ static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q5_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q5_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2180,10 +2180,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2215,10 +2215,10 @@ static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_1<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_1_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_1_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_1_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_1_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2295,10 +2295,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q8_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q8_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q8_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2330,10 +2330,10 @@ static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q8_0<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_qs_q8_0_acc_ct1.get_pointer(),
|
||||
tile_x_d_q8_0_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_qs_q8_0_acc_ct1),
|
||||
get_pointer(tile_x_d_q8_0_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2412,11 +2412,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q2_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2450,11 +2450,11 @@ static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q2_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q2_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q2_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q2_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q2_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2537,12 +2537,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_qh_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q3_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2578,12 +2578,12 @@ static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q3_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_qh_q3_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q3_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_qh_q3_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q3_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2663,11 +2663,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q4_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2701,11 +2701,11 @@ static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q4_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q4_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q4_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q4_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q4_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2784,11 +2784,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q5_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2822,11 +2822,11 @@ static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q5_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_dm_q5_K_acc_ct1.get_pointer(),
|
||||
tile_x_sc_q5_K_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_dm_q5_K_acc_ct1),
|
||||
get_pointer(tile_x_sc_q5_K_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2905,11 +2905,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_acc_ct1.get_pointer(),
|
||||
tile_x_dm_acc_ct1.get_pointer(),
|
||||
tile_x_sc_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_acc_ct1),
|
||||
get_pointer(tile_x_dm_acc_ct1),
|
||||
get_pointer(tile_x_sc_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -2943,11 +2943,11 @@ static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
|
|||
mul_mat_q6_K<need_check>(
|
||||
vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
|
||||
nrows_dst, item_ct1,
|
||||
tile_x_ql_acc_ct1.get_pointer(),
|
||||
tile_x_dm_acc_ct1.get_pointer(),
|
||||
tile_x_sc_acc_ct1.get_pointer(),
|
||||
tile_y_qs_acc_ct1.get_pointer(),
|
||||
tile_y_ds_acc_ct1.get_pointer());
|
||||
get_pointer(tile_x_ql_acc_ct1),
|
||||
get_pointer(tile_x_dm_acc_ct1),
|
||||
get_pointer(tile_x_sc_acc_ct1),
|
||||
get_pointer(tile_y_qs_acc_ct1),
|
||||
get_pointer(tile_y_ds_acc_ct1));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|||
const int nwarps = nthreads / WARP_SIZE;
|
||||
assert(nwarps % WARP_SIZE == 0);
|
||||
start += item_ct1.get_local_id(2);
|
||||
int nreduce = nwarps / WARP_SIZE;
|
||||
|
||||
if (end >= ne_elements) {
|
||||
end = ne_elements;
|
||||
|
@ -87,7 +88,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|||
*/
|
||||
item_ct1.barrier();
|
||||
tmp = 0.f;
|
||||
int nreduce = nwarps / WARP_SIZE;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
|
@ -122,7 +122,11 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|||
better performance if there is no access to global memory.
|
||||
*/
|
||||
item_ct1.barrier();
|
||||
tmp = s_sum[lane_id];
|
||||
tmp = 0.f;
|
||||
for (size_t i = 0; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += s_sum[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
|
@ -181,7 +185,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|||
|
||||
static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
|
@ -197,7 +201,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = get_work_group_size(stream->get_device());
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
||||
|
@ -214,7 +218,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -222,7 +226,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
|
||||
static void group_norm_f32_sycl(const float* x, float* dst,
|
||||
const int num_groups, const int group_size,
|
||||
const int ne_elements, queue_ptr stream) {
|
||||
const int ne_elements, queue_ptr stream, int device) {
|
||||
static const float eps = 1e-6f;
|
||||
if (group_size < 1024) {
|
||||
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
|
||||
|
@ -240,7 +244,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = get_work_group_size(stream->get_device());
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
||||
|
@ -261,7 +265,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
group_norm_f32(x, dst, group_size, ne_elements,
|
||||
eps_ct4, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -269,7 +273,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|||
|
||||
static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
||||
const int nrows, const float eps,
|
||||
queue_ptr stream) {
|
||||
queue_ptr stream, int device) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
// printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
|
||||
if (ncols < 1024) {
|
||||
|
@ -286,7 +290,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
});
|
||||
}
|
||||
else {
|
||||
const int work_group_size = get_work_group_size(stream->get_device());
|
||||
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
const sycl::range<3> block_dims(1, 1, work_group_size);
|
||||
/*
|
||||
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
||||
|
@ -302,7 +306,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|||
[=](sycl::nd_item<3> item_ct1)
|
||||
[[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
rms_norm_f32(x, dst, ncols, eps, item_ct1,
|
||||
s_sum_acc_ct1.get_pointer(), work_group_size);
|
||||
get_pointer(s_sum_acc_ct1), work_group_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
@ -322,7 +326,7 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
|
||||
norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
|
@ -340,7 +344,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|||
|
||||
int num_groups = dst->op_params[0];
|
||||
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
|
||||
group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
|
@ -362,7 +366,7 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
|
|||
float eps;
|
||||
memcpy(&eps, dst->op_params, sizeof(float));
|
||||
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream);
|
||||
rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
|
||||
|
||||
(void)src1;
|
||||
(void)dst;
|
||||
|
|
|
@ -62,4 +62,5 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
|
|||
|
||||
#define MUL_MAT_SRC1_COL_STRIDE 128
|
||||
|
||||
#define QK_WARP_SIZE 32
|
||||
#endif // GGML_SYCL_PRESETS_HPP
|
||||
|
|
|
@ -55,7 +55,7 @@ static void rope_norm(
|
|||
const int i = row*ne0 + i0;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
|
@ -98,7 +98,7 @@ static void rope_neox(
|
|||
const int i = row*ne0 + i0/2;
|
||||
const int i2 = row/p_delta_rows;
|
||||
|
||||
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
|
||||
const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
|
||||
|
||||
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
|
||||
|
||||
|
|
250
ggml/src/ggml-sycl/softmax.cpp
Normal file
250
ggml/src/ggml-sycl/softmax.cpp
Normal file
|
@ -0,0 +1,250 @@
|
|||
#include "norm.hpp"
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = item_ct1.get_local_id(2);
|
||||
const int rowx = item_ct1.get_group(2);
|
||||
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
||||
|
||||
const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template;
|
||||
|
||||
const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
|
||||
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
||||
const int nthreads = block_size;
|
||||
const int nwarps = nthreads / WARP_SIZE;
|
||||
int nreduce = nwarps / WARP_SIZE;
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const uint32_t h = rowx/nrows_y; // head index
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = sycl::pow(base, float(exp));
|
||||
}
|
||||
|
||||
float *vals = vals_smem ? buf + std::max(nwarps, WARP_SIZE) : dst + rowx * ncols;
|
||||
float max_val = -INFINITY;
|
||||
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int ix = rowx*ncols + col;
|
||||
const int iy = rowy*ncols + col;
|
||||
|
||||
const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = sycl::max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = -INFINITY;
|
||||
for (size_t i = 1; i < nreduce; i += 1)
|
||||
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = max_val;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
max_val = buf[lane_id];
|
||||
for (size_t i = 1; i < nreduce; i += 1)
|
||||
{
|
||||
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
||||
}
|
||||
max_val = warp_reduce_max(max_val, item_ct1);
|
||||
}
|
||||
|
||||
float tmp = 0.f;
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = sycl::native::exp(vals[col] - max_val);
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
if (block_size > WARP_SIZE) {
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
if (warp_id == 0) {
|
||||
buf[lane_id] = 0.f;
|
||||
for (size_t i = 1; i < nreduce; i += 1)
|
||||
buf[lane_id + i * WARP_SIZE] = 0.f;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf[warp_id] = tmp;
|
||||
}
|
||||
item_ct1.barrier(sycl::access::fence_space::local_space);
|
||||
|
||||
tmp = buf[lane_id];
|
||||
for (size_t i = 1; i < nreduce; i += 1)
|
||||
{
|
||||
tmp += buf[lane_id + i * WARP_SIZE];
|
||||
}
|
||||
tmp = warp_reduce_sum(tmp, item_ct1);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int idst = rowx*ncols + col;
|
||||
dst[idst] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||
static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
|
||||
const int nrows_y, const float scale, const float max_bias, const float m0,
|
||||
const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
|
||||
const size_t n_local_scratch, queue_ptr stream) {
|
||||
stream->submit([&](sycl::handler &cgh) {
|
||||
sycl::local_accessor<float, 1> local_buf_acc(n_local_scratch, cgh);
|
||||
|
||||
cgh.parallel_for(
|
||||
sycl::nd_range<3>(block_nums * block_dims, block_dims),
|
||||
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
|
||||
soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
|
||||
nrows_y, scale, max_bias, m0,
|
||||
m1, n_head_log2, item_ct1,
|
||||
get_pointer(local_buf_acc));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void soft_max_f32_sycl(const float * x, const float * mask,
|
||||
float * dst, const int ncols_x, const int nrows_x,
|
||||
const int nrows_y, const float scale, const float max_bias,
|
||||
queue_ptr stream, int device) {
|
||||
int nth = WARP_SIZE;
|
||||
int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
|
||||
while (nth < ncols_x && nth < max_block_size) nth *= 2;
|
||||
if (nth>max_block_size) nth = max_block_size;
|
||||
|
||||
const sycl::range<3> block_dims(1, 1, nth);
|
||||
const sycl::range<3> block_nums(1, 1, nrows_x);
|
||||
const size_t n_local_scratch = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE);
|
||||
|
||||
const uint32_t n_head_kv = nrows_x/nrows_y;
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
|
||||
if (n_local_scratch*sizeof(float) < local_mem_size) {
|
||||
if (ncols_x > max_block_size) {
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
return;
|
||||
}
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, n_local_scratch, stream);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
|
||||
max_bias, m0, m1, n_head_log2, block_nums,
|
||||
block_dims, WARP_SIZE, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream) {
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
|
||||
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, dst->op_params + 1, sizeof(float));
|
||||
|
||||
soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
|
||||
nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device);
|
||||
}
|
24
ggml/src/ggml-sycl/softmax.hpp
Normal file
24
ggml/src/ggml-sycl/softmax.hpp
Normal file
|
@ -0,0 +1,24 @@
|
|||
//
|
||||
// MIT license
|
||||
// Copyright (C) 2024 Intel Corporation
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
|
||||
#ifndef GGML_SYCL_SOFTMAX_HPP
|
||||
#define GGML_SYCL_SOFTMAX_HPP
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
void ggml_sycl_op_soft_max(ggml_backend_sycl_context &ctx, const ggml_tensor *src0,
|
||||
const ggml_tensor *src1, ggml_tensor *dst,
|
||||
const float *src0_dd, const float *src1_dd,
|
||||
float *dst_dd,
|
||||
const queue_ptr &main_stream);
|
||||
|
||||
#endif // GGML_SYCL_SOFTMAX_HPP
|
File diff suppressed because it is too large
Load diff
|
@ -6561,7 +6561,7 @@ static void ggml_vk_print_tensor(ggml_backend_vk_context * ctx, const ggml_tenso
|
|||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
|
||||
vk_buffer buffer_gpu = extra->buffer_gpu.lock();
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
}
|
||||
|
||||
std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
@ -6645,7 +6645,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
for (int i3 = 0; i3 < src0->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src0->ne[2]; i2++) {
|
||||
const int idx = i3*src0->ne[2] + i2;
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset + idx * src0->nb[2], ((char *)src0_clone->data + idx * src0_clone->nb[2]), src0->ne[1] * src0->nb[1]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6658,7 +6658,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
if (offset + src0_size >= buffer_gpu->size) {
|
||||
src0_size = buffer_gpu->size - offset;
|
||||
}
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src0_clone->data, src0_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset, src0_clone->data, src0_size);
|
||||
memcpy(src0_clone->nb, src0->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
}
|
||||
} else {
|
||||
|
@ -6687,7 +6687,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
for (int i3 = 0; i3 < src1->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src1->ne[2]; i2++) {
|
||||
const int idx = i3*src1->ne[2] + i2;
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset + idx * src1->nb[2], ((char *)src1_clone->data + idx * src1_clone->nb[2]), src1->ne[1] * src1->nb[1]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6700,7 +6700,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
if (offset + src1_size >= buffer_gpu->size) {
|
||||
src1_size = buffer_gpu->size - offset;
|
||||
}
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src1_clone->data, src1_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset, src1_clone->data, src1_size);
|
||||
memcpy(src1_clone->nb, src1->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
}
|
||||
} else {
|
||||
|
@ -6745,7 +6745,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
for (int i3 = 0; i3 < src2->ne[3]; i3++) {
|
||||
for (int i2 = 0; i2 < src2->ne[2]; i2++) {
|
||||
const int idx = i3*src2->ne[2] + i2;
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6758,7 +6758,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
if (offset + src2_size >= buffer_gpu->size) {
|
||||
src2_size = buffer_gpu->size - offset;
|
||||
}
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, offset, src2_clone->data, src2_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, offset, src2_clone->data, src2_size);
|
||||
memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS);
|
||||
}
|
||||
} else {
|
||||
|
@ -6922,7 +6922,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_tensor *
|
|||
tensor_size = buffer_gpu->size - (extra->offset + tensor->view_offs);
|
||||
}
|
||||
|
||||
ggml_vk_buffer_read(ctx, buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
ggml_vk_buffer_read(buffer_gpu, extra->offset + tensor->view_offs, tensor_data, tensor_size);
|
||||
}
|
||||
|
||||
float first_error_result = -1.0f;
|
||||
|
|
232
ggml/src/ggml.c
232
ggml/src/ggml.c
|
@ -4,7 +4,7 @@
|
|||
#include "ggml-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#include "ggml-aarch64.h"
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
|
@ -37,12 +37,12 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
#if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#undef GGML_USE_LLAMAFILE
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_LLAMAFILE
|
||||
#include "sgemm.h"
|
||||
#include <llamafile/sgemm.h>
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
|
@ -592,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = false,
|
||||
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
|
||||
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
.nrows = 1,
|
||||
|
@ -604,7 +604,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
|
||||
.from_float = quantize_row_q4_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
|
||||
.vec_dot = ggml_vec_dot_q4_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
|
@ -620,7 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
||||
.from_float = quantize_row_q4_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
|
||||
.vec_dot = ggml_vec_dot_q4_1_q8_1,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
|
@ -636,7 +636,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = false,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_COUNT,
|
||||
.nrows = 1,
|
||||
|
@ -648,7 +648,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = false,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_COUNT,
|
||||
.nrows = 1,
|
||||
|
@ -660,7 +660,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_0,
|
||||
.from_float = quantize_row_q5_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
|
||||
.vec_dot = ggml_vec_dot_q5_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
|
@ -672,7 +672,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_1,
|
||||
.from_float = quantize_row_q5_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
|
||||
.vec_dot = ggml_vec_dot_q5_1_q8_1,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
.nrows = 1,
|
||||
|
@ -684,7 +684,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q8_0,
|
||||
.from_float = quantize_row_q8_0,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
|
||||
.from_float_to_mat = quantize_mat_q8_0,
|
||||
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
|
@ -699,7 +700,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.type_size = sizeof(block_q8_1),
|
||||
.is_quantized = true,
|
||||
.from_float = quantize_row_q8_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
.nrows = 1,
|
||||
},
|
||||
|
@ -710,7 +711,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q2_K,
|
||||
.from_float = quantize_row_q2_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
|
||||
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -722,7 +723,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q3_K,
|
||||
.from_float = quantize_row_q3_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
|
||||
.vec_dot = ggml_vec_dot_q3_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -734,7 +735,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q4_K,
|
||||
.from_float = quantize_row_q4_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
|
||||
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -746,7 +747,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q5_K,
|
||||
.from_float = quantize_row_q5_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
|
||||
.vec_dot = ggml_vec_dot_q5_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -758,7 +759,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q6_K,
|
||||
.from_float = quantize_row_q6_K,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
|
||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -770,7 +771,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -782,7 +783,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -794,7 +795,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
|
||||
.from_float = quantize_row_iq3_xxs,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
|
||||
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -806,7 +807,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq3_s,
|
||||
.from_float = quantize_row_iq3_s,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
|
||||
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -818,7 +819,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_s,
|
||||
.from_float = quantize_row_iq2_s,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
|
||||
.vec_dot = ggml_vec_dot_iq2_s_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -830,7 +831,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq1_s,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -842,7 +843,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq1_m,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq1_m_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -854,7 +855,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
|
||||
.from_float = quantize_row_iq4_nl,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
|
||||
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
|
@ -866,7 +867,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
||||
.from_float = quantize_row_iq4_xs,
|
||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
|
||||
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
|
@ -885,10 +886,58 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = false,
|
||||
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
|
||||
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
||||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
||||
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
||||
.vec_dot_type = GGML_TYPE_BF16,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_4_4] = {
|
||||
.type_name = "q4_0_4x4",
|
||||
.blck_size = QK4_0,
|
||||
.blck_size_interleave = 4,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 4,
|
||||
.gemv = ggml_gemv_q4_0_4x4_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_4x4_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_4_8] = {
|
||||
.type_name = "q4_0_4x8",
|
||||
.blck_size = QK4_0,
|
||||
.blck_size_interleave = 8,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 4,
|
||||
.gemv = ggml_gemv_q4_0_4x8_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_0_8_8] = {
|
||||
.type_name = "q4_0_8x8",
|
||||
.blck_size = QK4_0,
|
||||
.blck_size_interleave = 8,
|
||||
.type_size = sizeof(block_q4_0),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_ref = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
.ncols = 8,
|
||||
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
||||
.gemm = ggml_gemm_q4_0_8x8_q8_0,
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -3066,7 +3115,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
|
|||
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
|
||||
}
|
||||
|
||||
GGML_CALL int ggml_blck_size(enum ggml_type type) {
|
||||
GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
|
||||
return type_traits[type].blck_size;
|
||||
}
|
||||
|
||||
|
@ -3188,6 +3237,9 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|||
case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break;
|
||||
case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = GGML_TYPE_Q4_0_4_4; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = GGML_TYPE_Q4_0_4_8; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break;
|
||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
||||
}
|
||||
|
@ -9425,6 +9477,9 @@ static void ggml_compute_forward_add(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -9800,6 +9855,9 @@ static void ggml_compute_forward_add1(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_add1_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -9925,6 +9983,9 @@ static void ggml_compute_forward_acc(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -12119,9 +12180,14 @@ static void ggml_compute_forward_mul_mat(
|
|||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
|
||||
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
|
||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||
int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
|
||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||
ggml_gemm_t const gemm = type_traits[type].gemm;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
|
@ -12180,10 +12246,19 @@ UseGgmlGemm1:;
|
|||
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
int64_t i11_processed = 0;
|
||||
if ((ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) {
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
4, ne10, blck_size_interleave);
|
||||
}
|
||||
i11_processed = ne11 - ne11 % 4;
|
||||
}
|
||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12261,6 +12336,28 @@ UseGgmlGemm2:;
|
|||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
|
||||
if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
|
||||
int64_t src0_start = (ith * ne01) / nth;
|
||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
||||
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
|
||||
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
|
||||
if (src0_start >= src0_end) return;
|
||||
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (gemm && (ne11 > 3)) {
|
||||
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
}
|
||||
for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
|
@ -12303,9 +12400,11 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
|
||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||
|
@ -12346,9 +12445,9 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||
ne10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -12391,6 +12490,34 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = cne1; // src1 rows
|
||||
|
||||
if (((ggml_n_dims(src0) - 1) == 2) && gemv) {
|
||||
int64_t src0_cur_start = (ith * ne01) / nth;
|
||||
int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
|
||||
src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start;
|
||||
src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end;
|
||||
if (src0_cur_start >= src0_cur_end) return;
|
||||
|
||||
for (int ir1 = 0; ir1 < nr1; ir1++) {
|
||||
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
|
||||
const int id = row_mapping.i1; // selected expert index
|
||||
|
||||
const int64_t i11 = id % ne11;
|
||||
const int64_t i12 = row_mapping.i2; // row index in src1
|
||||
|
||||
const int64_t i1 = id; // selected expert index
|
||||
const int64_t i2 = i12; // row
|
||||
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12 * ne11) * row_size
|
||||
: (i11 * nb11 + i12 * nb12));
|
||||
|
||||
gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
|
||||
(const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
|
@ -12692,6 +12819,9 @@ static void ggml_compute_forward_out_prod(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_out_prod_q_f32(params, dst);
|
||||
} break;
|
||||
|
@ -12877,6 +13007,9 @@ static void ggml_compute_forward_set(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -13136,6 +13269,9 @@ static void ggml_compute_forward_get_rows(
|
|||
case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, dst);
|
||||
} break;
|
||||
|
@ -13722,6 +13858,9 @@ static void ggml_compute_forward_clamp(
|
|||
case GGML_TYPE_IQ3_S:
|
||||
case GGML_TYPE_IQ2_S:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_Q4_0_4_4:
|
||||
case GGML_TYPE_Q4_0_4_8:
|
||||
case GGML_TYPE_Q4_0_8_8:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -19246,7 +19385,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|||
|
||||
fprintf(fp, "digraph G {\n");
|
||||
fprintf(fp, " newrank = true;\n");
|
||||
fprintf(fp, " rankdir = LR;\n");
|
||||
fprintf(fp, " rankdir = TB;\n");
|
||||
|
||||
for (int i = 0; i < gb->n_nodes; i++) {
|
||||
struct ggml_tensor * node = gb->nodes[i];
|
||||
|
@ -19308,7 +19447,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
|||
}
|
||||
|
||||
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
||||
if (ggml_nelements(node) < 5) {
|
||||
if (ggml_nelements(node) < 5 && node->data != NULL) {
|
||||
fprintf(fp, " | (");
|
||||
for (int j = 0; j < ggml_nelements(node); j++) {
|
||||
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
||||
|
@ -20364,6 +20503,9 @@ size_t ggml_quantize_chunk(
|
|||
case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
size_t elemsize = sizeof(ggml_fp16_t);
|
||||
|
@ -20827,8 +20969,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|||
(int64_t) info->ne[3];
|
||||
|
||||
if (ne % ggml_blck_size(info->type) != 0) {
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
|
||||
__func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
|
||||
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
||||
fclose(file);
|
||||
gguf_free(ctx);
|
||||
return NULL;
|
||||
|
@ -21666,8 +21808,6 @@ int ggml_cpu_has_neon(void) {
|
|||
|
||||
int ggml_cpu_has_sve(void) {
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
// TODO: Currently, SVE 256 bit is only supported.
|
||||
GGML_ASSERT(svcntb() == QK8_0);
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
|
|
5
ggml/src/vulkan-shaders/CMakeLists.txt
Normal file
5
ggml/src/vulkan-shaders/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
|
||||
set(TARGET vulkan-shaders-gen)
|
||||
add_executable(${TARGET} vulkan-shaders-gen.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
524
ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
Normal file
524
ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
Normal file
|
@ -0,0 +1,524 @@
|
|||
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <future>
|
||||
#include <queue>
|
||||
#include <condition_variable>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cstdlib>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#include <direct.h> // For _mkdir on Windows
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#define ASYNCIO_CONCURRENCY 64
|
||||
|
||||
std::mutex lock;
|
||||
std::vector<std::pair<std::string, std::string>> shader_fnames;
|
||||
|
||||
std::string GLSLC = "glslc";
|
||||
std::string input_dir = "vulkan-shaders";
|
||||
std::string output_dir = "/tmp";
|
||||
std::string target_hpp = "ggml-vulkan-shaders.hpp";
|
||||
std::string target_cpp = "ggml-vulkan-shaders.cpp";
|
||||
bool no_clean = false;
|
||||
|
||||
const std::vector<std::string> type_names = {
|
||||
"f32",
|
||||
"f16",
|
||||
"q4_0",
|
||||
"q4_1",
|
||||
"q5_0",
|
||||
"q5_1",
|
||||
"q8_0",
|
||||
"q2_k",
|
||||
"q3_k",
|
||||
"q4_k",
|
||||
"q5_k",
|
||||
"q6_k"
|
||||
};
|
||||
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
HANDLE stdout_read, stdout_write;
|
||||
HANDLE stderr_read, stderr_write;
|
||||
SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
|
||||
|
||||
if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
|
||||
!SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
|
||||
throw std::runtime_error("Failed to create stdout pipe");
|
||||
}
|
||||
|
||||
if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
|
||||
!SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
|
||||
throw std::runtime_error("Failed to create stderr pipe");
|
||||
}
|
||||
|
||||
PROCESS_INFORMATION pi;
|
||||
STARTUPINFOA si = { sizeof(STARTUPINFOA) };
|
||||
si.dwFlags = STARTF_USESTDHANDLES;
|
||||
si.hStdOutput = stdout_write;
|
||||
si.hStdError = stderr_write;
|
||||
|
||||
std::vector<char> cmd(command.begin(), command.end());
|
||||
cmd.push_back('\0');
|
||||
|
||||
if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
|
||||
throw std::runtime_error("Failed to create process");
|
||||
}
|
||||
|
||||
CloseHandle(stdout_write);
|
||||
CloseHandle(stderr_write);
|
||||
|
||||
std::array<char, 128> buffer;
|
||||
DWORD bytes_read;
|
||||
|
||||
while (ReadFile(stdout_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
|
||||
stdout_str.append(buffer.data(), bytes_read);
|
||||
}
|
||||
|
||||
while (ReadFile(stderr_read, buffer.data(), buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
|
||||
stderr_str.append(buffer.data(), bytes_read);
|
||||
}
|
||||
|
||||
CloseHandle(stdout_read);
|
||||
CloseHandle(stderr_read);
|
||||
WaitForSingleObject(pi.hProcess, INFINITE);
|
||||
CloseHandle(pi.hProcess);
|
||||
CloseHandle(pi.hThread);
|
||||
#else
|
||||
int stdout_pipe[2];
|
||||
int stderr_pipe[2];
|
||||
|
||||
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
|
||||
throw std::runtime_error("Failed to create pipes");
|
||||
}
|
||||
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) {
|
||||
throw std::runtime_error("Failed to fork process");
|
||||
}
|
||||
|
||||
if (pid == 0) {
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
dup2(stdout_pipe[1], STDOUT_FILENO);
|
||||
dup2(stderr_pipe[1], STDERR_FILENO);
|
||||
close(stdout_pipe[1]);
|
||||
close(stderr_pipe[1]);
|
||||
execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr);
|
||||
_exit(EXIT_FAILURE);
|
||||
} else {
|
||||
close(stdout_pipe[1]);
|
||||
close(stderr_pipe[1]);
|
||||
|
||||
std::array<char, 128> buffer;
|
||||
ssize_t bytes_read;
|
||||
|
||||
while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
|
||||
stdout_str.append(buffer.data(), bytes_read);
|
||||
}
|
||||
|
||||
while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
|
||||
stderr_str.append(buffer.data(), bytes_read);
|
||||
}
|
||||
|
||||
close(stdout_pipe[0]);
|
||||
close(stderr_pipe[0]);
|
||||
waitpid(pid, nullptr, 0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool directory_exists(const std::string& path) {
|
||||
struct stat info;
|
||||
if (stat(path.c_str(), &info) != 0) {
|
||||
return false; // Path doesn't exist or can't be accessed
|
||||
}
|
||||
return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
|
||||
}
|
||||
|
||||
bool create_directory(const std::string& path) {
|
||||
#ifdef _WIN32
|
||||
return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
|
||||
#else
|
||||
return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string to_uppercase(const std::string& input) {
|
||||
std::string result = input;
|
||||
for (char& c : result) {
|
||||
c = std::toupper(c);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool string_ends_with(const std::string& str, const std::string& suffix) {
|
||||
if (suffix.size() > str.size()) {
|
||||
return false;
|
||||
}
|
||||
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
static const char path_separator = '\\';
|
||||
#else
|
||||
static const char path_separator = '/';
|
||||
#endif
|
||||
|
||||
std::string join_paths(const std::string& path1, const std::string& path2) {
|
||||
return path1 + path_separator + path2;
|
||||
}
|
||||
|
||||
std::string basename(const std::string &path) {
|
||||
return path.substr(path.find_last_of("/\\") + 1);
|
||||
}
|
||||
|
||||
void string_to_spv(const std::string& _name, const std::string& in_fname, const std::map<std::string, std::string>& defines, bool fp16 = true) {
|
||||
std::string name = _name + (fp16 ? "" : "_fp32");
|
||||
std::string out_fname = join_paths(output_dir, name + ".spv");
|
||||
std::string in_path = join_paths(input_dir, in_fname);
|
||||
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", in_path, "-o", out_fname};
|
||||
for (const auto& define : defines) {
|
||||
cmd.push_back("-D" + define.first + "=" + define.second);
|
||||
}
|
||||
|
||||
std::string command;
|
||||
for (const auto& part : cmd) {
|
||||
command += part + " ";
|
||||
}
|
||||
|
||||
std::string stdout_str, stderr_str;
|
||||
try {
|
||||
// std::cout << "Executing command: ";
|
||||
// for (const auto& part : cmd) {
|
||||
// std::cout << part << " ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
execute_command(command, stdout_str, stderr_str);
|
||||
if (!stderr_str.empty()) {
|
||||
std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> guard(lock);
|
||||
shader_fnames.push_back(std::make_pair(name, out_fname));
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
|
||||
std::map<std::string, std::string> result = a;
|
||||
result.insert(b.begin(), b.end());
|
||||
return result;
|
||||
}
|
||||
|
||||
void matmul_shaders(std::vector<std::future<void>>& tasks, bool fp16, bool matmul_id) {
|
||||
std::string load_vec = fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = fp16 ? "f16mat2x4" : "f16vec4";
|
||||
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", fp16 ? "float16_t" : "float"}};
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
shader_name = "matmul_id";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_f32_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_f32_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_f16", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_f16_aligned", "mul_mm.comp", merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string load_vec_a = (tname == "f32" || tname == "f16") ? load_vec : "2";
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", "mul_mm.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}}), fp16);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
void process_shaders(std::vector<std::future<void>>& tasks) {
|
||||
std::cout << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
||||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
|
||||
for (const auto& fp16 : {false, true}) {
|
||||
matmul_shaders(tasks, fp16, false);
|
||||
matmul_shaders(tasks, fp16, true);
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
// mul mat vec
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
|
||||
// Dequant shaders
|
||||
if (tname != "f16") {
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
||||
}));
|
||||
}
|
||||
|
||||
if (!string_ends_with(tname, "_k")) {
|
||||
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||
|
||||
if (tname == "f16") {
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
}));
|
||||
} else {
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
|
||||
}));
|
||||
}
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
// Norms
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
}));
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [] {
|
||||
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
|
||||
}));
|
||||
|
||||
tasks.push_back(std::async(std::launch::async, [=] {
|
||||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
}));
|
||||
}
|
||||
|
||||
void write_output_files() {
|
||||
FILE* hdr = fopen(target_hpp.c_str(), "w");
|
||||
FILE* src = fopen(target_cpp.c_str(), "w");
|
||||
|
||||
fprintf(hdr, "#include <cstdint>\n\n");
|
||||
fprintf(src, "#include \"%s\"\n\n", basename(target_hpp).c_str());
|
||||
|
||||
for (const auto& pair : shader_fnames) {
|
||||
const std::string& name = pair.first;
|
||||
const std::string& path = pair.second;
|
||||
FILE* spv = fopen(path.c_str(), "rb");
|
||||
if (!spv) {
|
||||
std::cerr << "Error opening SPIR-V file: " << path << "\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
fseek(spv, 0, SEEK_END);
|
||||
size_t size = ftell(spv);
|
||||
fseek(spv, 0, SEEK_SET);
|
||||
|
||||
std::vector<unsigned char> data(size);
|
||||
size_t read_size = fread(data.data(), 1, size, spv);
|
||||
fclose(spv);
|
||||
if (read_size != size) {
|
||||
std::cerr << "Error reading SPIR-V file: " << path << "\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
fprintf(hdr, "extern unsigned char %s_data[%zu];\n", name.c_str(), size);
|
||||
fprintf(hdr, "const uint64_t %s_len = %zu;\n\n", name.c_str(), size);
|
||||
|
||||
fprintf(src, "unsigned char %s_data[%zu] = {\n", name.c_str(), size);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
fprintf(src, "0x%02x,", data[i]);
|
||||
if ((i + 1) % 12 == 0) fprintf(src, "\n");
|
||||
}
|
||||
fprintf(src, "\n};\n\n");
|
||||
|
||||
if (!no_clean) {
|
||||
std::remove(path.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
fclose(hdr);
|
||||
fclose(src);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::map<std::string, std::string> args;
|
||||
for (int i = 1; i < argc; i += 2) {
|
||||
if (i + 1 < argc) {
|
||||
args[argv[i]] = argv[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
if (args.find("--glslc") != args.end()) {
|
||||
GLSLC = args["--glslc"]; // Path to glslc
|
||||
}
|
||||
if (args.find("--input-dir") != args.end()) {
|
||||
input_dir = args["--input-dir"]; // Directory containing shader sources
|
||||
}
|
||||
if (args.find("--output-dir") != args.end()) {
|
||||
output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
|
||||
}
|
||||
if (args.find("--target-hpp") != args.end()) {
|
||||
target_hpp = args["--target-hpp"]; // Path to generated header file
|
||||
}
|
||||
if (args.find("--target-cpp") != args.end()) {
|
||||
target_cpp = args["--target-cpp"]; // Path to generated cpp file
|
||||
}
|
||||
if (args.find("--no-clean") != args.end()) {
|
||||
no_clean = true; // Keep temporary SPIR-V files in output-dir after build
|
||||
}
|
||||
|
||||
if (!directory_exists(input_dir)) {
|
||||
std::cerr << "\"" << input_dir << "\" must be a valid directory containing shader sources" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (!directory_exists(output_dir)) {
|
||||
if (!create_directory(output_dir)) {
|
||||
std::cerr << "Error creating output directory: " << output_dir << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::future<void>> tasks;
|
||||
process_shaders(tasks);
|
||||
|
||||
for (auto& task : tasks) {
|
||||
task.get();
|
||||
}
|
||||
|
||||
write_output_files();
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue