diff --git a/lib/maple_tree.c b/lib/maple_tree.c index 880ce0fcdcac..0d4573a8d134 100644 --- a/lib/maple_tree.c +++ b/lib/maple_tree.c @@ -1756,6 +1756,36 @@ static inline void mas_replace(struct ma_state *mas, bool advanced) } } +/* + * mas_replace_node() - Replace a node by putting it in the tree, marking it + * dead, and freeing it. + * the parent encoding to locate the maple node in the tree. + * @mas - the ma_state with @mas->node pointing to the new node. + * @old_enode - The old maple encoded node. + */ +static inline void mas_replace_node(struct ma_state *mas, + struct maple_enode *old_enode) + __must_hold(mas->tree->ma_lock) +{ + if (mte_is_root(mas->node)) { + mas_mn(mas)->parent = ma_parent_ptr( + ((unsigned long)mas->tree | MA_ROOT_PARENT)); + rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node)); + mas_set_height(mas); + } else { + unsigned char offset = 0; + void __rcu **slots = NULL; + + offset = mte_parent_slot(mas->node); + slots = ma_slots(mte_parent(mas->node), + mas_parent_type(mas, mas->node)); + rcu_assign_pointer(slots[offset], mas->node); + } + + mte_set_node_dead(old_enode); + mas_free(mas, old_enode); +} + /* * mas_new_child() - Find the new child of a node. * @mas: the maple state @@ -3176,7 +3206,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end { enum maple_type mt = mte_node_type(mas->node); struct maple_node reuse, *newnode, *parent, *new_left, *left, *node; - struct maple_enode *eparent; + struct maple_enode *eparent, *old_eparent; unsigned char offset, tmp, split = mt_slots[mt] / 2; void __rcu **l_slots, **slots; unsigned long *l_pivs, *pivs, gap; @@ -3218,7 +3248,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end l_mas.max = l_pivs[split]; mas->min = l_mas.max + 1; - eparent = mt_mk_node(mte_parent(l_mas.node), + old_eparent = mt_mk_node(mte_parent(l_mas.node), mas_parent_type(&l_mas, l_mas.node)); tmp += end; if (!in_rcu) { @@ -3234,7 +3264,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end memcpy(node, newnode, sizeof(struct maple_node)); ma_set_meta(node, mt, 0, tmp - 1); - mte_set_pivot(eparent, mte_parent_slot(l_mas.node), + mte_set_pivot(old_eparent, mte_parent_slot(l_mas.node), l_pivs[split]); /* Remove data from l_pivs. */ @@ -3242,6 +3272,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end memset(l_pivs + tmp, 0, sizeof(unsigned long) * (max_p - tmp)); memset(l_slots + tmp, 0, sizeof(void *) * (max_s - tmp)); ma_set_meta(left, mt, 0, split); + eparent = old_eparent; goto done; } @@ -3266,7 +3297,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end parent = mas_pop_node(mas); slots = ma_slots(parent, mt); pivs = ma_pivots(parent, mt); - memcpy(parent, mte_to_node(eparent), sizeof(struct maple_node)); + memcpy(parent, mte_to_node(old_eparent), sizeof(struct maple_node)); rcu_assign_pointer(slots[offset], mas->node); rcu_assign_pointer(slots[offset - 1], l_mas.node); pivs[offset - 1] = l_mas.max; @@ -3278,8 +3309,10 @@ done: mte_set_gap(eparent, mte_parent_slot(l_mas.node), gap); mas_ascend(mas); - if (in_rcu) - mas_replace(mas, false); + if (in_rcu) { + mas_replace_node(mas, old_eparent); + mas_adopt_children(mas, mas->node); + } mas_update_gap(mas); } @@ -3596,11 +3629,13 @@ static noinline_for_kasan int mas_commit_b_node(struct ma_wr_state *wr_mas, struct maple_big_node *b_node, unsigned char end) { struct maple_node *node; + struct maple_enode *old_enode; unsigned char b_end = b_node->b_end; enum maple_type b_type = b_node->type; + old_enode = wr_mas->mas->node; if ((b_end < mt_min_slots[b_type]) && - (!mte_is_root(wr_mas->mas->node)) && + (!mte_is_root(old_enode)) && (mas_mt_height(wr_mas->mas) > 1)) return mas_rebalance(wr_mas->mas, b_node); @@ -3618,7 +3653,7 @@ static noinline_for_kasan int mas_commit_b_node(struct ma_wr_state *wr_mas, node->parent = mas_mn(wr_mas->mas)->parent; wr_mas->mas->node = mt_mk_node(node, b_type); mab_mas_cp(b_node, 0, b_end, wr_mas->mas, false); - mas_replace(wr_mas->mas, false); + mas_replace_node(wr_mas->mas, old_enode); reuse_node: mas_update_gap(wr_mas->mas); return 1; @@ -4117,9 +4152,10 @@ static inline bool mas_wr_node_store(struct ma_wr_state *wr_mas, done: mas_leaf_set_meta(mas, newnode, dst_pivots, maple_leaf_64, new_end); if (in_rcu) { - mte_set_node_dead(mas->node); + struct maple_enode *old_enode = mas->node; + mas->node = mt_mk_node(newnode, wr_mas->type); - mas_replace(mas, false); + mas_replace_node(mas, old_enode); } else { memcpy(wr_mas->node, newnode, sizeof(struct maple_node)); }