possibly slower, but cannot use larger batches without modifying ggml library.
This commit is contained in:
parent
bfeb3471d7
commit
ca9a11697c
2 changed files with 53 additions and 11 deletions
|
@ -432,10 +432,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
|
|||
{
|
||||
rwkv_ctx_v3 = rwkv_init_from_file(modelname.c_str(), n_threads);
|
||||
|
||||
// if(inputs.gpulayers>0)
|
||||
// {
|
||||
// rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
|
||||
// }
|
||||
if(inputs.gpulayers>0)
|
||||
{
|
||||
rwkv_gpu_offload_layers(rwkv_ctx_v3,inputs.gpulayers);
|
||||
}
|
||||
|
||||
const struct rwkv_file_header & header = rwkv_ctx_v3->instance->model.header;
|
||||
const size_t n_vocab = header.n_vocab;
|
||||
|
@ -1066,15 +1066,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
|
|||
}
|
||||
else
|
||||
{
|
||||
// if(embd.size()>1)
|
||||
// {
|
||||
// evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
if(embd.size()>1)
|
||||
{
|
||||
evalres = rwkv_eval_sequence(rwkv_ctx_v3, (uint32_t*)embd.data(), embd.size(), rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, rwkv_ctx_v3->logits_out);
|
||||
}
|
||||
else
|
||||
{
|
||||
bool ignoreLogits = (!startedsampling && ((int)embd_inp.size() > input_consumed + 2));
|
||||
evalres = rwkv_eval(rwkv_ctx_v3, embd[0], rwkv_ctx_v3->state_in, rwkv_ctx_v3->state_out, ignoreLogits?nullptr:rwkv_ctx_v3->logits_out);
|
||||
//}
|
||||
}
|
||||
|
||||
memcpy(logits.data(), rwkv_ctx_v3->logits_out, sizeof(float) * rwkv_vocab.size());
|
||||
rwkv_ctx_v3->state_in = rwkv_ctx_v3->state_out;
|
||||
|
|
|
@ -6,6 +6,13 @@
|
|||
#include "rwkv_v3.h"
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#endif
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
#include "ggml-opencl.h"
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
@ -1058,7 +1065,11 @@ struct rwkv_future_tensor rwkv_future_graph_work(struct rwkv_future_ctx & ctx,
|
|||
const size_t n_threads,
|
||||
const size_t sequence_len = 1
|
||||
) {
|
||||
#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS)
|
||||
enum ggml_type mul_mat_type = type == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16;
|
||||
#else
|
||||
enum ggml_type mul_mat_type = ggml_is_quantized(type) ? GGML_TYPE_Q8_1 : type;
|
||||
#endif
|
||||
return ctx.alloc(GGML_TYPE_I8, rwkv_future_tensor::size(mul_mat_type, ffn_key_height, sequence_len) * n_threads + 64 * (n_threads - 1));
|
||||
}
|
||||
|
||||
|
@ -1545,7 +1556,38 @@ struct rwkv_context * rwkv_clone_context(struct rwkv_context * ctx, const uint32
|
|||
}
|
||||
|
||||
bool rwkv_gpu_offload_layers(struct rwkv_context * ctx, const uint32_t n_layers) {
|
||||
#if defined(GGML_USE_CLBLAST) || defined(GGML_USE_CUBLAS)
|
||||
printf("\nOffloading %u (or fewer) layers...",n_layers);
|
||||
const auto offload = [&](struct ggml_tensor * tensor) {
|
||||
// TODO support multi-GPU
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
ggml_cl_transform_tensor(tensor->data, tensor);
|
||||
#else
|
||||
ggml_cuda_transform_tensor(tensor->data, tensor);
|
||||
#endif
|
||||
};
|
||||
|
||||
const size_t n_gpu = std::min(n_layers, ctx->instance->model.header.n_layer);
|
||||
|
||||
if (ctx->gpu_layers < n_gpu) {
|
||||
for (size_t & i = ctx->gpu_layers; i < n_gpu; i++) {
|
||||
const struct rwkv_layer & layer = ctx->instance->model.layers[i];
|
||||
|
||||
// TODO also offload other operations to GPU with ggml_cuda_assign_buffers
|
||||
offload(layer.att_key);
|
||||
offload(layer.att_value);
|
||||
offload(layer.att_receptance);
|
||||
offload(layer.att_output);
|
||||
|
||||
offload(layer.ffn_key);
|
||||
offload(layer.ffn_value);
|
||||
offload(layer.ffn_receptance);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue