diff --git a/Cargo.toml b/Cargo.toml index e351472..1e6f501 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,4 @@ regex = "1" rayon = "1.5" bitintr = "0.3.0" itertools = "0.10.2" +array-init = "2.0.0" diff --git a/src/main.rs b/src/main.rs index 8d50ed7..65e783b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,12 @@ +#![allow(dead_code)] +#![allow(unused_imports)] use std::fs; use std::collections::HashMap; use bitintr::{Lzcnt, Tzcnt}; use regex::Regex; use rayon::prelude::*; use itertools::zip; +use array_init::array_init; type Charmask = i32; type Achar = i8; // ASCII char @@ -12,6 +15,8 @@ const WORD_LENGTH: usize = 5; const WORD_LENGTH_P: usize = 5; // Padded for SIMD shenanigans const GUESS_DEPTH: usize = 1; // TODO: Change this whenever working at different depths const N_SOLUTIONS: usize = 2315; +const IDX_ALL_WORDS: Charmask = (1<<26) - 1; +const IDX_VALID_SOLUTIONS: Charmask = 0; const A: Achar = 'A' as Achar; const Z: Achar = 'Z' as Achar; @@ -19,7 +24,7 @@ const Z: Achar = 'Z' as Achar; struct Word { charbits: [Charmask; WORD_LENGTH_P], // Each letter in bitmask form charmask: Charmask, // All of the characters contained - letters: [Achar; WORD_LENGTH] + //letters: [Achar; WORD_LENGTH] } type WordCache = HashMap>; @@ -38,6 +43,10 @@ fn letters2str(letters: [Achar; WORD_LENGTH]) -> String { letters.iter().map(|x| (*x as u8) as char).collect() } +fn charbits2str(charbits: [Charmask; WORD_LENGTH]) -> String { + charbits.iter().map(|x| (cm2char(*x, 0) as u8) as char).collect() +} + fn str2word(s: &str) -> Word { let mut word = Word::default(); let mut iter = s.chars(); @@ -45,7 +54,7 @@ fn str2word(s: &str) -> Word { let c = iter.next().unwrap() as Achar; let cb = char2bit(c); word.charbits[i] = cb; - word.letters[i] = c; + //word.letters[i] = c; word.charmask |= cb; } word @@ -109,8 +118,8 @@ fn generate_wordcache(valid_words: Vec) -> WordCache { let mut cache: WordCache = HashMap::new(); let valid_solutions: Vec = valid_words[..N_SOLUTIONS].to_vec(); // Hacky way to separate the valid solutions from the larger guessing list _generate_wordcache_nested(&mut cache, &valid_solutions, 0, 5); - cache.insert(0, valid_solutions); - cache.insert(-1, valid_words); + cache.insert(IDX_VALID_SOLUTIONS, valid_solutions); + cache.insert(IDX_ALL_WORDS, valid_words); cache } @@ -118,57 +127,79 @@ fn filter_word(w: &[Charmask; WORD_LENGTH_P], banned_chars: &[Charmask; WORD_LEN zip(w, banned_chars).all(|(x,y)| x & y == 0) } -fn simulate(guess_ids: [usize; GUESS_DEPTH], solution_id: usize, wordcache: &WordCache) -> usize { - let valid_words = &wordcache[&-1]; - let solution = valid_words[solution_id]; // Technically this should never cross past N_SOLUTIONS or it breaks cache guarantees - let mut required_chars: Charmask = 0; - let mut banned_chars: [Charmask; WORD_LENGTH] = [0; WORD_LENGTH]; - let mut bans = 0; - for guess_id in guess_ids { - let guess = valid_words[guess_id]; - required_chars |= guess.charmask & solution.charmask; - bans |= guess.charmask & !solution.charmask; - for i in 0..WORD_LENGTH { - if guess.letters[i] == solution.letters[i] { // Right letter right position - banned_chars[i] = !guess.charbits[i]; - } else if guess.charbits[i] & solution.charmask != 0 { // Right letter wrong position - banned_chars[i] |= guess.charbits[i]; - } +fn aggregate_guesses(guess_ids: Vec, wordcache: &WordCache) -> Word { + //guess_ids.iter().reduce(|out, g| out |= wordcache[IDX_ALL_WORDS][g]).unwrap() + let all_words = &wordcache[&IDX_ALL_WORDS]; + let mut iter = guess_ids.iter(); + let mut aggregate_guess = all_words[*iter.next().unwrap()]; + for g in iter { + let guess = all_words[*g]; + for i in 0..aggregate_guess.charbits.len() { + aggregate_guess.charbits[i] |= guess.charbits[i]; } + aggregate_guess.charmask |= guess.charmask; } - for j in 0..WORD_LENGTH { - banned_chars[j] |= bans; - } - let cachekey = required_chars; - match wordcache.contains_key(&cachekey) { - true => wordcache[&cachekey].iter().filter(|w| filter_word(&w.charbits, &banned_chars)).count(), - false => 0, - } + aggregate_guess } -fn find_worstcase(word_ids: [usize; GUESS_DEPTH], wordcache: &WordCache) -> (String, usize) { - let valid_words = &wordcache[&-1]; - let valid_solutions = &wordcache[&0]; +fn simulate(guess: Word, wordcache: &WordCache) -> (String, usize) { + let valid_words = &wordcache[&IDX_ALL_WORDS]; + let valid_solutions = &wordcache[&IDX_VALID_SOLUTIONS]; + + let required_chars: [Charmask; N_SOLUTIONS] = array_init::from_iter( + valid_solutions.iter().map(|s| s.charmask & guess.charmask) + ).unwrap(); + let mut banned_chars: [Charmask; WORD_LENGTH*N_SOLUTIONS] = [0; WORD_LENGTH*N_SOLUTIONS]; + /* array_init::from_iter( + valid_solutions.iter().map(|s| s.charmask & guess.charmask) + ).unwrap(); */ + for i in 0..N_SOLUTIONS { + let s = valid_solutions[i]; + let bans = guess.charmask & !s.charmask; // A letter fully rejected in any position bans it in all positions + for j in 0..WORD_LENGTH { + banned_chars[i*WORD_LENGTH + j] = bans; + banned_chars[i*WORD_LENGTH + j] |= guess.charbits[j] & !s.charbits[j]; // A letter in the wrong position + // A correct letter bans all others in the position. TODO: test branchless toggle + let correct = guess.charbits[j] & s.charbits[j]; + //Branch + /* if correct != 0 { + banned_chars[i*WORD_LENGTH + j] |= !correct; + } */ + //Branchless + banned_chars[i*WORD_LENGTH + j] |= !correct * (correct !=0) as i32; + } + } let mut worst = 0; let mut worst_w = 0; - for target_id in 0..valid_solutions.len() { - let remaining = simulate(word_ids, target_id, wordcache); - if remaining > worst { - worst = remaining; - worst_w = target_id; - }; + for target_id in 0..N_SOLUTIONS { + let cachekey = required_chars[target_id]; + if wordcache.contains_key(&cachekey) { + let mut remaining = 0; + for word in &wordcache[&cachekey] { + // TODO: test branchless toggle + let mut error = 0; + for c in 0..WORD_LENGTH { + error += word.charbits[c] & banned_chars[target_id*WORD_LENGTH + c]; + } + remaining += (error == 0) as usize; + } + if remaining > worst { + worst = remaining; + worst_w = target_id; + } + } } - let wordstr: String = word_ids.map(|i| letters2str(valid_words[i].letters)).join(", "); - let worststr: String = letters2str(valid_words[worst_w].letters); + + let wordstr: String = charbits2str(guess.charbits); // THIS IS NOT SUITED FOR AGGREGATE GUESSES YET! + let worststr: String = charbits2str(valid_words[worst_w].charbits); let output = format!("{} - {} ({})", wordstr, worst, worststr); - println!("{}", output); (output, worst) } fn find_word_id_from_str(s: &str, words: &Vec) -> usize { let w = str2word(s); - words.iter().position(|x| x.letters==w.letters).unwrap() + words.iter().position(|x| x.charbits==w.charbits).unwrap() } fn main() { @@ -177,6 +208,7 @@ fn main() { let totalwords = words.len(); println!("Hello, world! {} words in dict", totalwords); let wordcache = generate_wordcache(words); + let all_words = &wordcache[&IDX_ALL_WORDS]; //let sr = simulate(&wordcache[""][0], &wordcache[""][5000], &wordcache); //println!("{:?}", sr); @@ -189,7 +221,7 @@ fn main() { // .map(|(i, j)| find_worstcase([i, j], &wordcache)).collect(); // Depth-1 full - let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| find_worstcase([i], &wordcache)).collect(); + let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| simulate(all_words[i], &wordcache)).collect(); // Depth-3 (word1,word2,?) // let i1 = find_word_id_from_str("CARET", &wordcache[&0]);