ggml : avoid multiply by D in GGML_OP_SSM_SCAN
This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
This commit is contained in:
parent
7d16e1bc8c
commit
3bc7103d2e
7 changed files with 98 additions and 95 deletions
|
@ -264,6 +264,12 @@ class Model:
|
|||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
# TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that)
|
||||
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
|
||||
del new_name, bid # unused
|
||||
|
||||
return data_torch.squeeze()
|
||||
|
||||
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
|
||||
del name, new_name, bid, n_dims # unused
|
||||
|
||||
|
@ -295,7 +301,7 @@ class Model:
|
|||
break
|
||||
|
||||
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
|
||||
data = data_torch.squeeze().numpy()
|
||||
data = self.reshape_tensors(data_torch, new_name, bid).numpy()
|
||||
|
||||
# if data ends up empty, it means data_torch was a scalar tensor -> restore
|
||||
if len(data.shape) == 0:
|
||||
|
@ -3063,6 +3069,24 @@ class Mamba2Model(Model):
|
|||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
|
||||
if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
|
||||
gguf.MODEL_TENSOR.SSM_A,
|
||||
gguf.MODEL_TENSOR.SSM_D,
|
||||
]):
|
||||
# unsqueeze A to use similar shape semantics as Mamba-1
|
||||
# (D is also unsqueezed, but for more straightforward broadcast internally)
|
||||
return data_torch.reshape((*data_torch.shape, 1))
|
||||
|
||||
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
|
||||
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
|
||||
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
|
||||
n_group = self.hparams.get("n_groups", 1)
|
||||
return data_torch.reshape((n_group, d_inner // n_group))
|
||||
|
||||
return data_torch.squeeze()
|
||||
|
||||
|
||||
|
||||
@Model.register("CohereForCausalLM")
|
||||
class CommandR2Model(Model):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue