cuda : loading models directly into VRAM, norm calculation on GPU, broadcasting for ggml_mul (#1483)
* Broadcasting for ggml_mul * CUDA kernel for ggml_mul, norms in VRAM * GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * define default model path once, sync path with readme (#1366) * ~7% faster Q5_1 AVX2 code (#1477) * convert.py: Support models which are stored in a single pytorch_model.bin (#1469) * Support models in a single pytorch_model.bin * Remove spurious line with typo * benchmark-matmul: Print the average of the test results (#1490) * Remove unused n_parts parameter (#1509) * Fixes #1511 lambda issue for w64devkit (mingw) (#1513) * Fix for w64devkit and mingw * make kv_f16 the default for api users (#1517) * minor : fix compile warnings * readme : adds WizardLM to the list of supported models (#1485) * main : make reverse prompt option act as a stop token in non-interactive mode (#1032) * Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit2bb2ff1748
. * Update gpt_params_parse and fix a merge error take 2 * examples : add persistent chat (#1495) * examples : add persistent chat * examples : fix whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * tests : add missing header * ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508) * ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics * ggml : fix scalar implementation of Q4_1 dot * llama : fix compile warnings in llama_set_state_data() * llama : fix name shadowing and C4146 (#1526) * Fix name shadowing and C4146 * Fix if macros not using defined when required * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Code style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix for mingw (#1462) * llama : add llama_init_backend() API (close #1527) * feature : add blis and other BLAS implementation support (#1502) * feature: add blis support * feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927 * fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake * Fix typo in INTEGER Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Revert "feature : add blis and other BLAS implementation support (#1502)" This reverts commit07e9ace0f9
. * GPU weights not in RAM, direct loading with cuFile * llama : code style fixes + progress print fix * ggml : ggml_mul better broadcast support * cmake : workarounds for cufile when CMake version < 3.25 * gg rebase fixup * Loop in llama.cpp, fixed progress callback * Attempt clang-tidy fix * llama : fix vram size computation * Add forgotten fclose() --------- Co-authored-by: András Salamon <ott2@users.noreply.github.com> Co-authored-by: Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Co-authored-by: Tom Jobbins <784313+TheBloke@users.noreply.github.com> Co-authored-by: rankaiyx <rankaiyx@rankaiyx.com> Co-authored-by: Stephan Walter <stephan@walter.name> Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com> Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: David Kennedy <dakennedyd@gmail.com> Co-authored-by: Jason McCartney <jmac@theroot.org> Co-authored-by: Evan Jones <evan.q.jones@gmail.com> Co-authored-by: Maxime <672982+maximegmd@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zenix <zenixls2@gmail.com>
This commit is contained in:
parent
ea600071cb
commit
affc76edfd
5 changed files with 304 additions and 116 deletions
123
ggml-cuda.cu
123
ggml-cuda.cu
|
@ -83,9 +83,19 @@ typedef struct {
|
|||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
|
||||
|
||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= kx) {
|
||||
return;
|
||||
}
|
||||
dst[i] = x[i] * y[i%ky];
|
||||
}
|
||||
|
||||
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
|
@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
|||
}
|
||||
}
|
||||
|
||||
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
||||
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
||||
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
|
@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[2];
|
||||
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
size_t x_size, d_size;
|
||||
|
||||
float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
|
||||
float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
|
||||
float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
const int i0 = i03*ne02 + i02;
|
||||
float * c_X2 = d_X + i0*ne01*ne00;
|
||||
float * c_D2 = d_D + i0*ne01*ne00;
|
||||
|
||||
cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
|
||||
|
||||
// copy src0 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// wait for data
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
||||
const int64_t i13 = i03%ne13;
|
||||
const int64_t i12 = i02%ne12;
|
||||
const int64_t i11 = i01%ne11;
|
||||
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
|
||||
|
||||
float * c_X1 = c_X2 + i01*ne00;
|
||||
float * c_Y = d_Y + i1*ne10;
|
||||
float * c_D1 = c_D2 + i01*ne00;
|
||||
|
||||
// compute
|
||||
mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
|
@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
ggml_cuda_pool_free(d_Q, q_size);
|
||||
}
|
||||
|
||||
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
|
||||
ggml_cuda_mul_f32(src0, src1, dst);
|
||||
}
|
||||
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
|
||||
|
@ -797,14 +878,48 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
|
|||
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
||||
|
||||
size_t q_size;
|
||||
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
|
||||
char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
|
||||
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[0];
|
||||
|
||||
// copy tensor to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
||||
for (int64_t i2 = 0; i2 < ne2; i2++) {
|
||||
int i = i3*ne2 + i2;
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
|
||||
}
|
||||
}
|
||||
|
||||
tensor->data = d_Q;
|
||||
tensor->data = dst;
|
||||
tensor->backend = GGML_BACKEND_CUDA;
|
||||
}
|
||||
|
||||
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
|
||||
FILE * fp = fopen(fname, "rb");
|
||||
|
||||
const size_t size = ggml_nbytes(tensor);
|
||||
|
||||
void * buf;
|
||||
CUDA_CHECK(cudaMalloc(&buf, size));
|
||||
void * buf_host = malloc(size);
|
||||
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
|
||||
#else
|
||||
int ret = fseek(fp, (long) offset, SEEK_SET);
|
||||
#endif
|
||||
GGML_ASSERT(ret == 0); // same
|
||||
|
||||
size_t ret2 = fread(buf_host, size, 1, fp);
|
||||
if (ret2 != 1) {
|
||||
fprintf(stderr, "unexpectedly reached end of file");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
tensor->data = buf;
|
||||
free(buf_host);
|
||||
fclose(fp);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue