Merge branch 'master' into pr/8836
This commit is contained in:
commit
fb2b9ea667
23 changed files with 1653 additions and 775 deletions
|
@ -28,6 +28,7 @@
|
||||||
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
|
{ "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } },
|
||||||
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
{ "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } },
|
||||||
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
|
{ "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } },
|
||||||
|
{ "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } },
|
||||||
|
|
||||||
{
|
{
|
||||||
"name": "arm64-windows-msvc", "hidden": true,
|
"name": "arm64-windows-msvc", "hidden": true,
|
||||||
|
@ -60,6 +61,8 @@
|
||||||
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
|
{ "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] },
|
||||||
|
|
||||||
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
|
{ "name": "x64-windows-sycl-debug" , "inherits": [ "sycl-base", "debug" ] },
|
||||||
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }
|
{ "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] },
|
||||||
|
{ "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] },
|
||||||
|
{ "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,6 +63,7 @@ class Model:
|
||||||
model_name: str | None
|
model_name: str | None
|
||||||
metadata_override: Path | None
|
metadata_override: Path | None
|
||||||
dir_model_card: Path
|
dir_model_card: Path
|
||||||
|
is_lora: bool
|
||||||
|
|
||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
@ -70,7 +71,7 @@ class Model:
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False):
|
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, is_lora: bool = False):
|
||||||
if type(self) is Model:
|
if type(self) is Model:
|
||||||
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
|
||||||
|
|
||||||
|
@ -92,6 +93,7 @@ class Model:
|
||||||
self.metadata_override = metadata_override
|
self.metadata_override = metadata_override
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
|
||||||
|
self.is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
|
||||||
|
|
||||||
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
|
||||||
if self.ftype == gguf.LlamaFileType.GUESSED:
|
if self.ftype == gguf.LlamaFileType.GUESSED:
|
||||||
|
@ -1593,7 +1595,8 @@ class LlamaModel(Model):
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
if not self.is_lora:
|
||||||
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
@ -2140,8 +2143,9 @@ class Phi3MiniModel(Model):
|
||||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
if not self.is_lora:
|
||||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||||
|
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||||
|
|
||||||
|
|
||||||
@Model.register("PlamoForCausalLM")
|
@Model.register("PlamoForCausalLM")
|
||||||
|
@ -3839,7 +3843,8 @@ class ExaoneModel(Model):
|
||||||
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
rope_factors.append(1 / ((1 - smooth) / factor + smooth))
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
if not self.is_lora:
|
||||||
|
self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))
|
||||||
|
|
||||||
super().prepare_tensors()
|
super().prepare_tensors()
|
||||||
|
|
||||||
|
|
|
@ -386,6 +386,7 @@ if __name__ == '__main__':
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
dir_lora_model=dir_lora,
|
dir_lora_model=dir_lora,
|
||||||
lora_alpha=alpha,
|
lora_alpha=alpha,
|
||||||
|
is_lora=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Exporting model...")
|
logger.info("Exporting model...")
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
|
**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include:
|
||||||
|
|
||||||
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
|
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
|
||||||
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL - Math Kernel Library)*.
|
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. oneMKL and oneDNN)*.
|
||||||
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
|
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs.
|
||||||
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
|
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
|
||||||
|
|
||||||
|
@ -28,10 +28,6 @@
|
||||||
|
|
||||||
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*).
|
The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it could support other vendor GPUs: Nvidia GPU (*AMD GPU coming*).
|
||||||
|
|
||||||
When targeting **Intel CPU**, it is recommended to use llama.cpp for [Intel oneMKL](README.md#intel-onemkl) backend.
|
|
||||||
|
|
||||||
It has the similar design of other llama.cpp BLAS-based paths such as *OpenBLAS, cuBLAS, etc..*. In beginning work, the oneAPI's [SYCLomatic](https://github.com/oneapi-src/SYCLomatic) open-source migration tool (Commercial release [Intel® DPC++ Compatibility Tool](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html)) was used for this purpose.
|
|
||||||
|
|
||||||
## Recommended Release
|
## Recommended Release
|
||||||
|
|
||||||
The SYCL backend would be broken by some PRs due to no online CI.
|
The SYCL backend would be broken by some PRs due to no online CI.
|
||||||
|
@ -45,6 +41,10 @@ The following release is verified with good quality:
|
||||||
|
|
||||||
## News
|
## News
|
||||||
|
|
||||||
|
|
||||||
|
- 2024.8
|
||||||
|
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
|
||||||
|
|
||||||
- 2024.5
|
- 2024.5
|
||||||
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
|
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770.
|
||||||
- Arch Linux is verified successfully.
|
- Arch Linux is verified successfully.
|
||||||
|
@ -196,7 +196,7 @@ Please follow the instructions for downloading and installing the Toolkit for Li
|
||||||
|
|
||||||
Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable.
|
Following guidelines/code snippets assume the default installation values. Otherwise, please make sure the necessary changes are reflected where applicable.
|
||||||
|
|
||||||
Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI MKL for intel GPUs.
|
Upon a successful installation, SYCL is enabled for the available intel devices, along with relevant libraries such as oneAPI oneDNN for Intel GPUs.
|
||||||
|
|
||||||
- **Adding support to Nvidia GPUs**
|
- **Adding support to Nvidia GPUs**
|
||||||
|
|
||||||
|
@ -255,8 +255,6 @@ or
|
||||||
# Export relevant ENV variables
|
# Export relevant ENV variables
|
||||||
source /opt/intel/oneapi/setvars.sh
|
source /opt/intel/oneapi/setvars.sh
|
||||||
|
|
||||||
# Build LLAMA with MKL BLAS acceleration for intel GPU
|
|
||||||
|
|
||||||
# Option 1: Use FP32 (recommended for better performance in most cases)
|
# Option 1: Use FP32 (recommended for better performance in most cases)
|
||||||
cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
cmake -B build -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,7 @@ static void usage(const char * executable) {
|
||||||
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n");
|
||||||
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n");
|
||||||
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n");
|
||||||
printf(" --keep-split: will generate quatized model in the same shards as input");
|
printf(" --keep-split: will generate quantized model in the same shards as input\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n");
|
||||||
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
printf("Note: --include-weights and --exclude-weights cannot be used together\n");
|
||||||
|
|
|
@ -1760,7 +1760,8 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias);
|
float max_bias,
|
||||||
|
float logit_softcap);
|
||||||
|
|
||||||
GGML_API void ggml_flash_attn_ext_set_prec(
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
@ -1777,10 +1778,8 @@ extern "C" {
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
GGML_API struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
struct ggml_tensor * c);
|
||||||
struct ggml_tensor * c,
|
|
||||||
struct ggml_tensor * sq);
|
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
GGML_API struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
@ -1789,8 +1788,7 @@ extern "C" {
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C,
|
struct ggml_tensor * C);
|
||||||
struct ggml_tensor * sq);
|
|
||||||
|
|
||||||
// partition into non-overlapping windows with padding if needed
|
// partition into non-overlapping windows with padding if needed
|
||||||
// example:
|
// example:
|
||||||
|
|
|
@ -549,6 +549,13 @@ if (GGML_SYCL)
|
||||||
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
|
file(GLOB GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
|
||||||
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
|
list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
|
||||||
|
|
||||||
|
find_package(DNNL)
|
||||||
|
message("-- DNNL found:" ${DNNL_FOUND})
|
||||||
|
if (GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||||
|
add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
|
||||||
|
else()
|
||||||
|
add_compile_definitions(GGML_SYCL_DNNL=0)
|
||||||
|
endif()
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
find_package(IntelSYCL REQUIRED)
|
find_package(IntelSYCL REQUIRED)
|
||||||
find_package(MKL REQUIRED)
|
find_package(MKL REQUIRED)
|
||||||
|
@ -561,6 +568,9 @@ if (GGML_SYCL)
|
||||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
|
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
|
||||||
|
list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (GGML_RPC)
|
if (GGML_RPC)
|
||||||
|
|
|
@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -657,11 +658,17 @@ void launch_fattn(
|
||||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||||
const int shmem = 0;
|
const int shmem = 0;
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = Q->ne[2];
|
const uint32_t n_head = Q->ne[2];
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
@ -675,7 +682,7 @@ void launch_fattn(
|
||||||
V_data,
|
V_data,
|
||||||
mask ? ((const char *) mask->data) : nullptr,
|
mask ? ((const char *) mask->data) : nullptr,
|
||||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||||
scale, max_bias, m0, m1, n_head_log2,
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F16 64
|
#define FATTN_KQ_STRIDE_TILE_F16 64
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
half sum;
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
sum = logit_softcap * tanhf(tmp.x + tmp.y);
|
||||||
|
} else {
|
||||||
|
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
|
||||||
|
@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
|
|
||||||
#define FATTN_KQ_STRIDE_TILE_F32 32
|
#define FATTN_KQ_STRIDE_TILE_F32 32
|
||||||
|
|
||||||
template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
|
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||||
|
@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
||||||
const int j_KQ = j_KQ_0 + threadIdx.y;
|
const int j_KQ = j_KQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
}
|
||||||
|
|
||||||
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
|
||||||
|
@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int cols_per_block, int parallel_blocks>
|
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
|
||||||
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
switch (Q->ne[0]) {
|
switch (Q->ne[0]) {
|
||||||
case 64: {
|
case 64: {
|
||||||
constexpr int D = 64;
|
constexpr int D = 64;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
case 128: {
|
case 128: {
|
||||||
constexpr int D = 128;
|
constexpr int D = 128;
|
||||||
constexpr int nwarps = 8;
|
constexpr int nwarps = 8;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
|
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
} break;
|
} break;
|
||||||
default: {
|
default: {
|
||||||
|
@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] <= 16) {
|
if (Q->ne[1] <= 16) {
|
||||||
constexpr int cols_per_block = 16;
|
constexpr int cols_per_block = 16;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 32) {
|
if (Q->ne[1] <= 32) {
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 32;
|
constexpr int cols_per_block = 32;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_AVAILABLE
|
#ifdef FP16_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||||
|
@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
|
||||||
|
|
||||||
if (ncols == 1) {
|
if (ncols == 1) {
|
||||||
|
@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
#endif // FP16_AVAILABLE
|
#endif // FP16_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
|
@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
#include "fattn-common.cuh"
|
#include "fattn-common.cuh"
|
||||||
|
|
||||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
|
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(D, 1)
|
__launch_bounds__(D, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
const int ne1,
|
const int ne1,
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
|
||||||
|
@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
sum = logit_softcap*tanhf(sum);
|
||||||
|
}
|
||||||
|
|
||||||
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
||||||
|
|
||||||
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
|
||||||
|
@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
|
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
constexpr int nwarps = D/WARP_SIZE;
|
constexpr int nwarps = D/WARP_SIZE;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
|
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
|
||||||
constexpr bool need_f16_K = D != 128;
|
constexpr bool need_f16_K = D != 128;
|
||||||
constexpr bool need_f16_V = D != 128 && D != 64;
|
constexpr bool need_f16_V = D != 128 && D != 64;
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
|
||||||
|
@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
||||||
|
|
||||||
template <int D, ggml_type type_K, ggml_type type_V>
|
template <int D, ggml_type type_K, ggml_type type_V>
|
||||||
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * KQV = dst;
|
||||||
ggml_tensor * K = dst->src[1];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
ggml_tensor * V = dst->src[2];
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
|
|
||||||
GGML_ASSERT(K->type == type_K);
|
GGML_ASSERT(K->type == type_K);
|
||||||
GGML_ASSERT(V->type == type_V);
|
GGML_ASSERT(V->type == type_V);
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (Q->ne[1] == 1) {
|
if (Q->ne[1] == 1) {
|
||||||
constexpr int cols_per_block = 1;
|
constexpr int cols_per_block = 1;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] == 2) {
|
if (Q->ne[1] == 2) {
|
||||||
constexpr int cols_per_block = 2;
|
constexpr int cols_per_block = 2;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 4) {
|
if (Q->ne[1] <= 4) {
|
||||||
constexpr int cols_per_block = 4;
|
constexpr int cols_per_block = 4;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Q->ne[1] <= 8) {
|
if (Q->ne[1] <= 8) {
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int cols_per_block = 8;
|
constexpr int cols_per_block = 8;
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
#endif // FP16_MMA_AVAILABLE
|
#endif // FP16_MMA_AVAILABLE
|
||||||
|
|
||||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const float m0,
|
const float m0,
|
||||||
const float m1,
|
const float m1,
|
||||||
const uint32_t n_head_log2,
|
const uint32_t n_head_log2,
|
||||||
|
const float logit_softcap,
|
||||||
const int ne00,
|
const int ne00,
|
||||||
const int ne01,
|
const int ne01,
|
||||||
const int ne02,
|
const int ne02,
|
||||||
|
@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int ne2,
|
const int ne2,
|
||||||
const int ne3) {
|
const int ne3) {
|
||||||
#ifdef FP16_MMA_AVAILABLE
|
#ifdef FP16_MMA_AVAILABLE
|
||||||
|
// Skip unused kernel variants for faster compilation:
|
||||||
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||||
|
|
||||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||||
|
@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const half slopeh = __float2half(slopef);
|
const half slopeh = __float2half(slopef);
|
||||||
const half2 slope2 = make_half2(slopef, slopef);
|
const half2 slope2 = make_half2(slopef, slopef);
|
||||||
|
|
||||||
|
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
|
||||||
|
|
||||||
frag_b Q_b[D/16][ncols/frag_n];
|
frag_b Q_b[D/16][ncols/frag_n];
|
||||||
|
|
||||||
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
|
||||||
|
@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float KQ_max_new = KQ_max_f[j0/nwarps];
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
||||||
|
@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int k = k0 + threadIdx.x;
|
const int k = k0 + threadIdx.x;
|
||||||
|
|
||||||
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
||||||
|
|
||||||
|
if (use_logit_softcap) {
|
||||||
|
// There is no dedicated tangens hyperbolicus function for half2.
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
||||||
|
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
||||||
|
|
||||||
|
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
||||||
|
@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
||||||
|
|
||||||
template <int D, int cols_per_block, typename KQ_acc_t>
|
template <int D, int cols_per_block, typename KQ_acc_t>
|
||||||
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * KQV = dst;
|
||||||
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
constexpr int nwarps = 4;
|
constexpr int nwarps = 4;
|
||||||
|
|
||||||
|
@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
||||||
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
||||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||||
|
|
||||||
|
float logit_softcap;
|
||||||
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
if (4*blocks_num_pb1 < 2*nsm) {
|
if (4*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 4;
|
constexpr int parallel_blocks = 4;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (2*blocks_num_pb1 < 2*nsm) {
|
if (2*blocks_num_pb1 < 2*nsm) {
|
||||||
constexpr int parallel_blocks = 2;
|
constexpr int parallel_blocks = 2;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
constexpr int parallel_blocks = 1;
|
constexpr int parallel_blocks = 1;
|
||||||
fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
fattn_kernel_t fattn_kernel;
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
constexpr bool use_logit_softcap = false;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
} else {
|
||||||
|
constexpr bool use_logit_softcap = true;
|
||||||
|
fattn_kernel = flash_attn_ext_f16<
|
||||||
|
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
||||||
|
}
|
||||||
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
|
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
if (precision != GGML_PREC_DEFAULT) {
|
if (precision != GGML_PREC_DEFAULT) {
|
||||||
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
||||||
|
@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
|
|
||||||
ggml_cuda_set_device(ctx.device);
|
ggml_cuda_set_device(ctx.device);
|
||||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||||
const int32_t precision = KQV->op_params[2];
|
const int32_t precision = KQV->op_params[3];
|
||||||
|
|
||||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||||
if (cc >= CC_OFFSET_AMD) {
|
if (cc >= CC_OFFSET_AMD) {
|
||||||
|
|
|
@ -802,6 +802,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||||
if (op->src[0]->ne[0] == 256) {
|
if (op->src[0]->ne[0] == 256) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
{
|
||||||
|
float logit_softcap;
|
||||||
|
|
||||||
|
memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
|
|
@ -38,6 +38,7 @@
|
||||||
|
|
||||||
#include "ggml-sycl/backend.hpp"
|
#include "ggml-sycl/backend.hpp"
|
||||||
#include "ggml-sycl/presets.hpp"
|
#include "ggml-sycl/presets.hpp"
|
||||||
|
#include "ggml-sycl/gemm.hpp"
|
||||||
|
|
||||||
bool ggml_sycl_loaded(void);
|
bool ggml_sycl_loaded(void);
|
||||||
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
void ggml_sycl_free_data(struct ggml_tensor * tensor);
|
||||||
|
@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const sycl::half alpha_f16 = 1.0f;
|
const sycl::half alpha_f16 = 1.0f;
|
||||||
const sycl::half beta_f16 = 0.0f;
|
const sycl::half beta_f16 = 0.0f;
|
||||||
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
|
@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
dpct::library_data_t::real_half)));
|
dpct::library_data_t::real_half)));
|
||||||
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||||
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||||
|
#else
|
||||||
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
||||||
|
src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
|
||||||
|
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
|
||||||
|
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
||||||
|
@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
#if !GGML_SYCL_DNNL
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
|
||||||
*stream, oneapi::mkl::transpose::trans,
|
*stream, oneapi::mkl::transpose::trans,
|
||||||
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
|
||||||
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
|
||||||
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
|
||||||
dst_dd_i, ldc)));
|
dst_dd_i, ldc)));
|
||||||
|
#else
|
||||||
|
auto dnnl_stream = ctx.stream_dnnl(stream);
|
||||||
|
DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
|
||||||
|
src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src1_ddq_i;
|
(void) src1_ddq_i;
|
||||||
|
|
|
@ -19,6 +19,10 @@
|
||||||
#include "dpct/helper.hpp"
|
#include "dpct/helper.hpp"
|
||||||
#include "ggml-sycl.h"
|
#include "ggml-sycl.h"
|
||||||
#include "presets.hpp"
|
#include "presets.hpp"
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
#include "dnnl.hpp"
|
||||||
|
#include "dnnl_sycl.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define GGML_COMMON_DECL_SYCL
|
#define GGML_COMMON_DECL_SYCL
|
||||||
#define GGML_COMMON_IMPL_SYCL
|
#define GGML_COMMON_IMPL_SYCL
|
||||||
|
@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
|
||||||
return stream(device, 0);
|
return stream(device, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
dnnl::engine make_engine(sycl::queue* q) {
|
||||||
|
// Get the device associated with the queue
|
||||||
|
sycl::device dev = q->get_device();
|
||||||
|
// Get the context associated with the queue
|
||||||
|
sycl::context ctx = q->get_context();
|
||||||
|
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||||
|
return eng;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
|
||||||
|
std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
|
||||||
|
dnnl::stream stream_dnnl(int device, int _stream) {
|
||||||
|
auto q = stream(device, _stream);
|
||||||
|
return stream_dnnl(q);
|
||||||
|
}
|
||||||
|
dnnl::engine engine_dnnl(sycl::queue* qptr) {
|
||||||
|
auto it = engine_map.find(qptr);
|
||||||
|
if (it == engine_map.end()) {
|
||||||
|
auto eng = make_engine(qptr);
|
||||||
|
engine_map[qptr] = eng;
|
||||||
|
return eng;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dnnl::stream stream_dnnl(sycl::queue* qptr) {
|
||||||
|
auto it = stream_map.find(qptr);
|
||||||
|
if (it == stream_map.end()) {
|
||||||
|
auto eng = engine_dnnl(qptr);
|
||||||
|
auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
|
||||||
|
stream_map[qptr] = stream;
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dnnl::stream stream_dnnl() {
|
||||||
|
return stream_dnnl(device, 0);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// pool
|
// pool
|
||||||
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
|
||||||
|
|
||||||
|
|
101
ggml/src/ggml-sycl/gemm.hpp
Normal file
101
ggml/src/ggml-sycl/gemm.hpp
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
//
|
||||||
|
// MIT license
|
||||||
|
// Copyright (C) 2024 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: MIT
|
||||||
|
//
|
||||||
|
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef GGML_SYCL_GEMM_HPP
|
||||||
|
#define GGML_SYCL_GEMM_HPP
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "ggml-sycl.h"
|
||||||
|
|
||||||
|
#if GGML_SYCL_DNNL
|
||||||
|
|
||||||
|
#include "dnnl.hpp"
|
||||||
|
#include "dnnl_sycl.hpp"
|
||||||
|
|
||||||
|
class DnnlGemmWrapper {
|
||||||
|
public:
|
||||||
|
using dt = dnnl::memory::data_type;
|
||||||
|
using tag = dnnl::memory::format_tag;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static constexpr dt to_dt() {
|
||||||
|
if constexpr (std::is_same_v<T, float>) return dt::f32;
|
||||||
|
else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
|
||||||
|
else static_assert(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void row_gemm(sycl::queue& q, bool a_trans,
|
||||||
|
bool b_trans, int m, int n, int k,
|
||||||
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
|
{
|
||||||
|
// Get the device associated with the queue
|
||||||
|
sycl::device dev = q.get_device();
|
||||||
|
// Get the context associated with the queue
|
||||||
|
sycl::context ctx = q.get_context();
|
||||||
|
const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
|
||||||
|
const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
|
||||||
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
||||||
|
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
||||||
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
// Create the primitive.
|
||||||
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
|
// Primitive arguments.
|
||||||
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
|
||||||
|
matmul_prim.execute(stream, matmul_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
|
||||||
|
bool b_trans, int m, int n, int k,
|
||||||
|
const void* a, dt at, const void* b, dt bt, void* c, dt ct)
|
||||||
|
{
|
||||||
|
auto const eng = stream.get_engine();
|
||||||
|
dnnl::memory::dims a_dims = { m, k };
|
||||||
|
dnnl::memory::dims b_dims = { k, n };
|
||||||
|
dnnl::memory::dims c_dims = { m, n };
|
||||||
|
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
|
||||||
|
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
|
||||||
|
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
|
||||||
|
auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
|
||||||
|
auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
|
||||||
|
auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
|
||||||
|
auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
|
||||||
|
|
||||||
|
// Create the primitive.
|
||||||
|
auto matmul_prim = dnnl::matmul(matmul_pd);
|
||||||
|
// Primitive arguments.
|
||||||
|
std::unordered_map<int, dnnl::memory> matmul_args;
|
||||||
|
matmul_args.insert({ DNNL_ARG_SRC, a_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
|
||||||
|
matmul_args.insert({ DNNL_ARG_DST, c_mem });
|
||||||
|
|
||||||
|
matmul_prim.execute(stream, matmul_args);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_GEMM_HPP
|
306
ggml/src/ggml.c
306
ggml/src/ggml.c
|
@ -7095,7 +7095,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
struct ggml_tensor * mask,
|
struct ggml_tensor * mask,
|
||||||
float scale,
|
float scale,
|
||||||
float max_bias) {
|
float max_bias,
|
||||||
|
float logit_softcap) {
|
||||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
// TODO: check if vT can be multiplied by (k*qT)
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
|
@ -7122,7 +7123,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
float params[] = { scale, max_bias };
|
float params[] = { scale, max_bias, logit_softcap };
|
||||||
ggml_set_op_params(result, params, sizeof(params));
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
result->op = GGML_OP_FLASH_ATTN_EXT;
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
@ -7142,7 +7143,7 @@ void ggml_flash_attn_ext_set_prec(
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
|
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_flash_attn_back
|
// ggml_flash_attn_back
|
||||||
|
@ -7229,43 +7230,34 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
|
||||||
struct ggml_tensor * ggml_ssm_conv(
|
struct ggml_tensor * ggml_ssm_conv(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * s,
|
struct ggml_tensor * sx,
|
||||||
struct ggml_tensor * x,
|
struct ggml_tensor * c) {
|
||||||
struct ggml_tensor * c,
|
GGML_ASSERT(ggml_is_3d(sx));
|
||||||
struct ggml_tensor * sq) {
|
|
||||||
GGML_ASSERT(ggml_is_3d(s));
|
|
||||||
GGML_ASSERT(ggml_is_matrix(x));
|
|
||||||
GGML_ASSERT(ggml_is_matrix(c));
|
GGML_ASSERT(ggml_is_matrix(c));
|
||||||
GGML_ASSERT(ggml_is_matrix(sq));
|
|
||||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
|
||||||
|
|
||||||
const int64_t d_conv = c->ne[0];
|
const int64_t d_conv = c->ne[0];
|
||||||
const int64_t d_inner = c->ne[1];
|
const int64_t d_inner = c->ne[1];
|
||||||
const int64_t n_tokens = x->ne[1];
|
const int64_t n_t = sx->ne[0] - d_conv + 1; // tokens per sequence
|
||||||
const int64_t n_kv = s->ne[2];
|
const int64_t n_s = sx->ne[2];
|
||||||
|
|
||||||
GGML_ASSERT( s->ne[0] == d_conv - 1);
|
// TODO: maybe support other strides than 1?
|
||||||
GGML_ASSERT( s->ne[1] == d_inner);
|
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
|
||||||
GGML_ASSERT( x->ne[0] == d_inner);
|
GGML_ASSERT(sx->ne[1] == d_inner);
|
||||||
GGML_ASSERT(sq->ne[0] == n_kv);
|
GGML_ASSERT(n_t >= 0);
|
||||||
GGML_ASSERT(sq->ne[1] == n_tokens);
|
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || c->grad || sq->grad) {
|
if (sx->grad || c->grad) {
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated x and conv_states, {d_inner, n_tokens} with {d_conv, d_inner, n_kv}
|
struct ggml_tensor * result = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_t, n_s);
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, (d_inner*n_tokens) + (d_conv*d_inner*n_kv));
|
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_CONV;
|
result->op = GGML_OP_SSM_CONV;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
result->src[0] = s;
|
result->src[0] = sx;
|
||||||
result->src[1] = x;
|
result->src[1] = c;
|
||||||
result->src[2] = c;
|
|
||||||
result->src[3] = sq;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -7279,39 +7271,42 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
struct ggml_tensor * dt,
|
struct ggml_tensor * dt,
|
||||||
struct ggml_tensor * A,
|
struct ggml_tensor * A,
|
||||||
struct ggml_tensor * B,
|
struct ggml_tensor * B,
|
||||||
struct ggml_tensor * C,
|
struct ggml_tensor * C) {
|
||||||
struct ggml_tensor * sq) {
|
|
||||||
GGML_ASSERT(ggml_is_contiguous(s));
|
GGML_ASSERT(ggml_is_contiguous(s));
|
||||||
GGML_ASSERT(ggml_is_contiguous(x));
|
GGML_ASSERT(ggml_is_contiguous(x));
|
||||||
GGML_ASSERT(ggml_is_contiguous(dt));
|
GGML_ASSERT(ggml_is_contiguous(dt));
|
||||||
GGML_ASSERT(ggml_is_contiguous(A));
|
GGML_ASSERT(ggml_is_contiguous(A));
|
||||||
GGML_ASSERT(sq->type == GGML_TYPE_I32);
|
GGML_ASSERT(ggml_is_matrix(A));
|
||||||
|
GGML_ASSERT(ggml_is_3d(B));
|
||||||
|
GGML_ASSERT(ggml_is_3d(s));
|
||||||
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
|
||||||
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
|
||||||
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
GGML_ASSERT(ggml_are_same_shape(x, dt));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(B, C));
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t d_state = s->ne[0];
|
const int64_t d_state = s->ne[0];
|
||||||
const int64_t d_inner = s->ne[1];
|
const int64_t d_inner = s->ne[1];
|
||||||
const int64_t n_tokens = x->ne[1];
|
const int64_t n_seq_tokens = x->ne[1];
|
||||||
|
const int64_t n_seqs = x->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(s->ne[2] == n_seqs);
|
||||||
GGML_ASSERT(x->ne[0] == d_inner);
|
GGML_ASSERT(x->ne[0] == d_inner);
|
||||||
GGML_ASSERT(A->ne[0] == d_state);
|
GGML_ASSERT(A->ne[0] == d_state);
|
||||||
GGML_ASSERT(A->ne[1] == d_inner);
|
GGML_ASSERT(A->ne[1] == d_inner);
|
||||||
GGML_ASSERT(B->ne[0] == d_state);
|
GGML_ASSERT(B->ne[0] == d_state);
|
||||||
GGML_ASSERT(B->ne[1] == n_tokens);
|
GGML_ASSERT(B->ne[1] == n_seq_tokens);
|
||||||
GGML_ASSERT(C->ne[0] == d_state);
|
GGML_ASSERT(B->ne[2] == n_seqs);
|
||||||
GGML_ASSERT(C->ne[1] == n_tokens);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad || sq->grad) {
|
if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
|
||||||
GGML_ABORT("fatal error"); // TODO: implement
|
GGML_ABORT("fatal error"); // TODO: implement
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2-in-1 concatenated y and ssm_states, {d_inner, n_tokens} with {d_state, d_inner, n_kv}
|
// concatenated y + ssm_states
|
||||||
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s));
|
||||||
|
|
||||||
result->op = GGML_OP_SSM_SCAN;
|
result->op = GGML_OP_SSM_SCAN;
|
||||||
|
@ -7322,7 +7317,6 @@ struct ggml_tensor * ggml_ssm_scan(
|
||||||
result->src[3] = A;
|
result->src[3] = A;
|
||||||
result->src[4] = B;
|
result->src[4] = B;
|
||||||
result->src[5] = C;
|
result->src[5] = C;
|
||||||
result->src[6] = sq;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
@ -10995,11 +10989,6 @@ static void ggml_compute_forward_concat_f32(
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
// TODO: support for transposed / permuted tensors
|
|
||||||
GGML_ASSERT(nb0 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb00 == sizeof(float));
|
|
||||||
GGML_ASSERT(nb10 == sizeof(float));
|
|
||||||
|
|
||||||
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
const int32_t dim = ggml_get_op_params_i32(dst, 0);
|
||||||
|
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
|
@ -15283,11 +15272,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const int ir0 = dr*ith;
|
const int ir0 = dr*ith;
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
float logit_softcap = 0.0f;
|
||||||
|
|
||||||
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
||||||
|
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
|
if (logit_softcap != 0) {
|
||||||
|
scale /= logit_softcap;
|
||||||
|
}
|
||||||
|
|
||||||
const uint32_t n_head = neq2;
|
const uint32_t n_head = neq2;
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
|
||||||
|
@ -15351,7 +15346,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||||
|
|
||||||
s = s*scale + mv; // scale KQ value and apply mask
|
s = s*scale; // scale KQ value
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
s = logit_softcap*tanhf(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
s += mv; // apply mask
|
||||||
|
|
||||||
const float Mold = M;
|
const float Mold = M;
|
||||||
|
|
||||||
|
@ -15360,7 +15361,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
|
||||||
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||||
|
|
||||||
if (v->type== GGML_TYPE_F16) {
|
if (v->type == GGML_TYPE_F16) {
|
||||||
if (s > M) {
|
if (s > M) {
|
||||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||||
M = s;
|
M = s;
|
||||||
|
@ -15427,7 +15428,7 @@ static void ggml_compute_forward_flash_attn_ext(
|
||||||
const struct ggml_tensor * v,
|
const struct ggml_tensor * v,
|
||||||
const struct ggml_tensor * mask,
|
const struct ggml_tensor * mask,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
switch (dst->op_params[2]) {
|
switch (dst->op_params[3]) {
|
||||||
case GGML_PREC_DEFAULT:
|
case GGML_PREC_DEFAULT:
|
||||||
case GGML_PREC_F32:
|
case GGML_PREC_F32:
|
||||||
{
|
{
|
||||||
|
@ -15782,27 +15783,22 @@ static void ggml_compute_forward_flash_attn_back(
|
||||||
static void ggml_compute_forward_ssm_conv_f32(
|
static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const struct ggml_compute_params * params,
|
const struct ggml_compute_params * params,
|
||||||
struct ggml_tensor * dst) {
|
struct ggml_tensor * dst) {
|
||||||
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
|
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
|
||||||
const struct ggml_tensor * src1 = dst->src[1]; // x
|
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
|
||||||
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
|
|
||||||
const struct ggml_tensor * src3 = dst->src[3]; // state_seq
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int nc = src2->ne[0]; // d_conv
|
const int nc = src1->ne[0]; // d_conv
|
||||||
const int nr = src0->ne[1]; // d_inner
|
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
|
||||||
const int n_t = src1->ne[1]; // n_tokens
|
const int nr = src0->ne[1]; // d_inner
|
||||||
const int n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int n_t = dst->ne[1]; // tokens per sequence
|
||||||
|
const int n_s = dst->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT((nr*n_t) + (nc*nr*n_kv) == ggml_nelements(dst));
|
GGML_ASSERT( dst->ne[0] == nr);
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
GGML_ASSERT(src1->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src2->nb[0] == sizeof(float));
|
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(int32_t));
|
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// for use with the destination state offset between sequences
|
|
||||||
GGML_ASSERT(src2->nb[2] == src2->ne[1]*src2->ne[0]*sizeof(float));
|
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
@ -15812,76 +15808,29 @@ static void ggml_compute_forward_ssm_conv_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
// multiple sequences means it's hard to know when it's the first time a state is read,
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
// so copy them all over to the destination, just to be sure.
|
// {d_conv - 1 + n_t, d_inner, n_seqs}
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
// sliding window
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + i3*(src2->nb[2]) + nr*n_t*sizeof(float));
|
const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
|
||||||
// can't use memcpy because of d_conv vs d_conv - 1
|
float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
|
||||||
|
|
||||||
|
// TODO: transpose the output for smaller strides for big batches?
|
||||||
|
// d_inner
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
// rowwise dot product
|
||||||
// copy s0 to last (d_conv - 1) columns of s
|
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
|
||||||
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
// d_conv
|
||||||
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
|
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
|
||||||
}
|
}
|
||||||
|
x[i1] = sumf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
|
||||||
int32_t * sq = (int32_t *) ((char *) src3->data + i2*(src3->nb[1])); // {n_kv, n_tokens}
|
|
||||||
float * x = (float *) ((char *) dst->data + ir0*sizeof(float) + i2*(nr*sizeof(float))); // {d_inner, n_tokens}
|
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src2->nb[1]) + sq[0]*(src2->nb[2]) + nr*n_t*sizeof(float)); // {d_conv, d_inner, n_kv}
|
|
||||||
float * s0; // {d_conv - 1, d_inner, n_kv}
|
|
||||||
float * x0 = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
||||||
float * c = (float *) ((char *) src2->data + ir0*(src2->nb[1])); // {d_conv, d_inner}
|
|
||||||
int ne0s0;
|
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
|
||||||
|
|
||||||
// avoid needing to copy the state for the first token
|
|
||||||
if (i2 == 0) {
|
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_conv - 1, d_inner, n_kv}
|
|
||||||
ne0s0 = src0->ne[0];
|
|
||||||
} else {
|
|
||||||
// the source is the last (d_conv - 1) columns of the destination
|
|
||||||
s0 = s + 1;
|
|
||||||
ne0s0 = nc;
|
|
||||||
}
|
|
||||||
|
|
||||||
// d_inner
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
// shift state left
|
|
||||||
for (int i0 = 0; i0 < nc - 1; ++i0) {
|
|
||||||
s[i0 + i1*nc] = s0[i0 + i1*ne0s0];
|
|
||||||
}
|
|
||||||
// insert x on the last column
|
|
||||||
s[(nc - 1) + i1*nc] = x0[i1];
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// it seems a little faster when this is separate from the state shift
|
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
|
||||||
// rowwise dot product
|
|
||||||
float sumf = 0.0f;
|
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
|
||||||
int i = i0 + i1*nc;
|
|
||||||
sumf += s[i] * c[i];
|
|
||||||
}
|
|
||||||
x[i1] = sumf;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_compute_forward_ssm_conv(
|
static void ggml_compute_forward_ssm_conv(
|
||||||
|
@ -15910,15 +15859,14 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const struct ggml_tensor * src3 = dst->src[3]; // A
|
const struct ggml_tensor * src3 = dst->src[3]; // A
|
||||||
const struct ggml_tensor * src4 = dst->src[4]; // B
|
const struct ggml_tensor * src4 = dst->src[4]; // B
|
||||||
const struct ggml_tensor * src5 = dst->src[5]; // C
|
const struct ggml_tensor * src5 = dst->src[5]; // C
|
||||||
const struct ggml_tensor * src6 = dst->src[6]; // sq
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
const int64_t nc = src0->ne[0]; // d_state
|
const int64_t nc = src0->ne[0]; // d_state
|
||||||
const int64_t nr = src0->ne[1]; // d_inner
|
const int64_t nr = src0->ne[1]; // d_inner
|
||||||
const int64_t n_t = src1->ne[1]; // number of tokens in the batch
|
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
|
||||||
const int64_t n_kv = src0->ne[2]; // max number of sequences in the batch
|
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
|
||||||
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
GGML_ASSERT(src0->nb[0] == sizeof(float));
|
||||||
|
@ -15927,12 +15875,12 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
GGML_ASSERT(src3->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
GGML_ASSERT(src4->nb[0] == sizeof(float));
|
||||||
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
GGML_ASSERT(src5->nb[0] == sizeof(float));
|
||||||
// required for the dot product between s and C, and when copying the states
|
// required for the dot product between s and C
|
||||||
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
|
||||||
// required for per-sequence offsets for states
|
// required for per-sequence offsets for states
|
||||||
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
|
||||||
// required to get correct offset for state destination (i.e. src1->nb[2])
|
// required to get correct offset for state destination (i.e. src1->nb[3])
|
||||||
GGML_ASSERT(src1->nb[2] == src1->ne[0]*src1->ne[1]*sizeof(float));
|
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
@ -15942,64 +15890,36 @@ static void ggml_compute_forward_ssm_scan_f32(
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
const int ir = ir1 - ir0;
|
const int ir = ir1 - ir0;
|
||||||
|
|
||||||
if (n_kv > 1) {
|
for (int i3 = 0; i3 < n_s; ++i3) {
|
||||||
// it's hard to know if the source states have already been copied
|
for (int i2 = 0; i2 < n_t; ++i2) {
|
||||||
// when there are multiple, so copy them already.
|
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
|
||||||
for (int i3 = 0; i3 < n_kv; ++i3) {
|
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]));
|
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[2]);
|
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
||||||
memcpy(s, s0, nc*ir*sizeof(float));
|
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
|
||||||
}
|
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
|
||||||
}
|
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
|
||||||
|
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
|
||||||
|
|
||||||
for (int i2 = 0; i2 < n_t; ++i2) {
|
// use the output as the source for the next token-wise iterations
|
||||||
int32_t * sq = (int32_t *) ((char *) src6->data + i2*(src6->nb[1])); // {n_kv, n_tokens}
|
if (i2 > 0) { s0 = s; }
|
||||||
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
||||||
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2]) + src1->nb[2]); // {d_state, d_inner, n_kv}
|
|
||||||
float * s0;
|
|
||||||
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1])); // {d_inner, n_tokens}
|
|
||||||
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1])); // {d_inner, n_tokens}
|
|
||||||
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
|
|
||||||
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
|
|
||||||
float * C = (float *) ((char *) src5->data + i2*(src5->nb[1])); // {d_state, n_tokens}
|
|
||||||
|
|
||||||
GGML_ASSERT(0 <= sq[0] && sq[0] < n_kv);
|
// d_inner
|
||||||
|
for (int i1 = 0; i1 < ir; ++i1) {
|
||||||
// avoid needing to copy the state for the first token
|
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
||||||
if (i2 == 0) {
|
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
||||||
s0 = (float *) ((char *) src0->data + ir0*(src0->nb[1]) + sq[0]*(src0->nb[2])); // {d_state, d_inner, n_kv}
|
float x_dt = x[i1] * dt_soft_plus;
|
||||||
} else {
|
float sumf = 0.0f;
|
||||||
// otherwise the source is the same as the destination
|
// d_state
|
||||||
s0 = s;
|
for (int i0 = 0; i0 < nc; ++i0) {
|
||||||
}
|
int i = i0 + i1*nc;
|
||||||
|
// state = prev_state * dA + dB * x
|
||||||
// d_inner
|
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
||||||
for (int i1 = 0; i1 < ir; ++i1) {
|
// y = rowwise_dotprod(state, C)
|
||||||
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
|
sumf += state * C[i0];
|
||||||
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
|
s[i] = state;
|
||||||
float x_dt = x[i1] * dt_soft_plus;
|
}
|
||||||
float sumf = 0.0f;
|
y[i1] = sumf;
|
||||||
// d_state
|
|
||||||
for (int i0 = 0; i0 < nc; ++i0) {
|
|
||||||
int i = i0 + i1*nc;
|
|
||||||
// state = prev_state * dA + dB * x
|
|
||||||
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
||||||
// y = rowwise_dotprod(state, C)
|
|
||||||
sumf += state * C[i0];
|
|
||||||
s[i] = state;
|
|
||||||
}
|
|
||||||
y[i1] = sumf;
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle copies when there are multiple output states
|
|
||||||
for (int i3 = 1; i3 < n_kv; ++i3) {
|
|
||||||
int32_t seq = sq[i3];
|
|
||||||
if (0 <= seq && seq < n_kv) {
|
|
||||||
float * s1 = s + (seq - sq[0])*nc*nr;
|
|
||||||
memcpy(s1, s, nc*ir*sizeof(float));
|
|
||||||
} else {
|
|
||||||
// stop at negative or too big seq_ids
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -519,6 +519,9 @@ extern "C" {
|
||||||
// to the decoder to start generating output sequence. For other models, it returns -1.
|
// to the decoder to start generating output sequence. For other models, it returns -1.
|
||||||
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
|
||||||
|
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
|
||||||
|
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
LLAMA_API uint32_t llama_model_quantize(
|
LLAMA_API uint32_t llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
|
|
1538
src/llama.cpp
1538
src/llama.cpp
File diff suppressed because it is too large
Load diff
|
@ -1652,19 +1652,20 @@ struct test_flash_attn_ext : public test_case {
|
||||||
const bool mask; // use mask
|
const bool mask; // use mask
|
||||||
|
|
||||||
const float max_bias; // ALiBi
|
const float max_bias; // ALiBi
|
||||||
|
const float logit_softcap; // Gemma 2
|
||||||
|
|
||||||
const ggml_type type_KV;
|
const ggml_type type_KV;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
|
return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
|
||||||
}
|
}
|
||||||
|
|
||||||
double max_nmse_err() override {
|
double max_nmse_err() override {
|
||||||
return 5e-4;
|
return 5e-4;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
|
||||||
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
|
: hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
|
||||||
|
@ -1673,7 +1674,7 @@ struct test_flash_attn_ext : public test_case {
|
||||||
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
|
||||||
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
|
||||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -2437,11 +2438,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
for (bool mask : { true, false } ) {
|
for (bool mask : { true, false } ) {
|
||||||
for (float max_bias : { 0.0f, 8.0f }) {
|
for (float max_bias : { 0.0f, 8.0f }) {
|
||||||
if (!mask && max_bias > 0.0f) continue;
|
if (!mask && max_bias > 0.0f) continue;
|
||||||
for (int nh : { 32, }) {
|
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||||
for (int kv : { 512, 1024, }) {
|
if (hs != 128 && logit_softcap != 0.0f) continue;
|
||||||
for (int nb : { 1, 2, 4, 8, }) {
|
for (int nh : { 32, }) {
|
||||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
for (int kv : { 512, 1024, }) {
|
||||||
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
|
for (int nb : { 1, 2, 4, 8, }) {
|
||||||
|
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ MODELS_REPO_URL=https://huggingface.co/ggml-org/$MODELS_REPO
|
||||||
# Clone the Hugging Face repository if the directory does not exist
|
# Clone the Hugging Face repository if the directory does not exist
|
||||||
if [ ! -d "$MODELS_REPO" ]; then
|
if [ ! -d "$MODELS_REPO" ]; then
|
||||||
echo "Cloning the Hugging Face repository..."
|
echo "Cloning the Hugging Face repository..."
|
||||||
git clone $MODELS_REPO_URL
|
git clone $MODELS_REPO_URL --depth 1
|
||||||
else
|
else
|
||||||
echo "Repository already exists. Skipping clone."
|
echo "Repository already exists. Skipping clone."
|
||||||
fi
|
fi
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue