Compare commits

...
Sign in to create a new pull request.

8 commits

Author SHA1 Message Date
Georgi Gerganov
a3e6d62283 cuda : alternative q4_q8 kernel 2023-05-12 17:02:39 +03:00
JohannesGaessler
e7b9d97bae More int mult, less float mult, worse performance 2023-05-12 09:11:47 +02:00
JohannesGaessler
d882d1c2fe Performance no longer terrible 2023-05-11 23:27:06 +02:00
JohannesGaessler
4b12881329 WAKE ME UP 2023-05-11 22:47:38 +02:00
JohannesGaessler
8a9d7ce624 fixup! Store layers in VRAM 2023-05-11 07:05:52 +02:00
JohannesGaessler
3ed4588e22 Store layers in VRAM 2023-05-09 11:05:58 +02:00
JohannesGaessler
d052a0ed4c Faster than CPU without 80% runtime memcpy 2023-05-09 09:47:55 +02:00
JohannesGaessler
229aa1f504 Works but slower than CPU 2023-05-09 09:47:55 +02:00
8 changed files with 304 additions and 29 deletions

View file

@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "--mlock") {
params.use_mlock = true;
} else if (arg == "--gpu_layers") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.gpu_layers = std::stoi(argv[i]);
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--mtest") {
@ -406,6 +412,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
if (llama_mmap_supported()) {
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
}
fprintf(stderr, " --gpu_layers number of layers to store in VRAM\n");
fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@ -454,6 +461,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.gpu_layers = params.gpu_layers;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

View file

@ -68,6 +68,7 @@ struct gpt_params {
bool perplexity = false; // compute perplexity over the prompt
bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory
int gpu_layers = 0; // number of layers to store in VRAM
bool mem_test = false; // compute maximum memory usage
bool verbose_prompt = false; // print prompt tokens before generation
};

View file

@ -225,6 +225,141 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const void * vy, float * dst, const int ncols) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const block_q8_0 * y = (const block_q8_0 *) vy;
const int row = blockIdx.x;
const int tid = threadIdx.x;
__shared__ float tmp[block_size]; // separate sum for each thread
tmp[tid] = 0;
for (int i = 0; i < ncols/block_size; i += 4) {
const int col = i*block_size + 4*tid;
// dequantize
const float d0 = x[(row*ncols + col)/QK4_0].d;
const float d1 = y[col/QK8_0].d;
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs;
const int8_t * p1 = y[col/QK8_0].qs;
const uint8_t vui00 = p0[((row*ncols + col)%QK4_0)/2];
const uint8_t vui01 = p0[((row*ncols + col + 2)%QK4_0)/2];
const int vi10 = p1[(col + 0)%QK8_0];
const int vi11 = p1[(col + 1)%QK8_0];
const int vi12 = p1[(col + 2)%QK8_0];
const int vi13 = p1[(col + 3)%QK8_0];
const int vi00 = vui00 & 0xF;
const int vi01 = vui00 >> 4;
const int vi02 = vui01 & 0xF;
const int vi03 = vui01 >> 4;
// matrix multiplication
const int sumi = (vi00 - 8)*vi10 + (vi01 - 8)*vi11 + (vi02 - 8)*vi12 + (vi03 - 8)*vi13;
tmp[tid] += sumi*d0*d1;
}
// sum up partial sums and write back result
for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
__syncthreads();
}
if (tid == 0) {
dst[row] = tmp[0];
}
}
template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const block_q8_0 * y = (const block_q8_0 *) vy;
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ float tmp[NR][NT];
for (int i = 0; i < NR; ++i) {
tmp[i][tid] = 0.0f;
}
const int nbc = (ncols + 16*NT - 1)/(16*NT);
const int nbm = ncols/QK8_0;
uint64_t xa0;
uint64_t xa1;
const int8_t * xb0 = (const int8_t *) &xa0;
const int8_t * xb1 = (const int8_t *) &xa1;
for (int ibc = 0; ibc < nbc; ++ibc) {
const int iyb = (ibc*(16*NT) + 16*tid)/QK8_0;
const int iyq = (ibc*(16*NT) + 16*tid)%QK8_0;
if (iyb >= nbm) {
continue;
}
const int8_t * yb = (const int8_t *) &y[iyb].qs[iyq];
const float dy = y[iyb].d;
for (int ibr = 0; ibr < NR; ++ibr) {
const int ir = bid*NR + ibr;
if (ir >= nrows) {
continue;
}
// block offset
const int ixo = (ir*ncols)/QK4_0 + iyb;
memcpy(&xa0, &x[ixo].qs[iyq/2 + 0], sizeof(uint64_t));
xa1 = xa0;
xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F;
xa1 = (xa1 >> 4) & 0x0F0F0F0F0F0F0F0F;
const float dx = x[ixo].d;
// the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
tmp[ibr][tid] += (
((int)(xb0[0] - 8))*yb[0] + ((int)(xb1[0] - 8))*yb[1] +
((int)(xb0[1] - 8))*yb[2] + ((int)(xb1[1] - 8))*yb[3] +
((int)(xb0[2] - 8))*yb[4] + ((int)(xb1[2] - 8))*yb[5] +
((int)(xb0[3] - 8))*yb[6] + ((int)(xb1[3] - 8))*yb[7] +
((int)(xb0[4] - 8))*yb[8] + ((int)(xb1[4] - 8))*yb[9] +
((int)(xb0[5] - 8))*yb[10] + ((int)(xb1[5] - 8))*yb[11] +
((int)(xb0[6] - 8))*yb[12] + ((int)(xb1[6] - 8))*yb[13] +
((int)(xb0[7] - 8))*yb[14] + ((int)(xb1[7] - 8))*yb[15]
)*dx*dy;
}
}
// reduce
__syncthreads();
for (int s = NT/2; s > 0; s >>= 1) {
if (tid < s) {
for (int ibr = 0; ibr < NR; ++ibr) {
tmp[ibr][tid] += tmp[ibr][tid + s];
}
}
__syncthreads();
}
if (tid == 0) {
for (int ibr = 0; ibr < NR; ++ibr) {
const int ir = bid*NR + ibr;
if (ir < nrows) {
dst[ir] = tmp[ibr][0];
}
}
}
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@ -255,6 +390,28 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
// static int block_size = -1;
// if (block_size == -1) {
// int min_grid_size, max_block_size = 1;
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
// block_size = 1;
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
// block_size *= 2;
// }
// }
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
//const int block_size = 32;
//GGML_ASSERT(ncols % block_size == 0);
//dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
const int NR = 1; // unroll rows (seems to not help)
const int NT = 64; // number of thrads per row
dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1)/NR, NT, 0, stream>>>(vx, y, dst, ncols, nrows);
}
// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
const half * x = (const half *) vx;
@ -290,7 +447,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
}
// buffer pool for cuda
#define MAX_CUDA_BUFFERS 16
#define MAX_CUDA_BUFFERS 256
struct scoped_spin_lock {
std::atomic_flag& lock;
@ -424,6 +581,34 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
}
}
static cudaError_t ggml_cuda_h2d_tensor_2d_hack(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream, void * wdata) {
const uint64_t ne0 = src->ne[0];
const uint64_t ne1 = src->ne[1];
const uint64_t nb0 = src->nb[0];
const uint64_t nb1 = src->nb[1];
const uint64_t nb2 = src->nb[2];
const uint64_t nb3 = src->nb[3];
const enum ggml_type type = src->type;
const size_t ts = ggml_type_size(type);
const size_t bs = ggml_blck_size(type);
const void * x = (const void *) ((const char *) wdata + i2*nb2 + i3*nb3);
if (nb0 == ts && nb1 == ts*ne0/bs) {
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
} else if (nb0 == ts) {
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
} else {
for (uint64_t i1 = 0; i1 < ne1; i1++) {
const void * rx = (const void *) ((const char *) x + i1*nb1);
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
// pretend the row is a matrix with cols=1
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
if (r != cudaSuccess) return r;
}
return cudaSuccess;
}
}
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
@ -575,7 +760,7 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
ggml_cuda_pool_free(d_D, d_size);
}
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
@ -597,7 +782,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
size_t x_size, y_size, d_size, q_size;
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
float * d_X;
if (ne11 > 1) {
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
}
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
@ -612,13 +800,35 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
float * c_X = d_X + i * x_ne;
float * c_Y = d_Y + i * y_ne;
float * c_D = d_D + i * d_ne;
char * c_Q = d_Q + i * q_sz;
// copy src0 and convert to fp32 on device
// copy src0 to device if necessary
if (src0->backend == GGML_BACKEND_CPU) {
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
} else if (src0->backend == GGML_BACKEND_CUDA) {
c_Q = ((char *) src0->data) + i * q_sz;
} else {
GGML_ASSERT(false);
}
if (ne11 == 1) {
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
// copy src1 to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d_hack(c_Y, src1, i03, i02, cudaStream, wdata));
// wait for data
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// compute
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());
} else {
float * c_X = d_X + i * x_ne;
// convert src0 to fp32 on device
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
@ -637,6 +847,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
&alpha, c_X, ne00,
c_Y, ne10,
&beta, c_D, ne01));
}
// copy dst to host
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
@ -645,7 +856,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
}
CUDA_CHECK(cudaDeviceSynchronize());
if (ne11 > 1) {
ggml_cuda_pool_free(d_X, x_size);
}
ggml_cuda_pool_free(d_Y, y_size);
ggml_cuda_pool_free(d_D, d_size);
ggml_cuda_pool_free(d_Q, q_size);
@ -661,8 +874,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
src1->type == GGML_TYPE_F32 &&
dst->type == GGML_TYPE_F32 &&
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
return true;
}
@ -695,11 +907,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
}
else {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
}
}
else if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
}
else {
GGML_ASSERT(false);
@ -714,3 +926,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
return 0;
}
}
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
const int64_t ne0 = tensor->ne[0];
const int64_t ne1 = tensor->ne[1];
const int64_t ne2 = tensor->ne[2];
const int64_t ne3 = tensor->ne[3];
const ggml_type type = tensor->type;
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
size_t q_size;
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
cudaStream_t cudaStream2 = g_cudaStreams2[0];
// copy tensor to device
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
CUDA_CHECK(cudaDeviceSynchronize());
tensor->data = d_Q;
tensor->backend = GGML_BACKEND_CUDA;
}

View file

@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
#ifdef __cplusplus
}
#endif

5
ggml.c
View file

@ -4711,6 +4711,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
*result = (struct ggml_tensor) {
/*.type =*/ type,
/*.backend =*/ GGML_BACKEND_CPU,
/*.n_dims =*/ n_dims,
/*.ne =*/ { 1, 1, 1, 1 },
/*.nb =*/ { 0, 0, 0, 0 },
@ -8810,6 +8811,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
if (ne11 == 1 && ne12 == 1 && ne13 == 1) {
char * wdata = params->wdata;
quantize_row_q_dot((float *)((char *) src1->data), (void *) wdata, ne10);
}
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;

8
ggml.h
View file

@ -243,6 +243,11 @@ extern "C" {
GGML_TYPE_COUNT,
};
enum ggml_backend {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_CUDA = 1,
};
// model file types
enum ggml_ftype {
GGML_FTYPE_UNKNOWN = -1,
@ -323,6 +328,7 @@ extern "C" {
// n-dimensional tensor
struct ggml_tensor {
enum ggml_type type;
enum ggml_backend backend;
int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements
@ -353,7 +359,7 @@ extern "C" {
char name[32];
char padding[8]; // TODO: remove and add padding to name?
char padding[9]; // TODO: remove and add padding to name?
};
// computation graph

View file

@ -9,6 +9,9 @@
#include "llama.h"
#include "ggml.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#endif
#include <array>
#include <ctime>
@ -815,6 +818,7 @@ struct llama_context_params llama_context_default_params() {
/*.vocab_only =*/ false,
/*.use_mmap =*/ true,
/*.use_mlock =*/ false,
/*.gpu_layers =*/ 0,
/*.embedding =*/ false,
/*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr,
@ -877,6 +881,7 @@ static void llama_model_load_internal(
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
int gpu_layers,
bool vocab_only,
llama_progress_callback progress_callback,
void * progress_callback_user_data) {
@ -1011,6 +1016,18 @@ static void llama_model_load_internal(
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
model.mapping = std::move(ml->mapping);
#ifdef GGML_USE_CUBLAS
for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) {
auto & layer = model.layers[i];
ggml_cuda_transform_tensor(layer.wq);
ggml_cuda_transform_tensor(layer.wk);
ggml_cuda_transform_tensor(layer.wv);
ggml_cuda_transform_tensor(layer.wo);
ggml_cuda_transform_tensor(layer.w1);
ggml_cuda_transform_tensor(layer.w2);
ggml_cuda_transform_tensor(layer.w3);
}
#endif
// loading time will be recalculate after the first eval, so
// we take page faults deferred by mmap() into consideration
@ -1024,11 +1041,12 @@ static bool llama_model_load(
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
int gpu_layers,
bool vocab_only,
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers,
vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::string & err) {
@ -2088,7 +2106,7 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
params.use_mmap, params.use_mlock, params.vocab_only,
params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only,
params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);

View file

@ -63,6 +63,7 @@ extern "C" {
bool vocab_only; // only load the vocabulary, no weights
bool use_mmap; // use mmap if possible
bool use_mlock; // force system to keep model in RAM
int gpu_layers; // number of layers to store in VRAM
bool embedding; // embedding mode only
// called with a progress value between 0 and 1, pass NULL to disable