private static AVLTree balance(AVLTree l, Object key, Object value, AVLTree r) { if (l.height() > r.height() + 2) { assert !l.isEmpty(); AVLTree ll = l.left(); AVLTree lr = l.right(); if (ll.height() >= lr.height()) { return createNode(ll, l, createNode(lr, key, value, r)); } assert !lr.isEmpty(); AVLTree lrl = lr.left(); AVLTree lrr = lr.right(); return createNode(createNode(ll, l, lrl), lr, createNode(lrr, key, value, r)); } if (r.height() > l.height() + 2) { assert !r.isEmpty(); AVLTree rl = r.left(); AVLTree rr = r.right(); if (rr.height() >= rl.height()) { return createNode(createNode(l, key, value, rl), r, rr); } assert !rl.isEmpty(); AVLTree rll = rl.left(); AVLTree rlr = rl.right(); return createNode(createNode(l, key, value, rll), rl, createNode(rlr, r, rr)); } return createNode(l, key, value, r); }
@Test public void test() { List<Integer> keys = new ArrayList<>(); for (int i = 0; i < 100; i++) { keys.add(i); } Collections.shuffle(keys); AVLTree<Integer, Object> t = AVLTree.create(); for (Integer key : keys) { t = t.add(key); assertThat(t.add(key)).isSameAs(t); } assertThat(Counter.countSet(t)).isEqualTo(100); assertThat(Counter.countMap(t)).isEqualTo(100); assertThat(t.height()).isGreaterThanOrEqualTo(8).isLessThanOrEqualTo(10); for (Integer key : keys) { assertThat(t.contains(key)).isTrue(); t = t.remove(key); assertThat(t.remove(key)).isSameAs(t); } assertThat(Counter.countSet(t)).isEqualTo(0); assertThat(Counter.countMap(t)).isEqualTo(0); }
@Test public void balancing_should_preserve_buckets() { Object k1 = new Key(1, "k1"); Object k2 = new Key(2, "k2"); Object k3 = new Key(3, "k3"); Object k4 = new Key(4, "k4"); AVLTree<Object, Object> t = AVLTree.create().put(k1, "v1").put(k2, "v2").put(k3, "v3"); Object k1_1 = new Key(1, "k1_1"); t = t.put(k1_1, "v1_1"); t = t.put(k4, "v4"); assertThat(t.height()).as("height after balancing").isEqualTo(3); assertThat(t.get(k1_1)).isEqualTo("v1_1"); }
private static int incrementHeight(AVLTree l, AVLTree r) { return (l.height() > r.height() ? l.height() : r.height()) + 1; }