Added support for SFTTrainer checkpoint models and adapter models containing one or more non-LoRA weights
My initial commit was more like a brute force. The edits suggested by @FirstTimeEZ reduces the complexity.
This commit is contained in:
parent
c6396aa4bb
commit
c2c2626ec6
1 changed files with 4 additions and 14 deletions
|
@ -338,16 +338,8 @@ if __name__ == '__main__':
|
|||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||
|
||||
# The following edits will enable conversion for: SFTTrainer checkpoint adapter models and other adapter models that contain weights besides LoRA weights
|
||||
|
||||
# Here, we first get the items with the 'lora_' substring
|
||||
lora_model_items_name = [name for name,_ in lora_model.items()]
|
||||
lora_model_items_with_lora_tensors = [name for name in lora_model_items_name if 'lora_' in name]
|
||||
|
||||
for name, tensor in lora_model.items():
|
||||
|
||||
# Check for only LoRA finetuned weights and base layer weights
|
||||
if (name in lora_model_items_with_lora_tensors) or (".base_layer.weight" in name):
|
||||
if ("lora_" in name) or (".base_layer.weight" in name):
|
||||
if self.lazy:
|
||||
tensor = LazyTorchTensor.from_eager(tensor)
|
||||
base_name = get_base_tensor_name(name)
|
||||
|
@ -356,11 +348,7 @@ if __name__ == '__main__':
|
|||
if not is_lora_a and not is_lora_b:
|
||||
if ".base_layer.weight" in name:
|
||||
continue
|
||||
|
||||
# we will either have a lora weight or a base layer weight, this error becomes trivial
|
||||
# logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||
# sys.exit(1)
|
||||
|
||||
|
||||
if base_name in tensor_map:
|
||||
if is_lora_a:
|
||||
tensor_map[base_name].A = tensor
|
||||
|
@ -371,6 +359,8 @@ if __name__ == '__main__':
|
|||
tensor_map[base_name] = PartialLoraTensor(A=tensor)
|
||||
else:
|
||||
tensor_map[base_name] = PartialLoraTensor(B=tensor)
|
||||
else:
|
||||
pass
|
||||
|
||||
for name, tensor in tensor_map.items():
|
||||
assert tensor.A is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue