227 lines
5.6 KiB
Java
227 lines
5.6 KiB
Java
package net.abhinavsarkar.algorist;
|
|
|
|
import java.io.IOException;
|
|
import java.nio.file.Files;
|
|
import java.nio.file.Paths;
|
|
import java.util.Arrays;
|
|
import java.util.Iterator;
|
|
import java.util.Objects;
|
|
import java.util.Optional;
|
|
import java.util.function.Function;
|
|
import java.util.stream.Stream;
|
|
|
|
public class Trie<K extends Comparable<K>, V> implements SortedMap<K, V>
|
|
{
|
|
private final Function<K, ? extends int[]> keyToStringFn;
|
|
private Node<V> root;
|
|
|
|
public Trie(Function<K, int[]> keyToStringFn, int minBound, int maxBound) {
|
|
|
|
this.keyToStringFn = keyToStringFn;
|
|
root = new RootNode<>(maxBound - minBound + 1);
|
|
}
|
|
|
|
@Override
|
|
public Optional<K> minimum()
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
|
|
@Override
|
|
public Optional<K> maximum()
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
|
|
@Override
|
|
public Optional<K> successor(K key)
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
|
|
@Override
|
|
public Optional<K> predecessor(K key)
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
|
|
@Override
|
|
public void put(K key, V val)
|
|
{
|
|
root.put(keyToStringFn.apply(key), 0, val);
|
|
}
|
|
|
|
@Override
|
|
public Optional<V> get(K key)
|
|
{
|
|
return this.root.get(keyToStringFn.apply(key), 0);
|
|
}
|
|
|
|
@Override
|
|
public void remove(K key)
|
|
{
|
|
this.root.remove(keyToStringFn.apply(key), 0);
|
|
}
|
|
|
|
@Override
|
|
public Iterator<Entry<K, V>> iterator()
|
|
{
|
|
return null;
|
|
}
|
|
|
|
private interface Node<V> {
|
|
void put(int[] chars, int idx, V val);
|
|
Optional<V> get(int[] chars, int idx);
|
|
boolean remove(int[] chars, int idx);
|
|
}
|
|
|
|
private static class RootNode<V> implements Node<V> {
|
|
protected final ValueNode<V>[] children;
|
|
protected final int capacity;
|
|
|
|
@SuppressWarnings("unchecked")
|
|
private RootNode(int capacity)
|
|
{
|
|
children = new ValueNode[capacity];
|
|
this.capacity = capacity;
|
|
}
|
|
|
|
@Override
|
|
public void put(int[] chars, int idx, V val)
|
|
{
|
|
if (idx != 0) {
|
|
throw new IllegalArgumentException("index must be zero");
|
|
}
|
|
|
|
if (chars.length == 0) {
|
|
return;
|
|
}
|
|
|
|
if (children[chars[idx]] == null)
|
|
{
|
|
children[chars[idx]] = new ValueNode<>(null, capacity);
|
|
}
|
|
children[chars[idx]].put(chars, idx+1, val);
|
|
}
|
|
|
|
@Override
|
|
public Optional<V> get(int[] chars, int idx)
|
|
{
|
|
if (idx != 0) {
|
|
throw new IllegalArgumentException("index must be zero");
|
|
}
|
|
|
|
if (children[chars[idx]] == null)
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
return children[chars[idx]].get(chars, idx + 1);
|
|
}
|
|
|
|
@Override
|
|
public boolean remove(int[] chars, int idx)
|
|
{
|
|
if (idx != 0) {
|
|
throw new IllegalArgumentException("index must be zero");
|
|
}
|
|
|
|
doRemove(chars, idx);
|
|
return false;
|
|
}
|
|
|
|
protected void doRemove(int[] chars, int idx)
|
|
{
|
|
if (children[chars[idx]] != null && children[chars[idx]].remove(chars, idx + 1))
|
|
{
|
|
children[chars[idx]] = null;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static class ValueNode<V> extends RootNode<V> {
|
|
private V val;
|
|
|
|
private ValueNode(V val, int capacity)
|
|
{
|
|
super(capacity);
|
|
this.val = val;
|
|
}
|
|
|
|
@Override
|
|
public void put(int[] chars, int idx, V val) {
|
|
if (idx > chars.length) {
|
|
throw new IllegalArgumentException("index too big");
|
|
}
|
|
|
|
if (chars.length == idx) {
|
|
this.val = val;
|
|
return;
|
|
}
|
|
|
|
if (children[chars[idx]] == null)
|
|
{
|
|
children[chars[idx]] = new ValueNode<>(null, capacity);
|
|
}
|
|
children[chars[idx]].put(chars, idx+1, val);
|
|
}
|
|
|
|
@Override
|
|
public Optional<V> get(int[] chars, int idx)
|
|
{
|
|
if (chars.length == idx) {
|
|
return Optional.ofNullable(val);
|
|
}
|
|
|
|
if (children[chars[idx]] == null)
|
|
{
|
|
return Optional.empty();
|
|
}
|
|
|
|
return children[chars[idx]].get(chars, idx+1);
|
|
}
|
|
|
|
@Override
|
|
public boolean remove(int[] chars, int idx)
|
|
{
|
|
if (chars.length == idx) {
|
|
val = null;
|
|
} else {
|
|
doRemove(chars, idx);
|
|
}
|
|
return Arrays.stream(children).noneMatch(Objects::nonNull);
|
|
}
|
|
|
|
}
|
|
|
|
public static void main(String[] args) throws IOException
|
|
{
|
|
Trie<String, String> trie = new Trie<>(s -> stringToIntArray(s, 'A'), 'A', 'z');
|
|
try (Stream<String> lines = Files.lines(Paths.get("/usr/share/dict/words")))
|
|
{
|
|
lines.forEach(line -> trie.put(line, line));
|
|
}
|
|
System.out.println(trie.contains("abactor"));
|
|
System.out.println(trie.contains("a"));
|
|
System.out.println(trie.contains("bbb"));
|
|
System.out.println(trie.contains("aardwolf"));
|
|
trie.remove("aardwolf");
|
|
System.out.println(trie.contains("aardwolf"));
|
|
}
|
|
|
|
private static int[] stringToIntArray(String s, char min)
|
|
{
|
|
int[] ints = new int[s.length()];
|
|
char[] charArray = s.toCharArray();
|
|
for (int i = 0; i < charArray.length; i++)
|
|
{
|
|
char c = charArray[i];
|
|
if (c >= min)
|
|
{
|
|
ints[i] = c - min;
|
|
}
|
|
}
|
|
return ints;
|
|
}
|
|
|
|
}
|