cuBLAS: non-contiguous tensor support (#1215)
* Cuda: non-contiguous tensor support * remove extra stuff * rename * fix error * more fixes, now OpenBLAS and CLBlast build too * now then?
This commit is contained in:
parent
36d19a603b
commit
b1ee8f59b4
3 changed files with 44 additions and 11 deletions
28
ggml-cuda.cu
28
ggml-cuda.cu
|
@ -302,3 +302,31 @@ void ggml_init_cublas(void) {
|
|||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
|
||||
const uint64_t ne0 = src->ne[0];
|
||||
const uint64_t ne1 = src->ne[1];
|
||||
const uint64_t nb0 = src->nb[0];
|
||||
const uint64_t nb1 = src->nb[1];
|
||||
const uint64_t nb2 = src->nb[2];
|
||||
const uint64_t nb3 = src->nb[3];
|
||||
const enum ggml_type type = src->type;
|
||||
const size_t ts = ggml_type_size(type);
|
||||
const size_t bs = ggml_blck_size(type);
|
||||
|
||||
const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
|
||||
if (nb0 == ts && nb1 == ts*ne0/bs) {
|
||||
return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
|
||||
} else if (nb0 == ts) {
|
||||
return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
|
||||
} else {
|
||||
for (uint64_t i1 = 0; i1 < ne1; i1++) {
|
||||
const void * rx = (const void *) ((const char *) x + i1*nb1);
|
||||
void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
|
||||
// pretend the row is a matrix with cols=1
|
||||
cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
|
||||
if (r != cudaSuccess) return r;
|
||||
}
|
||||
return cudaSuccess;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue