Merge branch 'master' into concedo_experimental

# Conflicts:
#	.github/ISSUE_TEMPLATE/bug.md
#	Makefile
#	README.md
#	ggml-cuda.cu
#	tests/test-grad0.cpp
This commit is contained in:
Concedo 2023-12-25 18:47:21 +08:00
commit cc64f2cad1
10 changed files with 659 additions and 238 deletions

View file

@ -106,6 +106,8 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif() endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard # 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics # 60 == f16 CUDA intrinsics

View file

@ -184,6 +184,8 @@ class Model:
return MixtralModel return MixtralModel
if model_architecture == "PhiForCausalLM": if model_architecture == "PhiForCausalLM":
return Phi2Model return Phi2Model
if model_architecture == "PlamoForCausalLM":
return PlamoModel
return Model return Model
def _is_model_safetensors(self) -> bool: def _is_model_safetensors(self) -> bool:
@ -225,6 +227,8 @@ class Model:
return gguf.MODEL_ARCH.LLAMA return gguf.MODEL_ARCH.LLAMA
if arch == "PhiForCausalLM": if arch == "PhiForCausalLM":
return gguf.MODEL_ARCH.PHI2 return gguf.MODEL_ARCH.PHI2
if arch == "PlamoForCausalLM":
return gguf.MODEL_ARCH.PLAMO
raise NotImplementedError(f'Architecture "{arch}" not supported!') raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -1002,11 +1006,91 @@ class Phi2Model(Model):
self.gguf_writer.add_add_bos_token(False) self.gguf_writer.add_add_bos_token(False)
class PlamoModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
self.gguf_writer.add_name("PLaMo")
self.gguf_writer.add_context_length(4096) # not in config.json
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
def shuffle_attn_q_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(8, 5, 128, 5120)
data_torch = torch.permute(data_torch, (1, 0, 2, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch
def shuffle_attn_output_weight(self, data_torch):
assert data_torch.size() == (5120, 5120)
data_torch = data_torch.reshape(5120, 8, 5, 128)
data_torch = torch.permute(data_torch, (0, 2, 1, 3))
data_torch = torch.reshape(data_torch, (5120, 5120))
return data_torch
def write_tensors(self):
block_count = self.hparams.get("num_layers", self.hparams.get("num_hidden_layers"))
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
for name, data_torch in self.get_tensors():
if "self_attn.rotary_emb.inv_freq" in name:
continue
# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
# shuffle for broadcasting of gqa in ggml_mul_mat
if new_name.endswith("attn_q.weight"):
data_torch = self.shuffle_attn_q_weight(data_torch)
elif new_name.endswith("attn_output.weight"):
data_torch = self.shuffle_attn_output_weight(data_torch)
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file") parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file")
parser.add_argument( parser.add_argument(
"--vocab-only", action="store_true", "--vocab-only", action="store_true",
help="extract only the vocab", help="extract only the vocab",

View file

@ -297,7 +297,7 @@ static void ggml_backend_registry_init(void) {
void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) { void ggml_backend_register(const char * name, ggml_backend_init_fn init_fn, ggml_backend_buffer_type_t default_buffer_type, void * user_data) {
GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG); GGML_ASSERT(ggml_backend_registry_count < GGML_MAX_BACKENDS_REG);
int id = ggml_backend_registry_count; size_t id = ggml_backend_registry_count;
ggml_backend_registry[id] = (struct ggml_backend_reg) { ggml_backend_registry[id] = (struct ggml_backend_reg) {
/* .name = */ {0}, /* .name = */ {0},
@ -330,6 +330,8 @@ size_t ggml_backend_reg_find_by_name(const char * name) {
return i; return i;
} }
} }
// not found
return SIZE_MAX; return SIZE_MAX;
} }
@ -340,15 +342,15 @@ ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str)
const char * params = strchr(backend_str, ':'); const char * params = strchr(backend_str, ':');
char backend_name[128]; char backend_name[128];
if (params == NULL) { if (params == NULL) {
strcpy(backend_name, backend_str); snprintf(backend_name, sizeof(backend_name), "%s", backend_str);
params = ""; params = "";
} else { } else {
strncpy(backend_name, backend_str, params - backend_str); snprintf(backend_name, sizeof(backend_name), "%.*s", (int)(params - backend_str), backend_str);
backend_name[params - backend_str] = '\0';
params++; params++;
} }
size_t backend_i = ggml_backend_reg_find_by_name(backend_name); size_t backend_i = ggml_backend_reg_find_by_name(backend_name);
if (backend_i == SIZE_MAX) { if (backend_i == SIZE_MAX) {
fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name); fprintf(stderr, "%s: backend %s not found\n", __func__, backend_name);
return NULL; return NULL;
@ -396,18 +398,12 @@ static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
} }
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy((char *)tensor->data + offset, data, size); memcpy((char *)tensor->data + offset, data, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy(data, (const char *)tensor->data + offset, size); memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);

View file

@ -86,17 +86,28 @@
#define cudaStream_t hipStream_t #define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess #define cudaSuccess hipSuccess
#define __trap abort #define __trap abort
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#else #else
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
// CUDA 10.2 does not have these macro definitions.
#ifndef CUBLAS_TF32_TENSOR_OP_MATH #if CUDART_VERSION < 11020
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F #define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F #define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t #define cublasComputeType_t cudaDataType_t
#endif #endif // CUDART_VERSION < 11020
#endif // defined(GGML_USE_HIPBLAS) #endif // defined(GGML_USE_HIPBLAS)
#include "ggml-cuda.h" #include "ggml-cuda.h"
@ -200,45 +211,45 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define CUDA_CHECK(err) \
do { \
cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"CUDA error"); \
} \
} while (0)
#if CUDART_VERSION >= 12000 #if CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) \ static const char * cublas_get_error_str(const cublasStatus_t err) {
do { \ return cublasGetStatusString(err);
cublasStatus_t err_ = (err); \ }
if (err_ != CUBLAS_STATUS_SUCCESS) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"cuBLAS error"); \
} \
} while (0)
#else #else
#define CUBLAS_CHECK(err) \ static const char * cublas_get_error_str(const cublasStatus_t err) {
do { \ switch (err) {
cublasStatus_t err_ = (err); \ case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
if (err_ != CUBLAS_STATUS_SUCCESS) { \ case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
int id; \ case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
cudaGetDevice(&id); \ case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
fprintf(stderr, "current device: %d\n", id); \ case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
GGML_ASSERT(!"cuBLAS error"); \ case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
} \ case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
} while (0) case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
#endif // CUDART_VERSION >= 11 default: return "unknown error";
}
}
#endif // CUDART_VERSION >= 12000
[[noreturn]]
static void ggml_cuda_error(const char * stmt, const char * func, const char * file, const int line, const char * msg) {
fprintf(stderr, "CUDA error: %s: %s\n", stmt, msg);
fprintf(stderr, " in function %s at %s:%d\n", func, file, line);
GGML_ASSERT(!"CUDA error");
}
#define CUDA_CHECK(err) do { auto err_ = (err); if (err_ != cudaSuccess) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cudaGetErrorString(err_)); } while (0)
#define CUBLAS_CHECK(err) do { auto err_ = (err); if (err_ != CUBLAS_STATUS_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cublas_get_error_str(err_)); } while (0)
#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) do { auto err_ = (err); if (err_ != CUDA_SUCCESS) ggml_cuda_error(#err, __func__, __FILE__, __LINE__, cu_get_error_str(err_)); } while (0)
#endif
#if CUDART_VERSION >= 11100 #if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x) #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
@ -516,10 +527,18 @@ inline cudaError_t ggml_cuda_set_device(const int device) {
static int g_device_count = -1; static int g_device_count = -1;
static int g_main_device = 0; static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static bool g_mul_mat_q = false; static bool g_mul_mat_q = false;
struct cuda_device_capabilities {
int cc; // compute capability
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
};
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, false, 0} };
static void * g_scratch_buffer = nullptr; static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_size = 0; // disabled by default
static size_t g_scratch_offset = 0; static size_t g_scratch_offset = 0;
@ -5876,7 +5895,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -5921,7 +5940,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -5966,7 +5985,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6011,7 +6030,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6056,7 +6075,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6101,7 +6120,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6148,7 +6167,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6194,7 +6213,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6239,7 +6258,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6284,7 +6303,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
int mmq_x, mmq_y, nwarps; int mmq_x, mmq_y, nwarps;
if (compute_capability >= CC_RDNA2) { if (compute_capability >= CC_RDNA2) {
@ -6544,65 +6563,73 @@ struct scoped_spin_lock {
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
}; };
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
// #define DEBUG_CUDA_MALLOC
struct cuda_buffer { struct cuda_buffer {
void * ptr = nullptr; void * ptr = nullptr;
size_t size = 0; size_t size = 0;
}; };
static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int best_i = -1; int nnz = 0;
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs size_t max_size = 0;
int worst_i = -1; #endif
size_t worst_size = 0; //largest unused buffer seen so far size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i]; cuda_buffer& b = g_cuda_buffer_pool[id][i];
if (b.size > 0 && b.size >= size && b.size < best_size) if (b.ptr != nullptr) {
{ #ifdef DEBUG_CUDA_MALLOC
best_i = i; ++nnz;
best_size = b.size; if (b.size > max_size) max_size = b.size;
} #endif
if (b.size > 0 && b.size > worst_size) if (b.size >= size) {
{ size_t diff = b.size - size;
worst_i = i; if (diff < best_diff) {
worst_size = b.size; best_diff = diff;
ibest = i;
if (!best_diff) {
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
}
}
} }
} }
if(best_i!=-1) //found the smallest buffer that fits our needs if (ibest >= 0) {
{ cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
cuda_buffer& b = g_cuda_buffer_pool[id][best_i];
void * ptr = b.ptr; void * ptr = b.ptr;
*actual_size = b.size; *actual_size = b.size;
b.ptr = nullptr; b.ptr = nullptr;
b.size = 0; b.size = 0;
return ptr; return ptr;
} }
if(worst_i!=-1 && !g_mul_mat_q) //no buffer that fits our needs, resize largest one to save memory (non mmq only)
{
cuda_buffer& b = g_cuda_buffer_pool[id][worst_i];
b.size = 0;
void * ptr = b.ptr;
cudaFree(ptr);
b.ptr = ptr = nullptr;
}
void * ptr; void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size); size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256); look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size; *actual_size = look_ahead_size;
g_cuda_pool_size[id] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr; return ptr;
} }
static void ggml_cuda_pool_free(void * ptr, size_t size) { static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
@ -6617,8 +6644,152 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
} }
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr)); CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[id] -= size;
} }
#if !defined(GGML_USE_HIPBLAS)
// pool with virtual memory
static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB
static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
// round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
const size_t alignment = 128;
size = alignment * ((size + alignment - 1) / alignment);
size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];
if (size > avail) {
// round up to the next multiple of the granularity
size_t reserve_size = size - avail;
const size_t granularity = g_device_caps[id].vmm_granularity;
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = id;
CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[id] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}
// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));
// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));
// add to the pool
g_cuda_pool_handles[id].push_back(handle);
g_cuda_pool_size[id] += reserve_size;
//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}
GGML_ASSERT(g_cuda_pool_addr[id] != 0);
void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
*actual_size = size;
g_cuda_pool_used[id] += size;
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
#endif
return ptr;
}
static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
#endif
g_cuda_pool_used[id] -= size;
// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
}
static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) {
return ggml_cuda_pool_malloc_vmm(size, actual_size);
} else {
return ggml_cuda_pool_malloc_leg(size, actual_size);
}
}
static void ggml_cuda_pool_free(void * ptr, size_t size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_caps[id].vmm) {
ggml_cuda_pool_free_vmm(ptr, size);
} else {
ggml_cuda_pool_free_leg(ptr, size);
}
}
#else
#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
#endif // !defined(GGML_USE_HIPBLAS)
template<typename T>
struct cuda_pool_alloc {
T * ptr = nullptr;
size_t actual_size = 0;
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(ptr == nullptr);
ptr = (T *) ggml_cuda_pool_malloc(size * sizeof(T), &this->actual_size);
return ptr;
}
cuda_pool_alloc(size_t size) {
alloc(size);
}
~cuda_pool_alloc() {
if (ptr != nullptr) {
ggml_cuda_pool_free(ptr, actual_size);
}
}
T * get() {
return ptr;
}
cuda_pool_alloc() = default;
cuda_pool_alloc(const cuda_pool_alloc &) = delete;
cuda_pool_alloc(cuda_pool_alloc &&) = delete;
cuda_pool_alloc& operator=(const cuda_pool_alloc &) = delete;
cuda_pool_alloc& operator=(cuda_pool_alloc &&) = delete;
};
static bool g_cublas_loaded = false; static bool g_cublas_loaded = false;
bool ggml_cublas_loaded(void) { bool ggml_cublas_loaded(void) {
@ -6649,16 +6820,33 @@ void ggml_init_cublas() {
// fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: %s\n", __func__,"maybe"); // fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: %s\n", __func__,"maybe");
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
int device_vmm = 0;
#if !defined(GGML_USE_HIPBLAS)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
if (device_vmm) {
CUmemAllocationProp alloc_prop = {};
alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
alloc_prop.location.id = id;
CU_CHECK(cuMemGetAllocationGranularity(&g_device_caps[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
}
#endif // !defined(GGML_USE_HIPBLAS)
g_device_caps[id].vmm = !!device_vmm;
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
g_tensor_split[id] = total_vram; g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem; total_vram += prop.totalGlobalMem;
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD; g_device_caps[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
#else #else
g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
} }
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
@ -7167,11 +7355,11 @@ static int64_t get_row_rounding(ggml_type type) {
int64_t max_compute_capability = INT_MIN; int64_t max_compute_capability = INT_MIN;
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { if (g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
if (min_compute_capability > g_compute_capabilities[id]) { if (min_compute_capability > g_device_caps[id].cc) {
min_compute_capability = g_compute_capabilities[id]; min_compute_capability = g_device_caps[id].cc;
} }
if (max_compute_capability < g_compute_capabilities[id]) { if (max_compute_capability < g_device_caps[id].cc) {
max_compute_capability = g_compute_capabilities[id]; max_compute_capability = g_device_caps[id].cc;
} }
} }
} }
@ -7286,8 +7474,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16 #ifdef GGML_CUDA_F16
size_t ash; cuda_pool_alloc<half> src1_dfloat_a;
dfloat * src1_dfloat = nullptr; // dfloat == half half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 = bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
@ -7295,7 +7483,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) { if (src1_convert_f16) {
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); src1_dfloat = src1_dfloat_a.alloc(ne00);
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00, ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
ne00, 1, sizeof(float), 0, 0, ne00, 1, sizeof(float), 0, 0,
ne00, 1, sizeof(half), 0, 0, stream); ne00, 1, sizeof(half), 0, 0, stream);
@ -7343,12 +7531,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
break; break;
} }
#ifdef GGML_CUDA_F16
if (src1_convert_f16) {
ggml_cuda_pool_free(src1_dfloat, ash);
}
#endif // GGML_CUDA_F16
(void) src1; (void) src1;
(void) dst; (void) dst;
(void) src1_ddq_i; (void) src1_ddq_i;
@ -7379,33 +7561,30 @@ inline void ggml_cuda_op_mul_mat_cublas(
// ldc == nrows of the matrix that cuBLAS writes into // ldc == nrows of the matrix that cuBLAS writes into
int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff; int ldc = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
const int compute_capability = g_compute_capabilities[id]; const int compute_capability = g_device_caps[id].cc;
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr; cuda_pool_alloc<half> src0_as_f16;
size_t src0_as = 0;
if (src0->type != GGML_TYPE_F16) { if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = row_diff*ne00; size_t ne = row_diff*ne00;
src0_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src0_as); src0_as_f16.alloc(ne);
to_fp16_cuda(src0_dd_i, src0_as_f16, ne, stream); to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
} }
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16; const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
half * src1_as_f16 = nullptr; cuda_pool_alloc<half> src1_as_f16;
size_t src1_as = 0;
if (src1->type != GGML_TYPE_F16) { if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t ne = src1_ncols*ne10; size_t ne = src1_ncols*ne10;
src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &src1_as); src1_as_f16.alloc(ne);
to_fp16_cuda(src1_ddf_i, src1_as_f16, ne, stream); to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
} }
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16; const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
size_t dst_as = 0; cuda_pool_alloc<half> dst_f16(row_diff*src1_ncols);
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
const half alpha_f16 = 1.0f; const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f; const half beta_f16 = 0.0f;
@ -7414,36 +7593,25 @@ inline void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10, row_diff, src1_ncols, ne10,
&alpha_f16, src0_ptr, CUDA_R_16F, ne00, &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
src1_ptr, CUDA_R_16F, ne10, src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16, CUDA_R_16F, ldc, &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F, CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream); to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
ggml_cuda_pool_free(dst_f16, dst_as);
if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}
} }
else { else {
float * src0_ddq_as_f32 = nullptr; cuda_pool_alloc<float> src0_ddq_as_f32;
size_t src0_as = 0;
if (src0->type != GGML_TYPE_F32) { if (src0->type != GGML_TYPE_F32) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
GGML_ASSERT(to_fp32_cuda != nullptr); GGML_ASSERT(to_fp32_cuda != nullptr);
src0_ddq_as_f32 = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_as); // NOLINT src0_ddq_as_f32.alloc(row_diff*ne00);
to_fp32_cuda(src0_dd_i, src0_ddq_as_f32, row_diff*ne00, stream); to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
} }
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32; const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
const float alpha = 1.0f; const float alpha = 1.0f;
const float beta = 0.0f; const float beta = 0.0f;
@ -7455,10 +7623,6 @@ inline void ggml_cuda_op_mul_mat_cublas(
&alpha, src0_ddf_i, ne00, &alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10, src1_ddf_i, ne10,
&beta, dst_dd_i, ldc)); &beta, dst_dd_i, ldc));
if (src0_as != 0) {
ggml_cuda_pool_free(src0_ddq_as_f32, src0_as);
}
} }
(void) dst; (void) dst;
@ -7750,18 +7914,17 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
float * src1_ddf = nullptr; float * src1_ddf = nullptr;
float * dst_ddf = nullptr; float * dst_ddf = nullptr;
// as = actual size cuda_pool_alloc<float> src0_f;
size_t src0_asf = 0; cuda_pool_alloc<float> src1_f;
size_t src1_asf = 0; cuda_pool_alloc<float> dst_f;
size_t dst_asf = 0;
ggml_cuda_set_device(g_main_device); ggml_cuda_set_device(g_main_device);
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
if (src0_on_device) { if (src0_on_device) {
src0_ddf = (float *) src0_extra->data_device[g_main_device]; src0_ddf = (float *) src0_extra->data_device[g_main_device];
} else { } else {
src0_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_asf); src0_ddf = src0_f.alloc(ggml_nelements(src0));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_ddf, src0, 0, 0, 0, nrows0, main_stream));
} }
@ -7769,14 +7932,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
if (src1_on_device) { if (src1_on_device) {
src1_ddf = (float *) src1_extra->data_device[g_main_device]; src1_ddf = (float *) src1_extra->data_device[g_main_device];
} else { } else {
src1_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(src1), &src1_asf); src1_ddf = src1_f.alloc(ggml_nelements(src1));
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream)); CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src1_ddf, src1, 0, 0, 0, nrows1, main_stream));
} }
} }
if (dst_on_device) { if (dst_on_device) {
dst_ddf = (float *) dst_extra->data_device[g_main_device]; dst_ddf = (float *) dst_extra->data_device[g_main_device];
} else { } else {
dst_ddf = (float *) ggml_cuda_pool_malloc(ggml_nbytes(dst), &dst_asf); dst_ddf = dst_f.alloc(ggml_nelements(dst));
} }
// do the computation // do the computation
@ -7788,16 +7951,6 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
} }
if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
}
if (src1_asf > 0) {
ggml_cuda_pool_free(src1_ddf, src1_asf);
}
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
}
if (dst->backend == GGML_BACKEND_CPU) { if (dst->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaDeviceSynchronize()); CUDA_CHECK(cudaDeviceSynchronize());
} }
@ -8111,17 +8264,17 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(id)); CUDA_CHECK(ggml_cuda_set_device(id));
// free buffers again when done // free buffers again when done
if (src0_as[id] > 0) { if (dst_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]); ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
} }
if (src1_asq[id] > 0) { if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
} }
if (dst_as[id] > 0) { if (src1_asf[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]); ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
} }
} }
@ -8374,14 +8527,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr); GGML_ASSERT(to_fp16_cuda != nullptr);
size_t src1_as = 0; cuda_pool_alloc<half> src1_as_f16(ne1);
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); to_fp16_cuda(src1_ddf, src1_as_f16.get(), ne1, main_stream);
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
size_t dst_as = 0; cuda_pool_alloc<half> dst_f16;
char * dst_t;
half * dst_f16 = nullptr;
char * dst_t = nullptr;
cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F; cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
cudaDataType_t cu_data_type = CUDA_R_16F; cudaDataType_t cu_data_type = CUDA_R_16F;
@ -8400,8 +8550,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
const void * beta = &beta_f16; const void * beta = &beta_f16;
if (dst->op_params[0] == GGML_PREC_DEFAULT) { if (dst->op_params[0] == GGML_PREC_DEFAULT) {
dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); dst_t = (char *) dst_f16.alloc(ne);
dst_t = (char *) dst_f16;
nbd2 /= sizeof(float) / sizeof(half); nbd2 /= sizeof(float) / sizeof(half);
nbd3 /= sizeof(float) / sizeof(half); nbd3 /= sizeof(float) / sizeof(half);
@ -8448,9 +8597,9 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB (const char *) src1_as_f16.get(), CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC beta, ( char *) dst_t, cu_data_type, ne01, dst->nb[2]/sizeof(float), // strideC
ne12*ne13, ne12*ne13,
cu_compute_type, cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@ -8458,19 +8607,13 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
// use cublasGemmBatchedEx // use cublasGemmBatchedEx
const int ne23 = ne12*ne13; const int ne23 = ne12*ne13;
const void ** ptrs_src = nullptr; cuda_pool_alloc<const void *> ptrs_src(2*ne23);
void ** ptrs_dst = nullptr; cuda_pool_alloc< void *> ptrs_dst(1*ne23);
size_t ptrs_src_s = 0;
size_t ptrs_dst_s = 0;
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
dim3 block_dims(ne13, ne12); dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>( k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
src0_as_f16, src1_as_f16, dst_t, src0_as_f16, src1_as_f16.get(), dst_t,
ptrs_src, ptrs_dst, ptrs_src.get(), ptrs_dst.get(),
ne12, ne13, ne12, ne13,
ne23, ne23,
nb02, nb03, nb02, nb03,
@ -8482,30 +8625,19 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
CUBLAS_CHECK( CUBLAS_CHECK(
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10, ne01, ne11, ne10,
alpha, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
beta, ( void **) (ptrs_dst + 0*ne23), cu_data_type, ne01, beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
ne23, ne23,
cu_compute_type, cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}
} }
#endif #endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) { if (dst->op_params[0] == GGML_PREC_DEFAULT) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); to_fp32_cuda(dst_f16.get(), dst_ddf, ne, main_stream);
ggml_cuda_pool_free(dst_f16, dst_as);
} }
ggml_cuda_pool_free(src1_as_f16, src1_as);
} }
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@ -8518,8 +8650,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
int64_t min_compute_capability = INT_MAX; int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) { if (min_compute_capability > g_device_caps[id].cc && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_compute_capabilities[id]; min_compute_capability = g_device_caps[id].cc;
} }
} }
@ -8835,12 +8967,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row); ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
} }
} else { } else {
size_t as_src1, as_dst; cuda_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
char * src1_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(src1), &as_src1); cuda_pool_alloc<char> dst_contiguous(sizeof(float)*ggml_nelements(dst));
char * dst_contiguous = (char *) ggml_cuda_pool_malloc(sizeof(float)*ggml_nelements(dst), &as_dst);
src1_row_extra.data_device[g_main_device] = src1_contiguous; src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
dst_row_extra.data_device[g_main_device] = dst_contiguous; dst_row_extra.data_device[g_main_device] = dst_contiguous.get();
const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ? const cudaMemcpyKind src1_kind = src1->backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice; cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
@ -8860,7 +8991,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as); GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11, CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, src1_kind, stream)); nb11, src1_kind, stream));
num_src1_rows++; num_src1_rows++;
} }
@ -8892,14 +9023,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as); GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1, CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, dst_kind, stream)); nb1, dst_kind, stream));
num_src1_rows++; num_src1_rows++;
} }
} }
ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
} }
if (dst->backend == GGML_BACKEND_CPU) { if (dst->backend == GGML_BACKEND_CPU) {
@ -9674,8 +9802,10 @@ static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buff
static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = ggml_cuda_host_malloc(size); void * ptr = ggml_cuda_host_malloc(size);
if (ptr == nullptr) { if (ptr == nullptr) {
return nullptr; // fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
} }
// FIXME: this is a hack to avoid having to implement a new buffer type // FIXME: this is a hack to avoid having to implement a new buffer type

2
ggml.c
View file

@ -19392,7 +19392,7 @@ void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
} }
gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
free(data); free((void *)data);
} else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
GGML_ASSERT(false && "nested arrays not supported"); GGML_ASSERT(false && "nested arrays not supported");
} else { } else {

2
ggml.h
View file

@ -262,6 +262,8 @@
#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") #define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached")
#elif defined(__GNUC__) #elif defined(__GNUC__)
#define GGML_UNREACHABLE() __builtin_unreachable() #define GGML_UNREACHABLE() __builtin_unreachable()
#elif defined(_MSC_VER)
#define GGML_UNREACHABLE() __assume(0)
#else #else
#define GGML_UNREACHABLE() ((void) 0) #define GGML_UNREACHABLE() ((void) 0)
#endif #endif

View file

@ -96,6 +96,7 @@ class MODEL_ARCH(IntEnum):
STABLELM = auto() STABLELM = auto()
QWEN = auto() QWEN = auto()
PHI2 = auto() PHI2 = auto()
PLAMO = auto()
class MODEL_TENSOR(IntEnum): class MODEL_TENSOR(IntEnum):
@ -142,6 +143,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.STABLELM: "stablelm",
MODEL_ARCH.QWEN: "qwen", MODEL_ARCH.QWEN: "qwen",
MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PLAMO: "plamo",
} }
TENSOR_NAMES: dict[MODEL_TENSOR, str] = { TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -349,6 +351,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
], ],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [ MODEL_ARCH.GPT2: [
# TODO # TODO
], ],

View file

@ -79,6 +79,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon "language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi "model.layers.{bid}.ln1", # yi
"transformer.h.{bid}.ln", # phi2 "transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
), ),
# Attention norm 2 # Attention norm 2
@ -99,26 +100,29 @@ class TensorNameMap:
# Attention query # Attention query
MODEL_TENSOR.ATTN_Q: ( MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf "model.layers.{bid}.self_attn.q_proj", # llama-hf
"layers.{bid}.attention.wq", # llama-pth "layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert "encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j "transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
), ),
# Attention key # Attention key
MODEL_TENSOR.ATTN_K: ( MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf "model.layers.{bid}.self_attn.k_proj", # llama-hf
"layers.{bid}.attention.wk", # llama-pth "layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert "encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
), ),
# Attention value # Attention value
MODEL_TENSOR.ATTN_V: ( MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf "model.layers.{bid}.self_attn.v_proj", # llama-hf
"layers.{bid}.attention.wv", # llama-pth "layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert "encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
), ),
# Attention output # Attention output
@ -134,12 +138,14 @@ class TensorNameMap:
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"transformer.h.{bid}.mixer.out_proj", # phi2 "transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
), ),
# Rotary embeddings # Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: ( MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
), ),
# Feed-forward norm # Feed-forward norm
@ -174,6 +180,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
"transformer.h.{bid}.mlp.fc1", # phi2 "transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo
), ),
MODEL_TENSOR.FFN_UP_EXP: ( MODEL_TENSOR.FFN_UP_EXP: (
@ -186,6 +193,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact "model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth "layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen "transformer.h.{bid}.mlp.w2", # qwen
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
), ),
MODEL_TENSOR.FFN_GATE_EXP: ( MODEL_TENSOR.FFN_GATE_EXP: (
@ -206,6 +214,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc_out", # gpt-j "transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"transformer.h.{bid}.mlp.fc2", # phi2 "transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo
), ),
MODEL_TENSOR.FFN_DOWN_EXP: ( MODEL_TENSOR.FFN_DOWN_EXP: (

View file

@ -394,7 +394,7 @@ maxhordelen = 256
modelbusy = threading.Lock() modelbusy = threading.Lock()
requestsinqueue = 0 requestsinqueue = 0
defaultport = 5001 defaultport = 5001
KcppVersion = "1.53.1" KcppVersion = "1.54"
showdebug = True showdebug = True
showsamplerwarning = True showsamplerwarning = True
showmaxctxwarning = True showmaxctxwarning = True

187
llama.cpp
View file

@ -199,6 +199,7 @@ enum llm_arch {
LLM_ARCH_STABLELM, LLM_ARCH_STABLELM,
LLM_ARCH_QWEN, LLM_ARCH_QWEN,
LLM_ARCH_PHI2, LLM_ARCH_PHI2,
LLM_ARCH_PLAMO,
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
}; };
@ -217,6 +218,7 @@ static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{ LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_STABLELM, "stablelm" },
{ LLM_ARCH_QWEN, "qwen" }, { LLM_ARCH_QWEN, "qwen" },
{ LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PLAMO, "plamo" },
}; };
enum llm_kv { enum llm_kv {
@ -568,6 +570,24 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
}, },
}, },
{
LLM_ARCH_PLAMO,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{ {
LLM_ARCH_UNKNOWN, LLM_ARCH_UNKNOWN,
@ -1286,7 +1306,7 @@ struct llama_hparams {
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
const float EPSILON = 1e-9; const float EPSILON = 1e-9f;
if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true;
if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true;
@ -2760,6 +2780,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_PLAMO:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_13B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0; default: (void)0;
} }
@ -3659,6 +3688,51 @@ static bool llm_load_tensors(
layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend); layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend);
} }
} break; } break;
case LLM_ARCH_PLAMO:
{
model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
// output
{
ggml_backend_type backend_norm;
ggml_backend_type backend_output;
if (n_gpu_layers > int(n_layer)) {
backend_norm = llama_backend_offload;
backend_output = llama_backend_offload_split;
} else {
backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
}
const uint32_t n_ff = hparams.n_ff;
const int i_gpu_start = n_layer - n_gpu_layers;
model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split);
layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split);
layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
}
} break;
default: default:
throw std::runtime_error("unknown architecture"); throw std::runtime_error("unknown architecture");
} }
@ -5584,6 +5658,109 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_plamo() {
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
cb(inpL, "inp_embd", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
struct ggml_tensor * attention_norm = cur;
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, model, hparams, kv_self,
model.layers[il].wo, NULL,
Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
struct ggml_tensor * sa_out = cur;
cur = attention_norm;
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
cur = ggml_add(ctx0, cur, sa_out);
cb(cur, "l_out", il);
cur = ggml_add(ctx0, cur, inpL);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
}; };
// //
@ -6094,6 +6271,10 @@ static struct ggml_cgraph * llama_build_graph(
{ {
result = llm.build_phi2(); result = llm.build_phi2();
} break; } break;
case LLM_ARCH_PLAMO:
{
result = llm.build_plamo();
} break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -10580,7 +10761,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result); llama_unescape_whitespace(result);
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -(int) result.length();
} }
memcpy(buf, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();
@ -10610,7 +10791,7 @@ int llama_token_to_piece(const struct llama_model * model, llama_token token, ch
std::string result = model->vocab.id_to_token[token].text; std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result); result = llama_decode_text(result);
if (length < (int) result.length()) { if (length < (int) result.length()) {
return -result.length(); return -(int) result.length();
} }
memcpy(buf, result.c_str(), result.length()); memcpy(buf, result.c_str(), result.length());
return result.length(); return result.length();