Merge branch 'master' into concedo_experimental
# Conflicts: # README.md
This commit is contained in:
commit
b85ea580d3
11 changed files with 166 additions and 50 deletions
|
@ -167,11 +167,6 @@ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES
|
|||
if (MSVC)
|
||||
# TODO: arm msvc?
|
||||
else()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||
# Apple M1, M2, etc.
|
||||
# Raspberry Pi 3, 4, Zero 2 (64-bit)
|
||||
add_compile_options(-mcpu=native)
|
||||
endif()
|
||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
|
||||
# Raspberry Pi 1, Zero
|
||||
add_compile_options(-mfpu=neon-fp-armv8 -mfp16-format=ieee -mno-unaligned-access)
|
||||
|
|
37
convert.py
37
convert.py
|
@ -136,7 +136,7 @@ def find_n_mult(n_ff: int, n_embd: int) -> int:
|
|||
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
|
||||
if calc_ff == n_ff:
|
||||
return n_mult
|
||||
return 1
|
||||
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
|
||||
|
||||
@dataclass
|
||||
class Params:
|
||||
|
@ -321,6 +321,10 @@ class Tensor(metaclass=ABCMeta):
|
|||
@abstractmethod
|
||||
def permute(self, n_head: int) -> 'Tensor': ...
|
||||
@abstractmethod
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
def part(self, n_part: int) -> 'UnquantizedTensor': ...
|
||||
@abstractmethod
|
||||
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
|
||||
|
||||
|
||||
|
@ -345,6 +349,14 @@ class UnquantizedTensor(Tensor):
|
|||
def to_ggml(self) -> 'UnquantizedTensor':
|
||||
return self
|
||||
|
||||
def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor':
|
||||
r = self.ndarray.shape[0] // 3
|
||||
return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head))
|
||||
|
||||
def part(self, n_part: int) -> 'UnquantizedTensor':
|
||||
r = self.ndarray.shape[0] // 3
|
||||
return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...])
|
||||
|
||||
def permute(self, n_head: int) -> 'UnquantizedTensor':
|
||||
return UnquantizedTensor(permute(self.ndarray, n_head))
|
||||
|
||||
|
@ -642,6 +654,19 @@ def permute_lazy(lazy_tensor: LazyTensor, n_head: int) -> LazyTensor:
|
|||
return lazy_tensor.load().permute(n_head)
|
||||
return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
|
||||
def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
return lazy_tensor.load().permute_part(n_part, n_head)
|
||||
s = lazy_tensor.shape.copy()
|
||||
s[0] = s[0] // 3
|
||||
return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}) ' + lazy_tensor.description)
|
||||
|
||||
def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
return lazy_tensor.load().part(n_part)
|
||||
s = lazy_tensor.shape.copy()
|
||||
s[0] = s[0] // 3
|
||||
return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description)
|
||||
|
||||
def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
||||
out: LazyModel = {}
|
||||
|
@ -650,11 +675,17 @@ def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel:
|
|||
out["output.weight"] = model["lm_head.weight"]
|
||||
|
||||
for i in itertools.count():
|
||||
if f"model.layers.{i}.self_attn.q_proj.weight" not in model:
|
||||
break
|
||||
if f"model.layers.{i}.self_attn.q_proj.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head)
|
||||
out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"]
|
||||
elif f"model.layers.{i}.self_attn.W_pack.weight" in model:
|
||||
out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head)
|
||||
out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head)
|
||||
out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2)
|
||||
else:
|
||||
break
|
||||
|
||||
out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"]
|
||||
|
||||
out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"]
|
||||
|
|
|
@ -210,9 +210,12 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
|||
const char * sampling(struct MyModel * mymodel) {
|
||||
llama_context * ctx = mymodel->ctx;
|
||||
int id = sampling_id(mymodel);
|
||||
std::string ret;
|
||||
if (id == llama_token_eos()) ret = "</s>";
|
||||
else ret = llama_token_to_str(ctx, id);
|
||||
static std::string ret;
|
||||
if (id == llama_token_eos()) {
|
||||
ret = "</s>";
|
||||
} else {
|
||||
ret = llama_token_to_str(ctx, id);
|
||||
}
|
||||
eval_id(mymodel, id);
|
||||
return ret.c_str();
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include "llama.h"
|
||||
#include "build-info.h"
|
||||
|
||||
|
||||
extern "C" {
|
||||
|
||||
typedef struct MyModel {
|
||||
|
@ -14,14 +13,13 @@ typedef struct MyModel {
|
|||
int n_past = 0;
|
||||
} MyModel;
|
||||
|
||||
|
||||
struct MyModel* create_mymodel(int argc, char ** argv);
|
||||
|
||||
bool eval_float(void* model, float* input, int N);
|
||||
bool eval_tokens(void* model, std::vector<llama_token> tokens);
|
||||
bool eval_id(struct MyModel* mymodel, int id);
|
||||
bool eval_string(struct MyModel* mymodel, const char* str);
|
||||
const char* sampling(struct MyModel* mymodel);
|
||||
const char * sampling(struct MyModel* mymodel);
|
||||
llama_token sampling_id(struct MyModel* mymodel);
|
||||
void free_mymodel(struct MyModel* mymodel);
|
||||
|
||||
|
|
|
@ -2671,7 +2671,8 @@ struct train_params {
|
|||
const char * fn_checkpoint_out;
|
||||
const char * fn_model_out;
|
||||
|
||||
int seed;
|
||||
uint32_t seed;
|
||||
|
||||
int n_ctx;
|
||||
int n_embd;
|
||||
int n_mult;
|
||||
|
|
63
ggml-cuda.cu
63
ggml-cuda.cu
|
@ -214,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
|
||||
};
|
||||
|
||||
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
|
@ -1995,7 +2000,6 @@ inline void ggml_cuda_op_add(
|
|||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -2027,7 +2031,6 @@ inline void ggml_cuda_op_mul(
|
|||
|
||||
// compute
|
||||
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
(void) dst;
|
||||
|
@ -2048,7 +2051,6 @@ inline void ggml_cuda_op_silu(
|
|||
|
||||
// compute
|
||||
silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -2071,7 +2073,6 @@ inline void ggml_cuda_op_rms_norm(
|
|||
|
||||
// compute
|
||||
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -2150,7 +2151,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|||
GGML_ASSERT(false);
|
||||
break;
|
||||
}
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
#ifdef GGML_CUDA_DMMV_F16
|
||||
if (src1_convert_f16) {
|
||||
|
@ -2230,7 +2230,6 @@ inline void ggml_cuda_op_rope(
|
|||
|
||||
// compute
|
||||
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
|
@ -2254,7 +2253,6 @@ inline void ggml_cuda_op_diag_mask_inf(
|
|||
|
||||
// compute
|
||||
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) dst;
|
||||
(void) src0_ddq_i;
|
||||
|
@ -2276,7 +2274,6 @@ inline void ggml_cuda_op_soft_max(
|
|||
|
||||
// compute
|
||||
soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
(void) src1;
|
||||
(void) dst;
|
||||
|
@ -2372,10 +2369,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
||||
|
||||
// if multiple GPUs are used they need to wait for the main GPU to finish
|
||||
// if multiple devices are used they need to wait for the main device
|
||||
// here an event is recorded that signifies that the main device has finished calculating the input data
|
||||
if (split && g_device_count > 1) {
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
|
||||
}
|
||||
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
|
@ -2401,6 +2399,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
int64_t row_diff = row_high - row_low;
|
||||
|
||||
cudaSetDevice(id);
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
||||
|
||||
// wait for main GPU data if necessary
|
||||
if (split && id != g_main_device) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
|
||||
}
|
||||
|
||||
if (src0_on_device && src0_is_contiguous) {
|
||||
if (src0_is_f32) {
|
||||
|
@ -2476,8 +2480,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
const int64_t i11 = i13*ne12 + i12;
|
||||
|
||||
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
||||
|
||||
// for split tensors the data begins at i0 == i0_offset_low
|
||||
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
||||
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
|
||||
|
@ -2537,6 +2539,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
|
||||
// do the computation
|
||||
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
// copy dst to host or other device if necessary
|
||||
if (!dst_on_device) {
|
||||
|
@ -2566,6 +2569,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
|
||||
}
|
||||
}
|
||||
|
||||
// signify to main device that other device is done
|
||||
if (split && g_device_count > 1 && id != g_main_device) {
|
||||
CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2577,7 +2585,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
}
|
||||
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
|
||||
if (src0_asq[id] > 0) {
|
||||
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
||||
|
@ -2592,6 +2599,21 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|||
ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
|
||||
}
|
||||
}
|
||||
|
||||
// main device waits for all other devices to be finished
|
||||
if (split && g_device_count > 1) {
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
if (id != g_main_device) {
|
||||
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dst->backend == GGML_BACKEND_CPU) {
|
||||
CUDA_CHECK(cudaSetDevice(g_main_device));
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
|
@ -2831,6 +2853,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|||
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
|
||||
|
||||
extra->data_device[id] = buf;
|
||||
|
||||
if (backend == GGML_BACKEND_GPU_SPLIT) {
|
||||
CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
|
||||
}
|
||||
}
|
||||
|
||||
tensor->extra = extra;
|
||||
|
@ -2844,14 +2870,17 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|||
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
||||
|
||||
for (int id = 0; id < g_device_count; ++id) {
|
||||
if (extra->data_device[id] == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (extra->data_device[id] != nullptr) {
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaFree(extra->data_device[id]));
|
||||
}
|
||||
|
||||
if (extra->events[id] != nullptr) {
|
||||
CUDA_CHECK(cudaSetDevice(id));
|
||||
CUDA_CHECK(cudaEventDestroy(extra->events[id]));
|
||||
}
|
||||
}
|
||||
|
||||
delete extra;
|
||||
}
|
||||
|
||||
|
|
|
@ -8,10 +8,6 @@ extern "C" {
|
|||
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||
};
|
||||
|
||||
void ggml_init_cublas(void);
|
||||
void ggml_cuda_set_tensor_split(const float * tensor_split);
|
||||
|
||||
|
|
|
@ -202,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|||
|
||||
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
||||
fprintf(stderr, "%s: deallocating\n", __func__);
|
||||
|
||||
for (int i = 0; i < ctx->n_buffers; ++i) {
|
||||
[ctx->buffers[i].metal release];
|
||||
}
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
|
|
47
ggml.c
47
ggml.c
|
@ -3846,6 +3846,40 @@ static_assert(GGML_OP_COUNT == 64, "GGML_OP_COUNT != 64");
|
|||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||
|
||||
// WARN:
|
||||
// Mis-confguration can lead to problem that's hard to reason about:
|
||||
// * At best it crash or talks nosense.
|
||||
// * At worst it talks slightly difference but hard to perceive.
|
||||
//
|
||||
// An op has to enable INIT or FINALIZE when any of it's branch needs that pass.
|
||||
// Take care about compile options (e.g., GGML_USE_xxx).
|
||||
static bool GGML_OP_HAS_INIT [GGML_OP_COUNT] = { 0 };
|
||||
static bool GGML_OP_HAS_FINALIZE[GGML_OP_COUNT] = { 0 };
|
||||
static void ggml_setup_op_has_task_pass(void) {
|
||||
{ // INIT
|
||||
bool * I = GGML_OP_HAS_INIT;
|
||||
|
||||
I[GGML_OP_ACC ] = true;
|
||||
I[GGML_OP_MUL_MAT ] = true;
|
||||
I[GGML_OP_OUT_PROD ] = true;
|
||||
I[GGML_OP_SET ] = true;
|
||||
I[GGML_OP_GET_ROWS_BACK ] = true;
|
||||
I[GGML_OP_DIAG_MASK_INF ] = true;
|
||||
I[GGML_OP_DIAG_MASK_ZERO ] = true;
|
||||
I[GGML_OP_CONV_1D_S1_PH ] = true;
|
||||
I[GGML_OP_CONV_1D_S2_PH ] = true;
|
||||
I[GGML_OP_CONV_2D_SK_P0 ] = true;
|
||||
I[GGML_OP_FLASH_ATTN_BACK ] = true;
|
||||
I[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
||||
}
|
||||
|
||||
{ // FINALIZE
|
||||
bool * F = GGML_OP_HAS_FINALIZE;
|
||||
|
||||
F[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// ggml context
|
||||
//
|
||||
|
@ -4267,6 +4301,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
ggml_cl_init();
|
||||
#endif
|
||||
|
||||
ggml_setup_op_has_task_pass();
|
||||
|
||||
is_first_call = false;
|
||||
}
|
||||
|
||||
|
@ -16805,10 +16841,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||
if (node_n != -1) {
|
||||
/* FINALIZE */
|
||||
struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
|
||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||
params.nth = node->n_tasks;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
||||
}
|
||||
}
|
||||
|
||||
// distribute new work or execute it direct if 1T
|
||||
while (++node_n < cgraph->n_nodes) {
|
||||
|
@ -16819,10 +16857,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||
state->shared->perf_node_start_cycles = ggml_perf_cycles();
|
||||
state->shared->perf_node_start_time_us = ggml_perf_time_us();
|
||||
|
||||
/* INIT */
|
||||
params.type = GGML_TASK_INIT;
|
||||
params.nth = node->n_tasks;
|
||||
|
||||
/* INIT */
|
||||
if (GGML_OP_HAS_INIT[node->op]) {
|
||||
params.type = GGML_TASK_INIT;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
}
|
||||
|
||||
if (node->n_tasks == 1) {
|
||||
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
|
||||
|
@ -16830,9 +16871,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
|||
params.type = GGML_TASK_COMPUTE;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
|
||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||
params.type = GGML_TASK_FINALIZE;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
|
|
3
ggml.h
3
ggml.h
|
@ -450,6 +450,9 @@ extern "C" {
|
|||
|
||||
|
||||
// compute types
|
||||
|
||||
// NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
|
||||
// This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
|
||||
enum ggml_task_type {
|
||||
GGML_TASK_INIT = 0,
|
||||
GGML_TASK_COMPUTE,
|
||||
|
|
19
llama.cpp
19
llama.cpp
|
@ -283,7 +283,13 @@ struct llama_model {
|
|||
|
||||
struct llama_context {
|
||||
llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
|
||||
|
||||
#ifdef GGML_USE_METAL
|
||||
~llama_context() {
|
||||
if (ctx_metal) {
|
||||
ggml_metal_free(ctx_metal);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
std::mt19937 rng;
|
||||
|
||||
bool has_evaluated_once = false;
|
||||
|
@ -3252,7 +3258,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||
return nread;
|
||||
}
|
||||
|
||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(path_session, "rb");
|
||||
|
||||
// sanity checks
|
||||
|
@ -3306,6 +3312,15 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
try {
|
||||
return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "error loading session file: %s\n", err.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(path_session, "wb");
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue