Added support for SFTTrainer checkpoint models and adapter models containing one or more non-LoRA weights
My initial commit was more like a brute force. The edits suggested by @FirstTimeEZ reduces the complexity.
This commit is contained in:
parent
c6396aa4bb
commit
c2c2626ec6
1 changed files with 4 additions and 14 deletions
|
@ -338,16 +338,8 @@ if __name__ == '__main__':
|
||||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||||
|
|
||||||
# The following edits will enable conversion for: SFTTrainer checkpoint adapter models and other adapter models that contain weights besides LoRA weights
|
|
||||||
|
|
||||||
# Here, we first get the items with the 'lora_' substring
|
|
||||||
lora_model_items_name = [name for name,_ in lora_model.items()]
|
|
||||||
lora_model_items_with_lora_tensors = [name for name in lora_model_items_name if 'lora_' in name]
|
|
||||||
|
|
||||||
for name, tensor in lora_model.items():
|
for name, tensor in lora_model.items():
|
||||||
|
if ("lora_" in name) or (".base_layer.weight" in name):
|
||||||
# Check for only LoRA finetuned weights and base layer weights
|
|
||||||
if (name in lora_model_items_with_lora_tensors) or (".base_layer.weight" in name):
|
|
||||||
if self.lazy:
|
if self.lazy:
|
||||||
tensor = LazyTorchTensor.from_eager(tensor)
|
tensor = LazyTorchTensor.from_eager(tensor)
|
||||||
base_name = get_base_tensor_name(name)
|
base_name = get_base_tensor_name(name)
|
||||||
|
@ -356,11 +348,7 @@ if __name__ == '__main__':
|
||||||
if not is_lora_a and not is_lora_b:
|
if not is_lora_a and not is_lora_b:
|
||||||
if ".base_layer.weight" in name:
|
if ".base_layer.weight" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we will either have a lora weight or a base layer weight, this error becomes trivial
|
|
||||||
# logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
|
||||||
# sys.exit(1)
|
|
||||||
|
|
||||||
if base_name in tensor_map:
|
if base_name in tensor_map:
|
||||||
if is_lora_a:
|
if is_lora_a:
|
||||||
tensor_map[base_name].A = tensor
|
tensor_map[base_name].A = tensor
|
||||||
|
@ -371,6 +359,8 @@ if __name__ == '__main__':
|
||||||
tensor_map[base_name] = PartialLoraTensor(A=tensor)
|
tensor_map[base_name] = PartialLoraTensor(A=tensor)
|
||||||
else:
|
else:
|
||||||
tensor_map[base_name] = PartialLoraTensor(B=tensor)
|
tensor_map[base_name] = PartialLoraTensor(B=tensor)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
for name, tensor in tensor_map.items():
|
for name, tensor in tensor_map.items():
|
||||||
assert tensor.A is not None
|
assert tensor.A is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue