Vectorize Rust solution based on inspiration from Numpy experiment

This commit is contained in:
Luke Hubmayer-Werner 2022-02-06 00:37:59 +10:30
parent eeafb060c4
commit 4449c47149
2 changed files with 75 additions and 42 deletions

View File

@ -10,3 +10,4 @@ regex = "1"
rayon = "1.5" rayon = "1.5"
bitintr = "0.3.0" bitintr = "0.3.0"
itertools = "0.10.2" itertools = "0.10.2"
array-init = "2.0.0"

View File

@ -1,9 +1,12 @@
#![allow(dead_code)]
#![allow(unused_imports)]
use std::fs; use std::fs;
use std::collections::HashMap; use std::collections::HashMap;
use bitintr::{Lzcnt, Tzcnt}; use bitintr::{Lzcnt, Tzcnt};
use regex::Regex; use regex::Regex;
use rayon::prelude::*; use rayon::prelude::*;
use itertools::zip; use itertools::zip;
use array_init::array_init;
type Charmask = i32; type Charmask = i32;
type Achar = i8; // ASCII char type Achar = i8; // ASCII char
@ -12,6 +15,8 @@ const WORD_LENGTH: usize = 5;
const WORD_LENGTH_P: usize = 5; // Padded for SIMD shenanigans const WORD_LENGTH_P: usize = 5; // Padded for SIMD shenanigans
const GUESS_DEPTH: usize = 1; // TODO: Change this whenever working at different depths const GUESS_DEPTH: usize = 1; // TODO: Change this whenever working at different depths
const N_SOLUTIONS: usize = 2315; const N_SOLUTIONS: usize = 2315;
const IDX_ALL_WORDS: Charmask = (1<<26) - 1;
const IDX_VALID_SOLUTIONS: Charmask = 0;
const A: Achar = 'A' as Achar; const A: Achar = 'A' as Achar;
const Z: Achar = 'Z' as Achar; const Z: Achar = 'Z' as Achar;
@ -19,7 +24,7 @@ const Z: Achar = 'Z' as Achar;
struct Word { struct Word {
charbits: [Charmask; WORD_LENGTH_P], // Each letter in bitmask form charbits: [Charmask; WORD_LENGTH_P], // Each letter in bitmask form
charmask: Charmask, // All of the characters contained charmask: Charmask, // All of the characters contained
letters: [Achar; WORD_LENGTH] //letters: [Achar; WORD_LENGTH]
} }
type WordCache = HashMap<Charmask, Vec<Word>>; type WordCache = HashMap<Charmask, Vec<Word>>;
@ -38,6 +43,10 @@ fn letters2str(letters: [Achar; WORD_LENGTH]) -> String {
letters.iter().map(|x| (*x as u8) as char).collect() letters.iter().map(|x| (*x as u8) as char).collect()
} }
fn charbits2str(charbits: [Charmask; WORD_LENGTH]) -> String {
charbits.iter().map(|x| (cm2char(*x, 0) as u8) as char).collect()
}
fn str2word(s: &str) -> Word { fn str2word(s: &str) -> Word {
let mut word = Word::default(); let mut word = Word::default();
let mut iter = s.chars(); let mut iter = s.chars();
@ -45,7 +54,7 @@ fn str2word(s: &str) -> Word {
let c = iter.next().unwrap() as Achar; let c = iter.next().unwrap() as Achar;
let cb = char2bit(c); let cb = char2bit(c);
word.charbits[i] = cb; word.charbits[i] = cb;
word.letters[i] = c; //word.letters[i] = c;
word.charmask |= cb; word.charmask |= cb;
} }
word word
@ -109,8 +118,8 @@ fn generate_wordcache(valid_words: Vec<Word>) -> WordCache {
let mut cache: WordCache = HashMap::new(); let mut cache: WordCache = HashMap::new();
let valid_solutions: Vec<Word> = valid_words[..N_SOLUTIONS].to_vec(); // Hacky way to separate the valid solutions from the larger guessing list let valid_solutions: Vec<Word> = valid_words[..N_SOLUTIONS].to_vec(); // Hacky way to separate the valid solutions from the larger guessing list
_generate_wordcache_nested(&mut cache, &valid_solutions, 0, 5); _generate_wordcache_nested(&mut cache, &valid_solutions, 0, 5);
cache.insert(0, valid_solutions); cache.insert(IDX_VALID_SOLUTIONS, valid_solutions);
cache.insert(-1, valid_words); cache.insert(IDX_ALL_WORDS, valid_words);
cache cache
} }
@ -118,57 +127,79 @@ fn filter_word(w: &[Charmask; WORD_LENGTH_P], banned_chars: &[Charmask; WORD_LEN
zip(w, banned_chars).all(|(x,y)| x & y == 0) zip(w, banned_chars).all(|(x,y)| x & y == 0)
} }
fn simulate(guess_ids: [usize; GUESS_DEPTH], solution_id: usize, wordcache: &WordCache) -> usize { fn aggregate_guesses(guess_ids: Vec<usize>, wordcache: &WordCache) -> Word {
let valid_words = &wordcache[&-1]; //guess_ids.iter().reduce(|out, g| out |= wordcache[IDX_ALL_WORDS][g]).unwrap()
let solution = valid_words[solution_id]; // Technically this should never cross past N_SOLUTIONS or it breaks cache guarantees let all_words = &wordcache[&IDX_ALL_WORDS];
let mut required_chars: Charmask = 0; let mut iter = guess_ids.iter();
let mut banned_chars: [Charmask; WORD_LENGTH] = [0; WORD_LENGTH]; let mut aggregate_guess = all_words[*iter.next().unwrap()];
let mut bans = 0; for g in iter {
for guess_id in guess_ids { let guess = all_words[*g];
let guess = valid_words[guess_id]; for i in 0..aggregate_guess.charbits.len() {
required_chars |= guess.charmask & solution.charmask; aggregate_guess.charbits[i] |= guess.charbits[i];
bans |= guess.charmask & !solution.charmask;
for i in 0..WORD_LENGTH {
if guess.letters[i] == solution.letters[i] { // Right letter right position
banned_chars[i] = !guess.charbits[i];
} else if guess.charbits[i] & solution.charmask != 0 { // Right letter wrong position
banned_chars[i] |= guess.charbits[i];
} }
aggregate_guess.charmask |= guess.charmask;
} }
} aggregate_guess
for j in 0..WORD_LENGTH {
banned_chars[j] |= bans;
}
let cachekey = required_chars;
match wordcache.contains_key(&cachekey) {
true => wordcache[&cachekey].iter().filter(|w| filter_word(&w.charbits, &banned_chars)).count(),
false => 0,
}
} }
fn find_worstcase(word_ids: [usize; GUESS_DEPTH], wordcache: &WordCache) -> (String, usize) { fn simulate(guess: Word, wordcache: &WordCache) -> (String, usize) {
let valid_words = &wordcache[&-1]; let valid_words = &wordcache[&IDX_ALL_WORDS];
let valid_solutions = &wordcache[&0]; let valid_solutions = &wordcache[&IDX_VALID_SOLUTIONS];
let required_chars: [Charmask; N_SOLUTIONS] = array_init::from_iter(
valid_solutions.iter().map(|s| s.charmask & guess.charmask)
).unwrap();
let mut banned_chars: [Charmask; WORD_LENGTH*N_SOLUTIONS] = [0; WORD_LENGTH*N_SOLUTIONS];
/* array_init::from_iter(
valid_solutions.iter().map(|s| s.charmask & guess.charmask)
).unwrap(); */
for i in 0..N_SOLUTIONS {
let s = valid_solutions[i];
let bans = guess.charmask & !s.charmask; // A letter fully rejected in any position bans it in all positions
for j in 0..WORD_LENGTH {
banned_chars[i*WORD_LENGTH + j] = bans;
banned_chars[i*WORD_LENGTH + j] |= guess.charbits[j] & !s.charbits[j]; // A letter in the wrong position
// A correct letter bans all others in the position. TODO: test branchless toggle
let correct = guess.charbits[j] & s.charbits[j];
//Branch
/* if correct != 0 {
banned_chars[i*WORD_LENGTH + j] |= !correct;
} */
//Branchless
banned_chars[i*WORD_LENGTH + j] |= !correct * (correct !=0) as i32;
}
}
let mut worst = 0; let mut worst = 0;
let mut worst_w = 0; let mut worst_w = 0;
for target_id in 0..valid_solutions.len() { for target_id in 0..N_SOLUTIONS {
let remaining = simulate(word_ids, target_id, wordcache); let cachekey = required_chars[target_id];
if wordcache.contains_key(&cachekey) {
let mut remaining = 0;
for word in &wordcache[&cachekey] {
// TODO: test branchless toggle
let mut error = 0;
for c in 0..WORD_LENGTH {
error += word.charbits[c] & banned_chars[target_id*WORD_LENGTH + c];
}
remaining += (error == 0) as usize;
}
if remaining > worst { if remaining > worst {
worst = remaining; worst = remaining;
worst_w = target_id; worst_w = target_id;
};
} }
let wordstr: String = word_ids.map(|i| letters2str(valid_words[i].letters)).join(", "); }
let worststr: String = letters2str(valid_words[worst_w].letters); }
let wordstr: String = charbits2str(guess.charbits); // THIS IS NOT SUITED FOR AGGREGATE GUESSES YET!
let worststr: String = charbits2str(valid_words[worst_w].charbits);
let output = format!("{} - {} ({})", wordstr, worst, worststr); let output = format!("{} - {} ({})", wordstr, worst, worststr);
println!("{}", output);
(output, worst) (output, worst)
} }
fn find_word_id_from_str(s: &str, words: &Vec<Word>) -> usize { fn find_word_id_from_str(s: &str, words: &Vec<Word>) -> usize {
let w = str2word(s); let w = str2word(s);
words.iter().position(|x| x.letters==w.letters).unwrap() words.iter().position(|x| x.charbits==w.charbits).unwrap()
} }
fn main() { fn main() {
@ -177,6 +208,7 @@ fn main() {
let totalwords = words.len(); let totalwords = words.len();
println!("Hello, world! {} words in dict", totalwords); println!("Hello, world! {} words in dict", totalwords);
let wordcache = generate_wordcache(words); let wordcache = generate_wordcache(words);
let all_words = &wordcache[&IDX_ALL_WORDS];
//let sr = simulate(&wordcache[""][0], &wordcache[""][5000], &wordcache); //let sr = simulate(&wordcache[""][0], &wordcache[""][5000], &wordcache);
//println!("{:?}", sr); //println!("{:?}", sr);
@ -189,7 +221,7 @@ fn main() {
// .map(|(i, j)| find_worstcase([i, j], &wordcache)).collect(); // .map(|(i, j)| find_worstcase([i, j], &wordcache)).collect();
// Depth-1 full // Depth-1 full
let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| find_worstcase([i], &wordcache)).collect(); let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| simulate(all_words[i], &wordcache)).collect();
// Depth-3 (word1,word2,?) // Depth-3 (word1,word2,?)
// let i1 = find_word_id_from_str("CARET", &wordcache[&0]); // let i1 = find_word_id_from_str("CARET", &wordcache[&0]);