h2d tensors during loadup

This commit is contained in:
niansa 2023-06-23 17:10:45 +02:00
parent 98e588c6eb
commit 46f577bfc1
2 changed files with 86 additions and 12 deletions

View file

@ -378,6 +378,19 @@ void ggml_vk_scale(kp::Sequence& seq,
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {size}, {}, {pushConsts})); seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {size}, {}, {pushConsts}));
} }
void ggml_vk_xxlu(const std::vector<uint32_t>& spirv, kp::Sequence& seq,
const std::shared_ptr<kp::Tensor>& in, uint32_t inOff,
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
uint32_t size) {
struct PushConstants {
uint32_t inOff, outOff;
} pushConsts {
inOff, outOff
};
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {size}, {}, {pushConsts}));
}
static const std::string program_silu = static const std::string program_silu =
MULTILINE_QUOTE( MULTILINE_QUOTE(
@ -398,19 +411,64 @@ void main() {
} }
); );
void ggml_vk_silu(kp::Sequence& seq, template <typename... Args>
const std::shared_ptr<kp::Tensor>& in, uint32_t inOff, void ggml_vk_silu(Args&&... args) {
const std::shared_ptr<kp::Tensor>& out, uint32_t outOff,
uint32_t size) {
const static auto spirv = compileSource(program_source_head+program_silu); const static auto spirv = compileSource(program_source_head+program_silu);
struct PushConstants { ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
uint32_t inOff, outOff; }
} pushConsts {
inOff, outOff
};
seq.record<kp::OpAlgoDispatch>(mgr.algorithm<float, PushConstants>({in, out}, spirv, {size}, {}, {pushConsts}));
static const std::string program_relu =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inOff;
} pcs;
layout(local_size_x = 1) in;
layout(binding = 0) buffer tensorInA { float in_[]; };
layout(binding = 1) buffer tensorOut { float out_[]; };
void main() {
const uint i = gl_GlobalInvocationID.x;
out_[pcs.outOff+i] = max(0.0, in_[pcs.inOff+i]);
}
);
template <typename... Args>
void ggml_vk_relu(Args&&... args) {
const static auto spirv = compileSource(program_source_head+program_relu);
ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
}
static const std::string program_gelu =
MULTILINE_QUOTE(
layout(push_constant) uniform PushConstants {
uint inAOff;
uint inOff;
} pcs;
layout(local_size_x = 1) in;
layout(binding = 0) buffer tensorInA { float in_[]; };
layout(binding = 1) buffer tensorOut { float out_[]; };
void main() {
const uint i = gl_GlobalInvocationID.x;
const float x = in_[pcs.inOff+i];
out_[pcs.outOff+i] = 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
}
);
template <typename... Args>
void ggml_vk_gelu(Args&&... args) {
const static auto spirv = compileSource(program_source_head+program_gelu);
ggml_vk_xxlu(spirv, std::forward<Args>(args)...);
} }
@ -515,6 +573,14 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
{ {
ggml_vk_silu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst)); ggml_vk_silu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst));
} break; } break;
case GGML_OP_RELU:
{
ggml_vk_relu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst));
} break;
case GGML_OP_GELU:
{
ggml_vk_gelu(seq, id_src0, offs_src0, id_dst, offs_dst, ggml_nelements(dst));
} break;
} }
} }
}); });

View file

@ -753,7 +753,7 @@ struct llama_model_loader {
} }
} }
void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { void load_all_data(llama_context & lctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) {
size_t data_size = 0; size_t data_size = 0;
size_t prefetch_size = 0; size_t prefetch_size = 0;
size_t lock_size = 0; size_t lock_size = 0;
@ -810,6 +810,14 @@ struct llama_model_loader {
free(lt.data); free(lt.data);
} }
break; break;
#elif defined(GGML_USE_KOMPUTE)
case GGML_BACKEND_GPU:
lt.ggml_tensor->data = lt.data;
ggml_vk_h2d_tensor(lctx.ctx_kompute, lt.ggml_tensor);
if (!use_mmap) {
free(lt.data);
}
break;
#endif #endif
default: default:
continue; continue;
@ -1315,7 +1323,7 @@ static void llama_model_load_internal(
} }
#endif #endif
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); ml->load_all_data(lctx, progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
if (progress_callback) { if (progress_callback) {
progress_callback(1.0f, progress_callback_user_data); progress_callback(1.0f, progress_callback_user_data);