ggml : avoid multiply by D in GGML_OP_SSM_SCAN

This makes the weight buft detection in src/llama.cpp simpler.

* convert : transpose Mamba-2 A, D and reshape SSM_NORM

This breaks existing conversions of Mamba-2 models
to avoid some reshapes.

Not sure if it's a good idea,
but it makes the graph slightly cleaner.

* llama : more appropriate SSM_SCAN and SSM_CONV buft support checks
This commit is contained in:
Francis Couture-Harpin 2024-11-04 11:36:37 -05:00
parent 7d16e1bc8c
commit 3bc7103d2e
7 changed files with 98 additions and 95 deletions

View file

@ -264,6 +264,12 @@ class Model:
return [(self.map_tensor_name(name), data_torch)]
# TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that)
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
del new_name, bid # unused
return data_torch.squeeze()
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
del name, new_name, bid, n_dims # unused
@ -295,7 +301,7 @@ class Model:
break
for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data = data_torch.squeeze().numpy()
data = self.reshape_tensors(data_torch, new_name, bid).numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
@ -3063,6 +3069,24 @@ class Mamba2Model(Model):
yield (new_name, data_torch)
def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor:
if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [
gguf.MODEL_TENSOR.SSM_A,
gguf.MODEL_TENSOR.SSM_D,
]):
# unsqueeze A to use similar shape semantics as Mamba-1
# (D is also unsqueezed, but for more straightforward broadcast internally)
return data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
return data_torch.reshape((n_group, d_inner // n_group))
return data_torch.squeeze()
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):