This commit is contained in:
Qingtao Li 2024-11-06 16:26:18 +08:00 committed by GitHub
commit 076c793e82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 1039 additions and 9 deletions

View file

@ -65,6 +65,8 @@ class Model:
model_name: str | None
metadata_override: Path | None
dir_model_card: Path
enable_t_mac: bool
kcfg_file: Path | None
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@ -73,7 +75,8 @@ class Model:
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
small_first_shard: bool = False, hparams: dict[str, Any] | None = None):
small_first_shard: bool = False, hparams: dict[str, Any] | None = None,
enable_t_mac: bool = False, kcfg_file: Path | None = None):
if type(self) is Model:
raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
@ -95,7 +98,8 @@ class Model:
self.metadata_override = metadata_override
self.model_name = model_name
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
self.enable_t_mac = enable_t_mac
self.kcfg_file = kcfg_file
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
if self.ftype == gguf.LlamaFileType.GUESSED:
# NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie.
@ -265,6 +269,73 @@ class Model:
return [(self.map_tensor_name(name), data_torch)]
_gptq_quant_dict: dict[str, Tensor] | None = None
_t_mac_bits: int = 0
_t_mac_raw_shape: tuple[int, ...] | None = None
# Repack and merge qweight, scales, and qzeros into a single tensor
# Currently, this logic is nearly impossible to be implemented in quants.py
def _modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not self.enable_t_mac:
return self.modify_tensors(data_torch, name, bid)
# bits = 0 means not quantized
self._t_mac_bits = 0
self._t_mac_raw_shape = None
from t_mac.model_utils import get_quantization_config, preprocess_for_t_mac
quantization_config = get_quantization_config(self.dir_model)
if quantization_config["quant_method"] == "gptq": # AutoGPTQ/GPTQModel
if name.endswith(".g_idx"):
return []
if name.endswith(".qweight") or name.endswith(".scales") or name.endswith(".qzeros"):
if self._gptq_quant_dict is None:
self._gptq_quant_dict = {}
suffix = "." + name.split(".")[-1]
base_name = name.replace(suffix, "")
self._gptq_quant_dict.setdefault(base_name, {})[suffix] = data_torch
if len(self._gptq_quant_dict[base_name]) < 3:
return []
qweight = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qweight"]).numpy()
scales = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".scales"]).numpy()
qzeros = LazyTorchTensor.to_eager(self._gptq_quant_dict[base_name][".qzeros"]).numpy()
name = base_name + ".weight"
from t_mac.model_utils import unpack_gptqv2
w, scales, zeros, bits, group_size = unpack_gptqv2(qweight, scales, qzeros, "gptqmodel" in quantization_config["quantizer"])
self._t_mac_bits = bits
self._t_mac_raw_shape = w.shape
if bits != quantization_config["bits"] or group_size != quantization_config["group_size"]:
logger.warning("Error while parsing weights for quantization_config: {}".format(quantization_config))
# For permutation in, e.g., LlamaModel
w = self.modify_tensors(torch.from_numpy(w), name, bid)[0][1].numpy()
scales = self.modify_tensors(torch.from_numpy(scales), name, bid)[0][1].numpy()
zeros = self.modify_tensors(torch.from_numpy(zeros), name, bid)[0][1].numpy()
if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
if quantization_config["sym"]:
if not np.allclose(zeros, np.zeros_like(zeros)):
logger.warning("Although the quantized model claimed to be symmetric, the weights are asymmetric")
else:
zeros = None
data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scales, zeros, bits=bits))
else:
old_shape = w.shape
w = w.astype("float32").reshape(-1, group_size)
scales = scales.astype("float32").reshape(-1, 1)
zeros = zeros.astype("float32").reshape(-1, 1)
data = (w - (zeros / scales + (2 ** (bits - 1)))) * scales
data_torch = torch.from_numpy(data.reshape(old_shape))
if self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_torch = data_torch.to(torch.float16)
return [(self.map_tensor_name(name), data_torch)]
return self.modify_tensors(data_torch, name, bid)
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
@ -285,7 +356,7 @@ class Model:
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
if data_torch.dtype not in (torch.float16, torch.float32) and not self.enable_t_mac:
data_torch = data_torch.to(torch.float32)
# use the first number-like part of the tensor name as the block id
@ -295,7 +366,13 @@ class Model:
bid = int(part)
break
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
for new_name, data_torch in (self._modify_tensors(data_torch, name, bid)):
# Some GPTQ models have empty bias tensors which are not in the model architecture.
# These tensors will cause tensor number check to fail, so we have to skip them.
if new_name.endswith(".bias") and np.all(LazyTorchTensor.to_eager(data_torch).numpy() == 0):
logger.info(f"Skipping empty bias tensor: {new_name}")
continue
data = data_torch.squeeze().numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
@ -344,6 +421,19 @@ class Model:
# TODO: use Q4_K and Q6_K
data_qtype = gguf.GGMLQuantizationType.F16
# If self._t_mac_bits > 0, the tensor is quantized by GPTQ
if self.enable_t_mac and self._t_mac_bits > 0 and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
if self._t_mac_bits == 1:
data_qtype = gguf.GGMLQuantizationType.I1
elif self._t_mac_bits == 2:
data_qtype = gguf.GGMLQuantizationType.I2
elif self._t_mac_bits == 3:
data_qtype = gguf.GGMLQuantizationType.I3
elif self._t_mac_bits == 4:
data_qtype = gguf.GGMLQuantizationType.I4
else:
raise ValueError(f"Unsupported number of bits: {self._t_mac_bits}")
# No override (data_qtype is False), or wants to be quantized (data_qtype is True)
if isinstance(data_qtype, bool):
if self.ftype == gguf.LlamaFileType.ALL_F32:
@ -358,6 +448,12 @@ class Model:
data_qtype = gguf.GGMLQuantizationType.TQ1_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0:
data_qtype = gguf.GGMLQuantizationType.TQ2_0
elif self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
# If the tensor is successfully quantized, data_qtype should be I1/2/3/4
# If data_qtype is still bool, then the tensor should not be quantized
# In practice, this tensor is `output.weight` for GPTQ models
# TODO: Consider quantizing it?
data_qtype = gguf.GGMLQuantizationType.F16
else:
raise ValueError(f"Unknown file type: {self.ftype.name}")
@ -369,6 +465,7 @@ class Model:
data = gguf.quants.quantize(data, data_qtype)
shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape
shape = self._t_mac_raw_shape or shape
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}"
@ -376,7 +473,8 @@ class Model:
# n_dims is implicit in the shape
logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype)
raw_shape = gguf.quant_shape_to_byte_shape(self._t_mac_raw_shape, data_qtype) if self.ftype == gguf.LlamaFileType.MOSTLY_INT_N and self._t_mac_raw_shape else None
self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype, raw_shape=raw_shape)
def set_type(self):
self.gguf_writer.add_type(gguf.GGUFType.MODEL)
@ -1700,6 +1798,15 @@ class BitnetModel(Model):
]):
# transform weight into 1/0/-1 (in fp32)
data_torch = self.weight_quant(data_torch)
if self.enable_t_mac and self.ftype == gguf.LlamaFileType.MOSTLY_INT_N:
# transform weight into T-MAC INT_N format
from t_mac.model_utils import preprocess_for_t_mac
data = LazyTorchTensor.to_eager(data_torch).numpy()
scale = np.max(np.abs(data))
w = np.round(data / scale + 2).astype(np.uint8)
data_torch = torch.from_numpy(preprocess_for_t_mac(self.kcfg_file, w, scale.reshape(1), bits=2))
self._t_mac_bits = 2
self._t_mac_raw_shape = w.shape
yield (new_name, data_torch)
@ -4297,8 +4404,8 @@ def parse_args() -> argparse.Namespace:
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "int_n", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and int_n for int1/2/3/4, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--bigendian", action="store_true",
@ -4344,6 +4451,14 @@ def parse_args() -> argparse.Namespace:
"--metadata", type=Path,
help="Specify the path for an authorship metadata override file"
)
parser.add_argument(
"--enable-t-mac", action="store_true",
help="Enable T-MAC quantization format (disabled by default). Support GPTQ, GPTQv2, BitNet and BitDistiller."
)
parser.add_argument(
"--kcfg", type=Path,
help="Specify the path for the T-MAC configuration file"
)
return parser.parse_args()
@ -4387,6 +4502,7 @@ def main() -> None:
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0,
"tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0,
"int_n": gguf.LlamaFileType.MOSTLY_INT_N,
"auto": gguf.LlamaFileType.GUESSED,
}
@ -4420,7 +4536,8 @@ def main() -> None:
metadata_override=args.metadata, model_name=args.model_name,
split_max_tensors=args.split_max_tensors,
split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
small_first_shard=args.no_tensor_first_split)
small_first_shard=args.no_tensor_first_split,
enable_t_mac=args.enable_t_mac, kcfg_file=args.kcfg)
if args.vocab_only:
logger.info("Exporting model vocab...")

View file

@ -166,6 +166,9 @@ option(GGML_SYCL "ggml: use SYCL"
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
"ggml: sycl target device")
option(GGML_TMAC "ggml: use TMAC" OFF)
option(GGML_TMAC_SYSLIB "ggml: use TMAC system library" OFF)
option(GGML_TMAC_TVM_THREADPOOL "ggml: use TVM threadpool for TMAC" OFF)
# extra artifacts
option(GGML_BUILD_TESTS "ggml: build tests" ${GGML_STANDALONE})

42
ggml/include/ggml-tmac.h Normal file
View file

@ -0,0 +1,42 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __ARM_NEON
#include <arm_neon.h>
typedef float16_t tmac_float_type;
#else
typedef float tmac_float_type;
#endif
#ifdef __cplusplus
extern "C" {
#endif
struct tmac_tensor_extra {
int lut_scales_size;
int scales_size;
int n_tile_num;
uint8_t * qweights;
tmac_float_type * scales;
};
GGML_API void ggml_tmac_init(void);
GGML_API void ggml_tmac_free(void);
// src0->type == Q4_0/IQ2_XXS/IQ3_XXS
// T-MAC currently only supports BitNet quantization or GPTQ-like quantization (only scales, without zeros)
// If use i-quantization gguf models, the results will be wrong
// TODO: add customized block types Q2_0/Q3_0
GGML_API bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits);
GGML_API void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits);
GGML_API void ggml_tmac_transform_tensor(struct ggml_tensor * tensor);
GGML_API int ggml_tmac_get_type_bits(enum ggml_type type);
GGML_API void ggml_tmac_set_n_threads(int n_threads);
GGML_API size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif

View file

@ -389,6 +389,10 @@ extern "C" {
GGML_TYPE_Q4_0_8_8 = 33,
GGML_TYPE_TQ1_0 = 34,
GGML_TYPE_TQ2_0 = 35,
GGML_TYPE_I1 = 36,
GGML_TYPE_I2 = 37,
GGML_TYPE_I3 = 38,
GGML_TYPE_I4 = 39,
GGML_TYPE_COUNT,
};

View file

@ -870,6 +870,42 @@ if (GGML_KOMPUTE)
endif()
endif()
if (GGML_TMAC)
find_package(TMAC)
if (TMAC_FOUND)
message(STATUS "TMAC found")
list(APPEND GGML_CDEF_PUBLIC GGML_USE_TMAC)
set(GGML_HEADERS_TMAC ../include/ggml-tmac.h)
set(GGML_SOURCES_TMAC ggml-tmac.cpp)
link_directories(${TMAC_LIB_DIR})
file(COPY ${TMAC_LIB_DIR}/kcfg.ini DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
# TODO: link t_mac_object when GGML_TMAC_SYSLIB
if (GGML_TMAC_TVM_THREADPOOL)
add_compile_definitions(TMAC_USE_TVM_THREADPOOL)
set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac)
else()
if ((NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") OR
(NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
message(FATAL_ERROR "Clang is required for T-MAC compilation")
endif()
set(GGML_EXTRA_LIBS_PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} t_mac_no_tvm)
set(GGML_SOURCES_TMAC ${GGML_SOURCES_TMAC} ${TMAC_KERNELS_SOURCE})
endif()
if (GGML_TMAC_RECHUNK)
add_compile_definitions(TMAC_RECHUNK)
endif()
else()
message(WARNING "TMAC not found")
endif()
endif()
if (GGML_CPU_HBM)
find_library(memkind memkind REQUIRED)
@ -1170,6 +1206,26 @@ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
# Raspberry Pi 3, 4, Zero 2 (32-bit)
list(APPEND ARCH_FLAGS -mno-unaligned-access)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64" AND GGML_TMAC AND TMAC_FOUND)
# We need fullfp16 for T-MAC
# TODO: we need to simplify this logic through check_cxx_source_compiles or Presets?
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
# Device with armv8.7a+ cpu, e.g., WSL on Surface Laptop 7
# based on arm64-windows-llvm.cmake
list(APPEND ARCH_FLAGS -march=armv8.7-a+fp16 -fvectorize -ffp-model=fast -fno-finite-math-only)
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
else ()
# Jetson AGX Orin, Raspberry Pi 5
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif ()
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ARM64" AND GGML_TMAC AND TMAC_FOUND)
# ARM Windows with LLVM clang GNU interface
# We need fullfp16 for T-MAC
# TODO: check_cxx_source_compiles
list(APPEND ARCH_FLAGS -march=armv8.2a+fp16)
endif()
if (GGML_SVE)
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
endif()
@ -1184,7 +1240,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
# TODO: improve, should not reference files from the parent folder
include(../cmake/FindSIMD.cmake)
endif ()
if (GGML_AVX512)
# Can't use GGML_AVX512 with Clang for MSVC
# with error: conflicting types for '_m_prefetchw
if (GGML_AVX512 AND (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") AND (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang"))
list(APPEND ARCH_FLAGS /arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
@ -1388,6 +1446,7 @@ add_library(ggml
${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX}
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
${GGML_SOURCES_TMAC} ${GGML_HEADERS_TMAC}
ggml-aarch64.c ggml-aarch64.h
)

View file

@ -84,6 +84,10 @@
#include <Accelerate/Accelerate.h>
#endif
#if defined(GGML_USE_TMAC)
#include "ggml-tmac.h"
#endif
// floating point type used to accumulate sums
typedef double ggml_float;
@ -424,6 +428,26 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_I1] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I2] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I3] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
[GGML_TYPE_I4] = {
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
.vec_dot_type = GGML_TYPE_F32,
.nrows = 1,
},
};
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@ -7479,6 +7503,156 @@ static void ggml_compute_forward_mul_mat(
UseGgmlGemm1:;
#endif
// TODO: Refactor t-mac as ggml-backend,
// as ggml-blas.cpp has been moved to backend
#if defined(GGML_USE_TMAC)
if (ggml_tmac_can_mul_mat(src0, src1, dst)) {
const int bits = ggml_tmac_get_type_bits(type);
// src0: weight, ne00 = k, ne01 = n
// src1: activation, ne10 = k, ne11 = m
char * wdata = params->wdata;
struct tmac_tensor_extra * wt = src0->extra;
char * cur_wdata = wdata;
tmac_float_type * tmac_f_ptr = wdata;
if (sizeof(tmac_float_type) == 2) {
cur_wdata = wdata + MAX(ne10, ne01) * ne11 * sizeof(tmac_float_type);
};
int8_t * qlut = cur_wdata;
tmac_float_type * lut_scales = (tmac_float_type *) (qlut + ne10 * ne11 * 4);
tmac_float_type * lut_biases = (tmac_float_type *) (lut_scales + wt->lut_scales_size * ne11);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
tmac_float_type * act_input;
if (sizeof(tmac_float_type) == 2) {
act_input = tmac_f_ptr;
} else {
act_input = src1->data;
}
for (int ine11 = ith; ine11 < ne11; ine11 += nth) {
if (sizeof(tmac_float_type) == 2) {
ggml_fp32_to_fp16_row((const float *) src1->data + ne10 * ine11, act_input + ne10 * ine11, ne10);
}
ggml_tmac_mul_mat_task_init(act_input + ne10 * ine11,
qlut + ne10 * ine11 * 4,
lut_scales + wt->lut_scales_size * ine11,
lut_biases + wt->lut_scales_size * ine11,
ne01, ne00, 1, bits);
}
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
atomic_store_explicit(&params->threadpool->current_chunk, nth, memory_order_relaxed);
}
ggml_barrier(params->threadpool);
tmac_float_type * act_output;
if (sizeof(tmac_float_type) == 2) {
act_output = tmac_f_ptr;
} else {
act_output = dst->data;
}
// TODO: remove TVM threadpool if ensuring unused
#if defined(TMAC_USE_TVM_THREADPOOL)
if (ith != 0) {
return;
}
// TODO: schedule ne11(m) in T-MAC
for (int ine11 = 0; ine11 < ne11; ine11++) {
const int qlut_offset = ne10 * ine11 * 4;
const int lut_scales_offset = wt->lut_scales_size * ine11;
const int dst_offset = ne0 * ine11;
ggml_tmac_mul_mat_task_compute(wt->qweights,
wt->scales,
qlut + qlut_offset,
lut_scales + lut_scales_offset,
lut_biases + lut_scales_offset,
act_output + dst_offset,
ne01, ne00, 1, bits);
}
if (sizeof(tmac_float_type) == 2) {
ggml_fp16_to_fp32_row(tmac_f_ptr, dst->data, ne00 * ne01);
}
#else // #if defined(TMAC_USE_TVM_THREADPOOL)
const int n_tile_num = wt->n_tile_num;
// Currently, T-MAC requires ne0 devisible by n_tile_num
GGML_ASSERT(ne0 % n_tile_num == 0);
const int64_t w_size = ne00 * ne01 * bits / 8;
const int64_t w_chunk_size = w_size / n_tile_num;
const int64_t nr0 = ne0;
const int64_t nr1 = ne1 * ne2 * ne3;
// Adopt the same style with current llama.cpp impl
// But different chunk size for 0/1 dim.
// No scrap.
const int chunk_size0 = ne0 / n_tile_num;
const int chunk_size1 = 8; // TODO: tune in T-MAC
// nchunk0 == n_tile_num
int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1;
int64_t dr0 = chunk_size0;
int64_t dr1 = chunk_size1;
#if defined(TMAC_RECHUNK)
// Rechunk
if ((nchunk1 == 1) && (nchunk0 > nth * 4)) {
// dr0 should be divisible by chunk_size0
dr0 = (ne0 / (nth * 4) / chunk_size0) * chunk_size0;
nchunk0 = (nr0 + dr0 - 1) / dr0;
}
#endif
int current_chunk = ith;
while (current_chunk < nchunk0 * nchunk1) {
const int64_t ith0 = current_chunk % nchunk0;
const int64_t ith1 = current_chunk / nchunk0;
const int64_t ir0_start = dr0 * ith0;
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
const int64_t ir1_start = dr1 * ith1;
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
// inline ggml_compute_forward_mul_mat_one_chunk here for simplicity
for (int64_t ichunk0 = ir0_start / chunk_size0; ichunk0 < ir0_end / chunk_size0; ichunk0++) {
const int64_t w_offset = ichunk0 * w_chunk_size;
const int64_t scales_offset = ichunk0 * wt->scales_size / n_tile_num;
for (int64_t ine11 = ir1_start; ine11 < ir1_end; ine11++) {
const int64_t qlut_offset = ne10 * ine11 * 4;
const int64_t lut_scales_offset = wt->lut_scales_size * ine11;
const int64_t dst_offset = ne0 * ine11 + ichunk0 * chunk_size0;
ggml_tmac_mul_mat_task_compute(wt->qweights + w_offset,
wt->scales + scales_offset,
qlut + qlut_offset,
lut_scales + lut_scales_offset,
lut_biases + lut_scales_offset,
act_output + dst_offset,
chunk_size0, ne00, 1, bits);
if (sizeof(tmac_float_type) == 2) {
ggml_fp16_to_fp32_row(act_output + dst_offset, (float *) dst->data + dst_offset, chunk_size0);
}
}
}
if (nth >= nchunk0 * nchunk1) {
break;
}
current_chunk = atomic_fetch_add_explicit(&params->threadpool->current_chunk, 1, memory_order_relaxed);
}
#endif // #if defined(TMAC_USE_TVM_THREADPOOL)
return;
} // if (ggml_tmac_can_mul_mat(src0, src1, dst))
#endif // #if defined(GGML_USE_TMAC)
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
@ -9124,6 +9298,10 @@ static void ggml_compute_forward_clamp(
case GGML_TYPE_I32:
case GGML_TYPE_I64:
case GGML_TYPE_F64:
case GGML_TYPE_I1:
case GGML_TYPE_I2:
case GGML_TYPE_I3:
case GGML_TYPE_I4:
case GGML_TYPE_COUNT:
{
GGML_ABORT("fatal error");
@ -13173,6 +13351,11 @@ struct ggml_cplan ggml_graph_plan(
{
const enum ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type;
#if defined(GGML_USE_TMAC)
if (ggml_tmac_can_mul_mat(node->src[0], node->src[1], node)) {
cur = ggml_tmac_mul_mat_get_wsize(node->src[0], node->src[1], node);
} else
#endif
if (node->src[1]->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(node->src[1]));
}

View file

@ -15728,6 +15728,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
case GGML_TYPE_I64:
// nothing to validate
break;
case GGML_TYPE_I1:
case GGML_TYPE_I2:
case GGML_TYPE_I3:
case GGML_TYPE_I4:
// nothing to validate
break;
default:
{
fprintf(stderr, "%s: invalid type %d\n", __func__, type);

523
ggml/src/ggml-tmac.cpp Normal file
View file

@ -0,0 +1,523 @@
#include <vector>
#include <type_traits>
#include "ggml-tmac.h"
#include "ggml-quants.h"
#include "t-mac/tmac_gemm_wrapper.h"
#define GGML_TMAC_MAX_NODES 8192
static bool initialized = false;
static TMAC::TMACGeMMWrapper<tmac_float_type> * wrapper = nullptr;
static tmac_tensor_extra * tmac_tensor_extras = nullptr;
static size_t tmac_tensor_extras_index = 0;
static void * aligned_malloc(size_t size) {
#if defined(_WIN32)
return _aligned_malloc(size, TMAC::kAllocAlignment);
#else
void * ptr = nullptr;
posix_memalign(&ptr, TMAC::kAllocAlignment, size);
return ptr;
#endif
}
static void aligned_free(void * ptr) {
#if defined(_WIN32)
_aligned_free(ptr);
#else
free(ptr);
#endif
}
void ggml_tmac_init(void) {
LOG(INFO) << "ggml_tmac_init";
if (initialized) {
return;
}
initialized = true;
if (wrapper == nullptr) {
wrapper = new TMAC::TMACGeMMWrapper<tmac_float_type>();
}
if (tmac_tensor_extras == nullptr) {
tmac_tensor_extras = new tmac_tensor_extra[GGML_TMAC_MAX_NODES];
}
tmac_tensor_extras_index = 0;
}
void ggml_tmac_free(void) {
LOG(INFO) << "ggml_tmac_free";
if (!initialized) {
return;
}
initialized = false;
delete wrapper;
wrapper = nullptr;
for (size_t i = 0; i < tmac_tensor_extras_index; i++) {
// aligned_free(tmac_tensor_extras[i].qweights);
// aligned_free(tmac_tensor_extras[i].scales);
}
delete[] tmac_tensor_extras;
tmac_tensor_extras = nullptr;
}
static bool is_type_supported(enum ggml_type type) {
if (type == GGML_TYPE_Q4_0 ||
type == GGML_TYPE_I1 ||
type == GGML_TYPE_I2 ||
type == GGML_TYPE_I3 ||
type == GGML_TYPE_I4 ||
type == GGML_TYPE_TQ1_0 ||
type == GGML_TYPE_TQ2_0) {
return true;
} else {
return false;
}
}
static bool do_permutate(enum ggml_type type) {
if (type == GGML_TYPE_I1 ||
type == GGML_TYPE_I2 ||
type == GGML_TYPE_I3 ||
type == GGML_TYPE_I4) {
// Add additional args to decide if permuted I2 or naive I2
return false;
} else {
return true;
}
}
struct BlockQ40TypeAccessor {
using block_t = block_q4_0;
static constexpr int BITS = 4;
static constexpr int SIMD_LEN = 16;
static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS;
static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS;
static uint8_t get_q(const void * data, int idx) {
const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
int internal_idx = idx % group_size;
const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN;
int simd_idx = internal_idx % simd_n_elem;
return simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS);
}
static tmac_float_type get_scale(const void * data, int idx) {
ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
if (sizeof(tmac_float_type) == 2) {
tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
return *fp16dp;
} else {
return ggml_fp16_to_fp32(d);
}
}
};
struct BlockI2TypeAccessor {
static constexpr int BITS = 2;
static constexpr int n_elem = 8 / BITS;
static uint8_t get_q(const void * data, int idx) {
const uint8_t * qs = (const uint8_t *) data;
int elem_idx = idx % n_elem;
return qs[idx / n_elem] >> (elem_idx * BITS);
}
static tmac_float_type get_scale(const void * data, int idx, int group_size) {
const float * ss = (const float *) data;
float s = ss[idx / group_size];
return (tmac_float_type) s;
}
};
struct BlockTQ10TypeAccessor {
using block_t = block_tq1_0;
static constexpr int elements_qs = 5; // 5 elements per byte
static constexpr int elements_qh = 4; // 4 elements per byte
static constexpr int BITS = 2;
static constexpr int group_size_qs = sizeof(((block_t *)0)->qs) * elements_qs;
static constexpr int group_size_qh = sizeof(((block_t *)0)->qh) * elements_qh;
static constexpr int group_size = group_size_qs + group_size_qh;
static constexpr int SIMD_LEN_qs_1 = 32;
static constexpr int SIMD_LEN_qs_2 = 16;
static constexpr int SIMD_LEN_qh = 4;
static constexpr int simd_n_elem_qs_1 = SIMD_LEN_qs_1 * elements_qs; // 160
static constexpr int simd_n_elem_qs_2 = SIMD_LEN_qs_2 * elements_qs; // 80
static constexpr int simd_n_elem_qh = SIMD_LEN_qh * elements_qh; // 16
static constexpr uint8_t pow3[5] = {1, 3, 9, 27, 81};
static uint8_t get_q(const void * data, int idx) {
const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
uint8_t cur_qs;
uint8_t trit;
int internal_idx = idx % group_size;
if (internal_idx < simd_n_elem_qs_1) {
const int internal_offset = 0;
const uint8_t * simd_qs = qs + internal_offset;
int simd_idx = internal_idx;
int simd_byte = simd_idx % SIMD_LEN_qs_1;
int simd_trit = simd_idx / SIMD_LEN_qs_1;
cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
trit = ((uint16_t) cur_qs * 3) >> 8;
}
else if (internal_idx < simd_n_elem_qs_1 + simd_n_elem_qs_2) {
const int internal_offset = SIMD_LEN_qs_1;
const uint8_t * simd_qs = qs + internal_offset;
int simd_idx = internal_idx - simd_n_elem_qs_1;
int simd_byte = simd_idx % SIMD_LEN_qs_2;
int simd_trit = simd_idx / SIMD_LEN_qs_2;
cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
trit = ((uint16_t) cur_qs * 3) >> 8;
}
else {
const int internal_offset = SIMD_LEN_qs_1 + SIMD_LEN_qs_2;
const uint8_t * simd_qs = qs + internal_offset;
int simd_idx = internal_idx - simd_n_elem_qs_1 - simd_n_elem_qs_2;
int simd_byte = simd_idx % SIMD_LEN_qh;
int simd_trit = simd_idx / SIMD_LEN_qh;
cur_qs = simd_qs[simd_byte] * pow3[simd_trit];
trit = ((uint16_t) cur_qs * 3) >> 8;
}
return trit + 1;
}
static tmac_float_type get_scale(const void * data, int idx, int group_size) {
ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
if (sizeof(tmac_float_type) == 2) {
tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
return *fp16dp;
} else {
return ggml_fp16_to_fp32(d);
}
}
};
struct BlockTQ20TypeAccessor {
using block_t = block_tq2_0;
static constexpr int BITS = 2;
static constexpr int SIMD_LEN = 32;
static constexpr int group_size = (sizeof(block_t) - sizeof(ggml_fp16_t)) * 8 / BITS; // 256
static constexpr int simd_n_elem = SIMD_LEN * 8 / BITS; // 128
static uint8_t get_q(const void * data, int idx) {
const uint8_t * qs = (const uint8_t *) ((((const block_t *) data)[idx / group_size]).qs);
int internal_idx = idx % group_size;
const uint8_t * simd_qs = qs + internal_idx / simd_n_elem * SIMD_LEN;
int simd_idx = internal_idx % simd_n_elem;
return (simd_qs[simd_idx % SIMD_LEN] >> (simd_idx / SIMD_LEN * BITS)) + 1;
}
static tmac_float_type get_scale(const void * data, int idx, int group_size) {
ggml_fp16_t d = ((const block_t *) data)[idx / group_size].d;
if (sizeof(tmac_float_type) == 2) {
tmac_float_type * fp16dp = reinterpret_cast<tmac_float_type *>(&d);
return *fp16dp;
} else {
return ggml_fp16_to_fp32(d);
}
}
};
bool ggml_tmac_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
if ((is_type_supported(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
src0->backend == GGML_BACKEND_TYPE_CPU &&
strcmp(src0->name, "token_embd.weight") && // means not equal
strcmp(src0->name, "output.weight")) {
return true;
}
return false;
}
size_t ggml_tmac_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
const size_t ne01 = src0->ne[1];
const size_t ne10 = src1->ne[0];
const size_t ne11 = src1->ne[1];
const int bits = ggml_tmac_get_type_bits(src0->type);
TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(ne01, ne10, 1, bits);
size_t wsize = ne10 * ne11 * 4 * sizeof(int8_t) + kcfg.lut_scales_size * ne11 * 2 * sizeof(tmac_float_type);
if (sizeof(tmac_float_type) == 2) {
// Need fp32 to fp16 conversion
wsize += std::max(ne10, ne01) * ne11 * sizeof(tmac_float_type);
}
wsize = ((wsize - 1) / TMAC::kAllocAlignment + 1) * TMAC::kAllocAlignment;
return wsize;
}
// m = batch_size
// n = output_dim
void ggml_tmac_mul_mat_task_init(void * src1, void * qlut, void * lut_scales, void * lut_biases, int n, int k, int m, int bits) {
// t-mac llama.cpp n and m swapped
wrapper->llama_cpp_init(src1, qlut, lut_scales, lut_biases, n, k, m, bits);
}
void ggml_tmac_mul_mat_task_compute(void * src0, void * scales, void * qlut, void * lut_scales, void * lut_biases, void * dst, int n, int k, int m, int bits) {
wrapper->llama_cpp_compute(src0, scales, qlut, lut_scales, lut_biases, dst, n, k, m, bits);
}
size_t ggml_tmac_get_nbytes(const struct ggml_tensor * tensor) {
const int bits = ggml_tmac_get_type_bits(tensor->type);
int k = tensor->ne[0];
int m = tensor->ne[1]; // `n` in llama.cpp
TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits);
// Currently, I2 always uses float to store scales or zero points
size_t nbytes = k * m / 8 * bits + kcfg.scales_size * sizeof(float);
return nbytes;
}
void ggml_tmac_transform_tensor(struct ggml_tensor * tensor) {
if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
return;
}
const int bits = ggml_tmac_get_type_bits(tensor->type);
const int g = 4;
const int ngroups_per_elem = 2;
int k = tensor->ne[0];
int m = tensor->ne[1]; // `n` in llama.cpp
TMAC::TMACGeMMConfig kcfg = wrapper->get_kcfg(m, k, 1, bits);
const int bm = kcfg.bm;
const int simd_n_in = kcfg.simd_n_in;
const int simd_n_out = kcfg.simd_n_out;
const int kfactor = kcfg.kfactor;
const int group_size = kcfg.group_size; // could be different from block size in llama.cpp
const int lut_scales_size = kcfg.lut_scales_size;
const int scales_size = kcfg.scales_size;
const int n_tile_num = kcfg.n_tile_num;
DLOG(INFO) << "Transforming tensor: " << tensor->name << " (m: " << m << ", k: " << k << ", bits: " << bits << ")";
DLOG(INFO) << "kcfg (bm=" << bm << ", simd_n_in=" << simd_n_in << ", simd_n_out=" << simd_n_out << ", kfactor=" << kfactor
<< ", group_size=" << group_size << ", lut_scales_size=" << lut_scales_size << ", scales_size=" << scales_size << ", n_tile_num=" << n_tile_num << ")";
if (bm == 0) {
// TODO: warning token.embd if not support
if (!strcmp(tensor->name, "token_embd.weight") || !strcmp(tensor->name, "output.weight")) {
LOG(WARNING) << "Do not find kcfg for " << tensor->name << ". Consider compiling T-MAC kernel for it if vocab size is a multiply of 128 or 320, detected " << tensor->ne[1] << ".";
return;
}
else {
// Instead of fatal error, try to avoid using t-mac?
LOG(FATAL) << "Failed to find kcfg. Abort transforming";
return;
}
}
const int mgroup = ngroups_per_elem * simd_n_in;
m = m * bits;
uint8_t * qweights;
tmac_float_type * scales;
scales = (tmac_float_type *) aligned_malloc(scales_size * sizeof(tmac_float_type));
if (do_permutate(tensor->type)) {
qweights = (uint8_t *) aligned_malloc(k * m / 8);
} else {
qweights = (uint8_t *) tensor->data;
float * i2_scales = (float * )(qweights + k * m / 8);
for (int i = 0; i < scales_size; i++) {
scales[i] = (tmac_float_type) i2_scales[i];
}
}
tensor->extra = tmac_tensor_extras + tmac_tensor_extras_index;
tmac_tensor_extras[tmac_tensor_extras_index++] = {
/* .lut_scales_size = */ lut_scales_size,
/* .scales_size = */ scales_size,
/* .n_tile_num = */ n_tile_num,
/* .qweights = */ qweights,
/* .scales = */ scales
};
if (do_permutate(tensor->type)) {
// for fast testing
// #define TMAC_EMPTY_WEIGHTS
#ifndef TMAC_EMPTY_WEIGHTS
// TODO: optimize to accelerate weights loading
uint8_t * buf1 = new uint8_t[m * k];
uint8_t * buf2 = new uint8_t[m * k / g];
// # (M // bits, K, bits)
// w = np.stack([(w >> ib) & 1 for ib in range(bits)], axis=-1)
for (int im = 0; im < m / bits; im++) {
for (int ik = 0; ik < k; ik++) {
uint8_t v;
if (tensor->type == GGML_TYPE_Q4_0) {
v = BlockQ40TypeAccessor::get_q(tensor->data, im * k + ik);
} else if (tensor->type == GGML_TYPE_I2) {
v = BlockI2TypeAccessor::get_q(tensor->data, im * k + ik);
} else if (tensor->type == GGML_TYPE_TQ1_0) {
v = BlockTQ10TypeAccessor::get_q(tensor->data, im * k + ik);
} else if (tensor->type == GGML_TYPE_TQ2_0) {
v = BlockTQ20TypeAccessor::get_q(tensor->data, im * k + ik);
} else {
LOG(FATAL) << "Unsupported type";
}
for (int ib = 0; ib < bits; ib++) {
buf1[im * k * bits + ik * bits + ib] = (v >> ib) & 1;
}
}
}
// # (M // bits, K, bits) -> (M // bits, bits, K) -> (M // bits, bits, K // g, g) -> (M // bits, bits, K // g)
// w = w.transpose(0, 2, 1).reshape(M // bits, bits, K // g, g)
// w = sum([(w[:, :, :, ig] << ig) for ig in range(g)])
memset(buf2, 0, m * k / g);
for (int im = 0; im < m / bits; im++) {
for (int ik = 0; ik < k; ik++) {
for (int ib = 0; ib < bits; ib++) {
int new_im = im;
int new_ib = ib;
int new_ik = ik / g;
int new_ig = ik % g;
buf2[new_im * bits * k / g + new_ib * k / g + new_ik] += buf1[im * k * bits + ik * bits + ib] << new_ig;
}
}
}
// # 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31
// # for bits=3
// # bit0: [0, 8), bit1: [8, 16), bit2: [16, 24), bit0: [24, 32)
// # (M // bits // simd_n_float16, bits, simd_n_float16, K // g)
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
// mgroup = ngroups_per_elem * simd_n_in
// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
memset(qweights, 0, m * k / g / ngroups_per_elem);
for (int im = 0; im < m / bits; im++) {
for (int ib = 0; ib < bits; ib++) {
for (int ik = 0; ik < k / g; ik++) {
int new_im = im / simd_n_out;
int new_isno = im % simd_n_out;
int new_ib = ib;
int new_ik = ik;
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
int new_idx = new_im * bits * simd_n_out * k / g + new_ib * simd_n_out * k / g + new_isno * k / g + new_ik;
// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
int nb2 = k / g;
int nb1 = simd_n_in * nb2;
int nb0 = ngroups_per_elem * nb1;
new_im = new_idx / nb0;
int new_ing = (new_idx % nb0) / nb1;
int new_isni = (new_idx % nb1) / nb2;
new_ik = (new_idx % nb2);
new_idx = new_im * ngroups_per_elem * simd_n_in * k / g + new_isni * ngroups_per_elem * k / g + new_ing * k / g + new_ik;
// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
int nb4 = kfactor;
int nb3 = k / g / kfactor * nb4;
nb2 = ngroups_per_elem * nb3;
nb1 = simd_n_in * nb2;
nb0 = bm / mgroup * nb1;
new_im = new_idx / nb0;
int new_ibm = (new_idx % nb0) / nb1;
new_isni = (new_idx % nb1) / nb2;
new_ing = (new_idx % nb2) / nb3;
new_ik = (new_idx % nb3) / nb4;
int new_ikf = (new_idx % nb4);
new_idx = new_im * k / g / kfactor * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
new_ik * bm / mgroup * kfactor * simd_n_in * ngroups_per_elem +
new_ibm * kfactor * simd_n_in * ngroups_per_elem +
new_ikf * simd_n_in * ngroups_per_elem +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
qweights[new_idx] += buf2[im * bits * k / g + ib * k / g + ik] << (new_ing * g);
}
}
}
const float * i2_scales = (const float * ) ((const uint8_t *) tensor->data + k * m / 8);
if (scales_size < m / bits) { // BitNet-like scale (m_groups,)
for (int i = 0; i < scales_size; i++) {
scales[i] = (tmac_float_type) i2_scales[i];
}
} else { // GPTQ-like scale (m / bits, k / group_size)
GGML_ASSERT(scales_size == m / bits * k / group_size);
// scales = scales.reshape(M // bm, bm // bits, K // group_size).transpose(0, 2, 1)
for (int im = 0; im < m / bits; im += 1) {
for (int ik = 0; ik < k; ik += group_size) {
tmac_float_type scale;
int idx = im * k + ik;
if (tensor->type == GGML_TYPE_Q4_0) {
scale = BlockQ40TypeAccessor::get_scale(tensor->data, idx);
} else if (tensor->type == GGML_TYPE_I2) {
scale = BlockI2TypeAccessor::get_scale(i2_scales, idx, group_size);
} else if (tensor->type == GGML_TYPE_TQ1_0) {
scale = BlockTQ10TypeAccessor::get_scale(tensor->data, idx, group_size);
} else if (tensor->type == GGML_TYPE_TQ2_0) {
scale = BlockTQ20TypeAccessor::get_scale(tensor->data, idx, group_size);
} else {
LOG(FATAL) << "Unsupported type";
}
int new_idx;
idx = idx / group_size;
int new_im = idx / (bm / bits * k / group_size);
int new_ibm = (idx % (bm / bits * k / group_size)) / (k / group_size);
int new_ik = (idx % (k / group_size));
new_idx = new_im * k / group_size * bm / bits + new_ik * bm / bits + new_ibm;
scales[new_idx] = scale;
}
}
}
delete[] buf1;
delete[] buf2;
#else
memset(qweights, 0x88, k * m / 8);
for (int i = 0; i < scales_size; i++) {
scales[i] = 1.0f;
}
#endif
} // if (do_permutate(tensor->type))
}
int ggml_tmac_get_type_bits(enum ggml_type type) {
switch (type) {
case GGML_TYPE_I1:
return 1;
case GGML_TYPE_I2:
return 2;
case GGML_TYPE_I3:
return 3;
case GGML_TYPE_I4:
return 4;
case GGML_TYPE_Q4_0:
return 4;
case GGML_TYPE_TQ1_0:
return 2;
case GGML_TYPE_TQ2_0:
return 2;
default:
return 0;
}
}
void ggml_tmac_set_n_threads(int n_threads) {
wrapper->set_num_threads(n_threads);
}

View file

@ -8,6 +8,10 @@
#include "ggml.h"
#include "ggml-aarch64.h"
#if defined(GGML_USE_TMAC)
#include "ggml-tmac.h"
#endif
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
@ -546,6 +550,30 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I1] = {
.type_name = "i1",
.blck_size = 8,
.type_size = sizeof(int8_t),
.is_quantized = false,
},
[GGML_TYPE_I2] = {
.type_name = "i2",
.blck_size = 4,
.type_size = sizeof(int8_t),
.is_quantized = false,
},
[GGML_TYPE_I3] = {
.type_name = "i3",
.blck_size = 2,
.type_size = sizeof(int8_t),
.is_quantized = false,
},
[GGML_TYPE_I4] = {
.type_name = "i4",
.blck_size = 2,
.type_size = sizeof(int8_t),
.is_quantized = false,
},
[GGML_TYPE_I8] = {
.type_name = "i8",
.blck_size = 1,
@ -1166,6 +1194,14 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
}
}
#if defined(GGML_USE_TMAC)
if(tensor->type == GGML_TYPE_I1 ||
tensor->type == GGML_TYPE_I2 ||
tensor->type == GGML_TYPE_I3 ||
tensor->type == GGML_TYPE_I4){
nbytes = ggml_tmac_get_nbytes(tensor);
}
#endif
return nbytes;
}
@ -1422,6 +1458,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
} u = {i};
ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
}
#if defined(GGML_USE_TMAC)
ggml_tmac_init();
#endif
is_first_call = true;
}

View file

@ -1404,6 +1404,10 @@ class GGMLQuantizationType(IntEnum):
Q4_0_8_8 = 33
TQ1_0 = 34
TQ2_0 = 35
I1 = 36
I2 = 37
I3 = 38
I4 = 39
# TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -1450,6 +1454,7 @@ class LlamaFileType(IntEnum):
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
MOSTLY_TQ1_0 = 36 # except 1d tensors
MOSTLY_TQ2_0 = 37 # except 1d tensors
MOSTLY_INT_N = 38 # except 1d tensors
GUESSED = 1024 # not specified in the model file
@ -1528,6 +1533,15 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
# Currently, we use tricks here
# - The block size doesn't include scales or zero_points as group_size is changeable
# - So the size is slightly smaller than the real size
# - The n_bytes in gguf_reader.py is thus inaccurate
# - During inference, the accurate nbytes info will be known through ggml_tmac_get_nbytes
GGMLQuantizationType.I1: (8, 1),
GGMLQuantizationType.I2: (4, 1),
GGMLQuantizationType.I3: (8, 3),
GGMLQuantizationType.I4: (2, 1),
}

View file

@ -60,6 +60,15 @@ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
return data.astype(np.float16, copy=False)
elif (q := _type_traits.get(qtype)) is not None:
return q.quantize(data)
# Do nothing for I1/2/3/4, as they are already quantized
elif qtype == GGMLQuantizationType.I1:
return data
elif qtype == GGMLQuantizationType.I2:
return data
elif qtype == GGMLQuantizationType.I3:
return data
elif qtype == GGMLQuantizationType.I4:
return data
else:
raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented")

View file

@ -176,6 +176,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
LLAMA_FTYPE_MOSTLY_INT_N = 38,
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};

View file

@ -9,6 +9,10 @@
#include "ggml-backend.h"
#include "ggml-cpp.h"
#ifdef GGML_USE_TMAC
# include "ggml-tmac.h"
#endif
// TODO: replace with ggml API call
#define QK_K 256
@ -4434,6 +4438,10 @@ struct llama_model_loader {
case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break;
case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break;
case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break;
case GGML_TYPE_I1: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break;
case GGML_TYPE_I2: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break;
case GGML_TYPE_I3: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break;
case GGML_TYPE_I4: ftype = LLAMA_FTYPE_MOSTLY_INT_N; break;
default:
{
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@ -5032,6 +5040,11 @@ struct llama_model_loader {
}
size_done += n_size;
#if defined(GGML_USE_TMAC)
// Do pre-transformation to reduce first-run latency
ggml_tmac_transform_tensor(cur);
#endif
}
// free temporary resources used for async uploads
@ -5171,6 +5184,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_INT_N: return "INT_N";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
@ -17183,6 +17197,13 @@ static void llama_graph_compute(
ggml_cgraph * gf,
int n_threads,
ggml_threadpool * threadpool) {
#ifdef GGML_USE_TMAC
#ifdef TMAC_USE_TVM_THREADPOOL
ggml_tmac_set_n_threads(n_threads);
n_threads = 1;
#endif
#endif
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
@ -18747,6 +18768,13 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
new_type = params->output_tensor_type;
}
if (tensor->type == GGML_TYPE_I1 ||
tensor->type == GGML_TYPE_I2 ||
tensor->type == GGML_TYPE_I3 ||
tensor->type == GGML_TYPE_I4) {
// no need quantize for iN
new_type = tensor->type;
}
// If we've decided to quantize to the same type the tensor is already
// in then there's nothing to do.