ggml_is_permuted
This commit is contained in:
parent
8d648a34d8
commit
19c0bf5c86
4 changed files with 45 additions and 17 deletions
53
ggml-cuda.cu
53
ggml-cuda.cu
|
@ -1287,10 +1287,25 @@ void ggml_cuda_host_free(void * ptr) {
|
|||
CUDA_CHECK(cudaFreeHost(ptr));
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_h2d_tensor_2d(
|
||||
static cudaError_t ggml_cuda_tensor_2d(
|
||||
void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
|
||||
|
||||
char * dst_char = (char *) dst;
|
||||
cudaMemcpyKind kind;
|
||||
char * src_ptr;
|
||||
if (src->backend == GGML_BACKEND_CPU) {
|
||||
kind = cudaMemcpyHostToDevice;
|
||||
src_ptr = (char *) src->data;
|
||||
} else if (src->backend == GGML_BACKEND_GPU) {
|
||||
kind = cudaMemcpyDeviceToDevice;
|
||||
struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
|
||||
int id;
|
||||
CUDA_CHECK(cudaGetDevice(&id));
|
||||
src_ptr = (char *) extra->data_device[id];
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
char * dst_ptr = (char *) dst;
|
||||
|
||||
const int64_t ne0 = src->ne[0];
|
||||
const int64_t nb0 = src->nb[0];
|
||||
const int64_t nb1 = src->nb[1];
|
||||
|
@ -1301,17 +1316,17 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(
|
|||
const int64_t bs = ggml_blck_size(type);
|
||||
int64_t i1_diff = i1_high - i1_low;
|
||||
|
||||
const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3);
|
||||
const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
|
||||
return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
|
||||
} else if (nb0 == ts) {
|
||||
return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream);
|
||||
return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream);
|
||||
} else {
|
||||
for (int64_t i1 = 0; i1 < i1_diff; i1++) {
|
||||
const void * rx = (const void *) ((const char *) x + i1*nb1);
|
||||
void * rd = (void *) (dst_char + i1*ts*ne0/bs);
|
||||
void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
|
||||
// pretend the row is a matrix with cols=1
|
||||
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
|
||||
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream);
|
||||
if (r != cudaSuccess) return r;
|
||||
}
|
||||
return cudaSuccess;
|
||||
|
@ -1656,8 +1671,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
|
||||
const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
const bool src0_is_contiguous = ggml_is_contiguous(src0);
|
||||
const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
|
||||
|
||||
const bool src1_is_contiguous = use_src1 && ggml_is_contiguous(src1);
|
||||
|
||||
const bool split = src0->backend == GGML_BACKEND_GPU_SPLIT;
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
|
||||
|
@ -1700,7 +1718,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
|
||||
cudaSetDevice(id);
|
||||
|
||||
if (src0_on_device) {
|
||||
if (src0_on_device && src0_is_contiguous) {
|
||||
if (src0_is_f32) {
|
||||
src0_ddf[id] = (float *) src0_extra->data_device[id];
|
||||
} else {
|
||||
|
@ -1719,7 +1737,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
|
||||
if (use_src1) {
|
||||
if (src1_on_device) {
|
||||
if (src1_on_device && src1_is_contiguous) {
|
||||
src1_ddf[id] = (float *) src1_extra->data_device[id];
|
||||
} else {
|
||||
src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
|
||||
|
@ -1797,24 +1815,26 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
// copy src0, src1 to device if necessary
|
||||
if (use_src1) {
|
||||
if (src1->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
|
||||
} else if (src1->backend == GGML_BACKEND_GPU) {
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
|
||||
} else if (src1->backend == GGML_BACKEND_GPU && src1_is_contiguous) {
|
||||
if (id != g_main_device) {
|
||||
float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device];
|
||||
src1_ddf_i_source += i11*src1_stride;
|
||||
CUDA_CHECK(cudaMemcpyAsync(src1_ddf_i, src1_ddf_i_source, src1_stride*sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, cudaStream_memcpy_src1));
|
||||
}
|
||||
} else if (src1_on_device && !src1_is_contiguous) {
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src1_ddf_i, src1, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
|
||||
if (!src0_on_device) {
|
||||
if (!src0_on_device || !src0_is_contiguous) {
|
||||
if (src0_is_f32) {
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
} else {
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
CUDA_CHECK(ggml_cuda_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2027,13 +2047,12 @@ void ggml_cuda_mul_mat_p021_f16_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
}
|
||||
|
||||
void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
if (!ggml_is_contiguous(src0) && !ggml_is_contiguous(src1)) {
|
||||
if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||
ggml_cuda_mul_mat_p021_f16_f32(src0, src1, dst);
|
||||
// ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||
} else {
|
||||
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
|
||||
|
|
6
ggml.c
6
ggml.c
|
@ -3917,6 +3917,12 @@ bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
|
|||
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
|
||||
}
|
||||
|
||||
bool ggml_is_permuted(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
|
||||
}
|
||||
|
||||
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
|
|
1
ggml.h
1
ggml.h
|
@ -478,6 +478,7 @@ extern "C" {
|
|||
|
||||
GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
|
||||
GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
|
||||
|
||||
// use this to compute the memory overhead of a tensor
|
||||
GGML_API size_t ggml_tensor_overhead(void);
|
||||
|
|
|
@ -895,6 +895,7 @@ static bool kv_cache_init(
|
|||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
ggml_cuda_assign_buffers_no_scratch(cache.k);
|
||||
// ggml_cuda_assign_buffers_no_scratch(cache.v);
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
return true;
|
||||
|
@ -1452,6 +1453,7 @@ static bool llama_eval_internal(
|
|||
n_ctx*ggml_element_size(kv_self.v),
|
||||
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||
// offload_func(V);
|
||||
ggml_set_name(V, "V");
|
||||
|
||||
#if 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue