mamba : recurrent inference almost works, but incoherent
This commit is contained in:
parent
5a69a262a1
commit
f680364bd8
3 changed files with 128 additions and 41 deletions
|
@ -1851,13 +1851,57 @@ class MambaModel(Model):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
d_model = self.hparams["d_model"]
|
d_model = self.hparams["d_model"]
|
||||||
self.gguf_writer.add_name(self.dir_model.name)
|
self.gguf_writer.add_name(self.dir_model.name)
|
||||||
|
self.gguf_writer.add_context_length(128) # arbitrary value; it shouldn't be important for Mamba
|
||||||
self.gguf_writer.add_embedding_length(d_model)
|
self.gguf_writer.add_embedding_length(d_model)
|
||||||
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
|
||||||
self.gguf_writer.add_head_count(2 * d_model) # d_inner
|
self.gguf_writer.add_head_count(2 * d_model) # d_inner
|
||||||
|
self.gguf_writer.add_block_count(self.hparams["n_layer"])
|
||||||
|
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
|
||||||
self.gguf_writer.add_key_length(4) # d_conv
|
self.gguf_writer.add_key_length(4) # d_conv
|
||||||
self.gguf_writer.add_value_length(16) # d_state
|
self.gguf_writer.add_value_length(16) # d_state
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
|
||||||
|
def write_tensors(self):
|
||||||
|
block_count = self.hparams["n_layer"]
|
||||||
|
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||||
|
for name, data_torch in self.get_tensors():
|
||||||
|
old_dtype = data_torch.dtype
|
||||||
|
|
||||||
|
# convert any unsupported data types to float32
|
||||||
|
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||||
|
data_torch = data_torch.to(torch.float32)
|
||||||
|
|
||||||
|
# map tensor names
|
||||||
|
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||||
|
if new_name is None:
|
||||||
|
print(f"Can not map tensor {name!r}")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
if name.endswith(".A_log"):
|
||||||
|
print("A_log --> A ==> " + new_name)
|
||||||
|
data_torch = -torch.exp(data_torch)
|
||||||
|
|
||||||
|
data = data_torch.squeeze().numpy()
|
||||||
|
|
||||||
|
n_dims = len(data.shape)
|
||||||
|
data_dtype = data.dtype
|
||||||
|
|
||||||
|
# if f32 desired, convert any float16 to float32
|
||||||
|
if self.ftype == 0 and data_dtype == np.float16:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||||
|
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||||
|
data = data.astype(np.float32)
|
||||||
|
|
||||||
|
# if f16 desired, convert big float32 2-dim weight tensors to float16
|
||||||
|
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
|
||||||
|
data = data.astype(np.float16)
|
||||||
|
|
||||||
|
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||||
|
|
||||||
|
self.gguf_writer.add_tensor(new_name, data)
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
||||||
|
|
4
ggml.c
4
ggml.c
|
@ -5331,7 +5331,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
|
||||||
|
|
||||||
// ggml_soft_plus
|
// ggml_soft_plus
|
||||||
|
|
||||||
struct ggml_tensor * ggml_soft_plus_impl(
|
static struct ggml_tensor * ggml_soft_plus_impl(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
|
@ -15737,7 +15737,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
case GGML_OP_SOFT_PLUS:
|
case GGML_OP_SOFT_PLUS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_soft_plus(params, tensor->src[0], tensor);
|
ggml_compute_forward_soft_plus(params, tensor->src[0], tensor);
|
||||||
}
|
} break;
|
||||||
case GGML_OP_ROPE:
|
case GGML_OP_ROPE:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_rope(params, tensor);
|
ggml_compute_forward_rope(params, tensor);
|
||||||
|
|
119
llama.cpp
119
llama.cpp
|
@ -1765,7 +1765,6 @@ struct llama_layer {
|
||||||
struct ggml_tensor * ffn_up_b; // b3
|
struct ggml_tensor * ffn_up_b; // b3
|
||||||
struct ggml_tensor * ffn_act;
|
struct ggml_tensor * ffn_act;
|
||||||
|
|
||||||
|
|
||||||
// mamba proj
|
// mamba proj
|
||||||
struct ggml_tensor * ssm_in;
|
struct ggml_tensor * ssm_in;
|
||||||
struct ggml_tensor * ssm_x;
|
struct ggml_tensor * ssm_x;
|
||||||
|
@ -3435,6 +3434,7 @@ static void llm_load_hparams(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
{
|
{
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 24:
|
case 24:
|
||||||
switch (hparams.n_embd) {
|
switch (hparams.n_embd) {
|
||||||
|
@ -3455,7 +3455,7 @@ static void llm_load_hparams(
|
||||||
} break;
|
} break;
|
||||||
default: model.type = e_model::MODEL_UNKNOWN;
|
default: model.type = e_model::MODEL_UNKNOWN;
|
||||||
}
|
}
|
||||||
}
|
} break;
|
||||||
default: (void)0;
|
default: (void)0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3939,7 +3939,10 @@ static bool llm_load_tensors(
|
||||||
const int64_t n_vocab_type = hparams.n_vocab_type;
|
const int64_t n_vocab_type = hparams.n_vocab_type;
|
||||||
const int64_t n_ff = hparams.n_ff;
|
const int64_t n_ff = hparams.n_ff;
|
||||||
|
|
||||||
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
// Mamba uses these in its own way
|
||||||
|
if (model.arch != LLM_ARCH_MAMBA) {
|
||||||
|
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
||||||
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
||||||
|
@ -4678,19 +4681,21 @@ static bool llm_load_tensors(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
{
|
{
|
||||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
|
||||||
|
|
||||||
const int64_t d_conv = hparams.n_embd_head_k;
|
const int64_t d_conv = hparams.n_embd_head_k;
|
||||||
const int64_t d_state = hparams.n_embd_head_v;
|
const int64_t d_state = hparams.n_embd_head_v;
|
||||||
const int64_t d_inner = hparams.n_head;
|
const int64_t d_inner = hparams.n_head;
|
||||||
// FIXME: ceiling instead of floor
|
// FIXME: ceiling instead of floor
|
||||||
const int64_t dt_rank = n_embd / 16;
|
const int64_t dt_rank = n_embd / 16;
|
||||||
GGML_ASSERT(2 * n_embd == d_inner);
|
GGML_ASSERT(2 * n_embd == d_inner);
|
||||||
|
// round up the vocab size to the next multiple of 8
|
||||||
|
const int64_t rounded_vocab = (n_vocab + 7) & -8;
|
||||||
|
|
||||||
|
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, rounded_vocab});
|
||||||
|
|
||||||
// output
|
// output
|
||||||
{
|
{
|
||||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, rounded_vocab});
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
|
@ -4707,7 +4712,7 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
|
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
|
||||||
|
|
||||||
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner});
|
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
|
||||||
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
|
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
|
||||||
|
|
||||||
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
|
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
|
||||||
|
@ -4715,9 +4720,9 @@ static bool llm_load_tensors(
|
||||||
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
|
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
|
||||||
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
|
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
|
||||||
|
|
||||||
// FIXME: maybe no suffix for these
|
// no "weight" suffix for these
|
||||||
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner});
|
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
|
||||||
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner});
|
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
|
||||||
|
|
||||||
// out_proj
|
// out_proj
|
||||||
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
|
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
|
||||||
|
@ -7907,16 +7912,18 @@ struct llm_build_context {
|
||||||
return gf;
|
return gf;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph * build_mamba(bool use_conv) {
|
struct ggml_cgraph * build_mamba() {
|
||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||||
|
|
||||||
|
const bool use_conv = batch.n_tokens > 1;
|
||||||
GGML_ASSERT(use_conv == false); // TODO: implement
|
GGML_ASSERT(use_conv == false); // TODO: implement
|
||||||
|
|
||||||
const int64_t d_model = hparams.n_embd;
|
// hopefully the compiler does constant folding
|
||||||
const int64_t d_inner = hparams.n_head;
|
const int64_t d_model = n_embd;
|
||||||
|
const int64_t d_inner = n_head;
|
||||||
GGML_ASSERT(2 * d_model == d_inner);
|
GGML_ASSERT(2 * d_model == d_inner);
|
||||||
const int64_t d_conv = hparams.n_embd_head_k;
|
const int64_t d_conv = n_embd_head_k;
|
||||||
const int64_t d_state = hparams.n_embd_head_v;
|
const int64_t d_state = n_embd_head_v;
|
||||||
const int64_t dt_rank = d_model / 16;
|
const int64_t dt_rank = d_model / 16;
|
||||||
|
|
||||||
struct ggml_tensor * cur;
|
struct ggml_tensor * cur;
|
||||||
|
@ -7928,8 +7935,10 @@ struct llm_build_context {
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
// (ab)using the kv cache to store the state
|
// (ab)using the kv cache to store the state
|
||||||
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner}
|
// NOTE: the conv_state is transposed to ease shifting it.
|
||||||
ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner}
|
// if you figured out a way to shift it without transposing it like this, go ahead and fix this.
|
||||||
|
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv}
|
||||||
|
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
|
||||||
|
|
||||||
// norm
|
// norm
|
||||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
@ -7941,36 +7950,73 @@ struct llm_build_context {
|
||||||
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
|
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
|
||||||
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
|
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
|
||||||
// split the above in two
|
// split the above in two
|
||||||
|
// assuming it's contiguous
|
||||||
|
// FIXME: handle batches of more than 1 token
|
||||||
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
|
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
|
||||||
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner);
|
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner);
|
||||||
|
|
||||||
|
cur = x;
|
||||||
|
|
||||||
// FIXME: figure out when to transpose
|
|
||||||
// conv
|
// conv
|
||||||
{
|
{
|
||||||
// TODO: figure out how to do a row-wise dot product
|
// shift conv state left
|
||||||
// TODO: use the kv-cache to store the state
|
conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0);
|
||||||
kv_self.k_l[il];
|
|
||||||
|
|
||||||
// FIXME: this is wrong
|
// update last column
|
||||||
cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1);
|
conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, model.layers[il].ssm_conv1d_b);
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
|
||||||
|
|
||||||
// TODO: there's some SiLU in there (but no ffn? or is the conv an ffn?)
|
// rearrange and sum
|
||||||
cur = ggml_silu(ctx0, cur);
|
conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv);
|
||||||
|
// TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
|
||||||
|
conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state));
|
||||||
|
|
||||||
|
// --> {1, d_inner}
|
||||||
|
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
|
||||||
|
x = ggml_transpose(ctx0, x);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
|
||||||
|
|
||||||
|
x = ggml_silu(ctx0, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ssm
|
// ssm
|
||||||
{
|
{
|
||||||
|
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
|
||||||
|
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
|
||||||
|
// FIXME: handle batches of more than 1 token
|
||||||
|
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0);
|
||||||
|
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
|
||||||
|
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));
|
||||||
|
|
||||||
// TODO: use ggml_soft_plus here
|
// {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
|
||||||
|
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
|
||||||
|
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b));
|
||||||
|
dt = ggml_soft_plus(ctx0, dt);
|
||||||
|
|
||||||
}
|
// => {d_state, d_inner}
|
||||||
|
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt));
|
||||||
|
|
||||||
// TODO: there's some SiLU again towards the end. Can the `llm_build_ffn` helper be used?
|
// => {d_state, d_inner}
|
||||||
// Maybe the best way is to implement it, _then_ check if that helper would do the same thing.
|
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt));
|
||||||
// discretize
|
|
||||||
{
|
// => {d_state, d_inner}
|
||||||
|
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
|
||||||
|
|
||||||
|
ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il])));
|
||||||
|
|
||||||
|
// row-wise dot product ("dn,n->d")
|
||||||
|
// {d_state, d_inner} * {d_state} => {d_inner, 1}
|
||||||
|
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C);
|
||||||
|
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
|
||||||
|
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
|
||||||
|
|
||||||
|
// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
|
||||||
|
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// residual
|
// residual
|
||||||
|
@ -7981,11 +8027,8 @@ struct llm_build_context {
|
||||||
inpL = cur;
|
inpL = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the last step of each layer already makes these equivalent
|
|
||||||
// cur = inpL;
|
|
||||||
|
|
||||||
// final rmsnorm
|
// final rmsnorm
|
||||||
cur = llm_build_norm(ctx0, cur, hparams,
|
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||||
model.output_norm, NULL,
|
model.output_norm, NULL,
|
||||||
LLM_NORM_RMS, cb, -1);
|
LLM_NORM_RMS, cb, -1);
|
||||||
cb(cur, "result_norm", -1);
|
cb(cur, "result_norm", -1);
|
||||||
|
@ -8150,7 +8193,7 @@ static struct ggml_cgraph * llama_build_graph(
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
case LLM_ARCH_MAMBA:
|
||||||
{
|
{
|
||||||
result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1);
|
result = llm.build_mamba();
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue