Applied changes from upstream PR: save memory with lazy evaluation #7075 (shameless copy from LlamaModel).
This commit is contained in:
parent
f3d1227ca4
commit
a89257151f
1 changed files with 47 additions and 82 deletions
|
@ -2370,104 +2370,69 @@ class ArcticModel(Model):
|
||||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||||
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
|
||||||
|
|
||||||
# Same as super class, but permuting q_proj, k_proj
|
@staticmethod
|
||||||
def write_tensors(self):
|
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
|
||||||
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
|
if n_head_kv is not None and n_head != n_head_kv:
|
||||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
n_head = n_head_kv
|
||||||
n_head = self.hparams.get("num_attention_heads")
|
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(weights.shape))
|
||||||
|
|
||||||
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
n_head = self.hparams["num_attention_heads"]
|
||||||
n_kv_head = self.hparams.get("num_key_value_heads")
|
n_kv_head = self.hparams.get("num_key_value_heads")
|
||||||
n_experts = self.hparams.get("num_local_experts")
|
|
||||||
experts = dict()
|
|
||||||
for name, data_torch in self.get_tensors():
|
|
||||||
# we don't need these
|
|
||||||
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
|
|
||||||
continue
|
|
||||||
|
|
||||||
old_dtype = data_torch.dtype
|
if name.endswith("q_proj.weight"):
|
||||||
|
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
|
||||||
|
if name.endswith("k_proj.weight"):
|
||||||
|
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
|
||||||
|
|
||||||
# convert any unsupported data types to float32
|
# process the experts separately
|
||||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
if name.find("block_sparse_moe.experts") != -1:
|
||||||
data_torch = data_torch.to(torch.float32)
|
n_experts = self.hparams["num_local_experts"]
|
||||||
|
|
||||||
data = data_torch.numpy()
|
assert bid is not None
|
||||||
|
|
||||||
if name.endswith("q_proj.weight"):
|
if self._experts is None:
|
||||||
data = permute(data, n_head, n_head)
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
if name.endswith("k_proj.weight"):
|
|
||||||
data = permute(data, n_head, n_kv_head)
|
|
||||||
|
|
||||||
data = data.squeeze()
|
self._experts[bid][name] = data_torch
|
||||||
|
|
||||||
# process the experts separately
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
if name.find("block_sparse_moe.experts") != -1:
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
experts[name] = data
|
|
||||||
if len(experts) >= n_experts:
|
|
||||||
# merge the experts into a single 3d tensor
|
|
||||||
for bid in range(block_count):
|
|
||||||
for wid in range(1, 4):
|
|
||||||
full = True
|
|
||||||
for xid in range(n_experts):
|
|
||||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
|
|
||||||
if ename not in experts:
|
|
||||||
full = False
|
|
||||||
break
|
|
||||||
if not full:
|
|
||||||
continue
|
|
||||||
|
|
||||||
datas = []
|
# merge the experts into a single 3d tensor
|
||||||
for xid in range(n_experts):
|
for wid in ["w1", "w2", "w3"]:
|
||||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight"
|
datas: list[Tensor] = []
|
||||||
datas.append(experts[ename])
|
|
||||||
del experts[ename]
|
|
||||||
|
|
||||||
data = np.stack(datas, axis=0)
|
for xid in range(n_experts):
|
||||||
data_dtype = data.dtype
|
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
|
||||||
|
datas.append(self._experts[bid][ename])
|
||||||
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
if self.ftype == 0 and data_dtype == np.float16:
|
data_torch = torch.stack(datas, dim=0)
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
if self.ftype == 1 and data_dtype == np.float32:
|
merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight"
|
||||||
data = data.astype(np.float16)
|
|
||||||
|
|
||||||
merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight"
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias"))
|
tensors.append((new_name, data_torch))
|
||||||
if new_name is None:
|
return tensors
|
||||||
print(f"Can not map tensor {name!r}")
|
else:
|
||||||
sys.exit()
|
return []
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}")
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
def write_tensors(self):
|
||||||
continue
|
super().write_tensors()
|
||||||
|
|
||||||
# map tensor names
|
if self._experts is not None:
|
||||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||||
if new_name is None:
|
experts = [k for d in self._experts for k in d.keys()]
|
||||||
print(f"Can not map tensor {name!r}")
|
if len(experts) > 0:
|
||||||
sys.exit()
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
n_dims = len(data.shape)
|
|
||||||
data_dtype = data.dtype
|
|
||||||
|
|
||||||
# if f32 desired, convert any float16 to float32
|
|
||||||
if self.ftype == 0 and data_dtype == np.float16:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
# 1d tensors need to be converted to float32
|
|
||||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
|
||||||
data = data.astype(np.float32)
|
|
||||||
|
|
||||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
|
||||||
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
|
|
||||||
data = data.astype(np.float16)
|
|
||||||
|
|
||||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
|
||||||
|
|
||||||
self.gguf_writer.add_tensor(new_name, data)
|
|
||||||
|
|
||||||
if len(experts) > 0:
|
|
||||||
raise ValueError(f"Unprocessed experts: {experts.keys()}")
|
|
||||||
|
|
||||||
|
|
||||||
###### CONVERSION LOGIC ######
|
###### CONVERSION LOGIC ######
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue