package net.abhinavsarkar.algorist.tree; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Arrays; import java.util.Iterator; import java.util.Objects; import java.util.Optional; import java.util.function.Function; import java.util.stream.Stream; import net.abhinavsarkar.algorist.SortedMap; public class Trie, V> implements SortedMap { private final Function keyToStringFn; private Node root; public Trie(Function keyToStringFn, int minBound, int maxBound) { this.keyToStringFn = keyToStringFn; root = new RootNode<>(maxBound - minBound + 1); } @Override public Optional minimum() { return Optional.empty(); } @Override public Optional maximum() { return Optional.empty(); } @Override public Optional successor(K key) { return Optional.empty(); } @Override public Optional predecessor(K key) { return Optional.empty(); } @Override public void put(K key, V val) { root.put(keyToStringFn.apply(key), 0, val); } @Override public Optional get(K key) { return this.root.get(keyToStringFn.apply(key), 0); } @Override public void remove(K key) { this.root.remove(keyToStringFn.apply(key), 0); } @Override public Iterator> iterator() { return null; } private interface Node { void put(int[] chars, int idx, V val); Optional get(int[] chars, int idx); boolean remove(int[] chars, int idx); } private static class RootNode implements Node { protected final ValueNode[] children; protected final int capacity; @SuppressWarnings("unchecked") private RootNode(int capacity) { children = new ValueNode[capacity]; this.capacity = capacity; } @Override public void put(int[] chars, int idx, V val) { if (idx != 0) { throw new IllegalArgumentException("index must be zero"); } if (chars.length == 0) { return; } if (children[chars[idx]] == null) { children[chars[idx]] = new ValueNode<>(null, capacity); } children[chars[idx]].put(chars, idx+1, val); } @Override public Optional get(int[] chars, int idx) { if (idx != 0) { throw new IllegalArgumentException("index must be zero"); } if (children[chars[idx]] == null) { return Optional.empty(); } return children[chars[idx]].get(chars, idx + 1); } @Override public boolean remove(int[] chars, int idx) { if (idx != 0) { throw new IllegalArgumentException("index must be zero"); } doRemove(chars, idx); return false; } protected void doRemove(int[] chars, int idx) { if (children[chars[idx]] != null && children[chars[idx]].remove(chars, idx + 1)) { children[chars[idx]] = null; } } } private static class ValueNode extends RootNode { private V val; private ValueNode(V val, int capacity) { super(capacity); this.val = val; } @Override public void put(int[] chars, int idx, V val) { if (idx > chars.length) { throw new IllegalArgumentException("index too big"); } if (chars.length == idx) { this.val = val; return; } if (children[chars[idx]] == null) { children[chars[idx]] = new ValueNode<>(null, capacity); } children[chars[idx]].put(chars, idx+1, val); } @Override public Optional get(int[] chars, int idx) { if (chars.length == idx) { return Optional.ofNullable(val); } if (children[chars[idx]] == null) { return Optional.empty(); } return children[chars[idx]].get(chars, idx+1); } @Override public boolean remove(int[] chars, int idx) { if (chars.length == idx) { val = null; } else { doRemove(chars, idx); } return Arrays.stream(children).noneMatch(Objects::nonNull); } } public static void main(String[] args) throws IOException { Trie trie = new Trie<>(s -> stringToIntArray(s, 'A'), 'A', 'z'); try (Stream lines = Files.lines(Paths.get("/usr/share/dict/words"))) { lines.forEach(line -> trie.put(line, line)); } System.out.println(trie.contains("abactor")); System.out.println(trie.contains("a")); System.out.println(trie.contains("bbb")); System.out.println(trie.contains("aardwolf")); trie.remove("aardwolf"); System.out.println(trie.contains("aardwolf")); } private static int[] stringToIntArray(String s, char min) { int[] ints = new int[s.length()]; char[] charArray = s.toCharArray(); for (int i = 0; i < charArray.length; i++) { char c = charArray[i]; if (c >= min) { ints[i] = c - min; } } return ints; } }