Vectorize Rust solution based on inspiration from Numpy experiment
This commit is contained in:
parent
eeafb060c4
commit
4449c47149
|
@ -10,3 +10,4 @@ regex = "1"
|
||||||
rayon = "1.5"
|
rayon = "1.5"
|
||||||
bitintr = "0.3.0"
|
bitintr = "0.3.0"
|
||||||
itertools = "0.10.2"
|
itertools = "0.10.2"
|
||||||
|
array-init = "2.0.0"
|
||||||
|
|
110
src/main.rs
110
src/main.rs
|
@ -1,9 +1,12 @@
|
||||||
|
#![allow(dead_code)]
|
||||||
|
#![allow(unused_imports)]
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use bitintr::{Lzcnt, Tzcnt};
|
use bitintr::{Lzcnt, Tzcnt};
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
use itertools::zip;
|
use itertools::zip;
|
||||||
|
use array_init::array_init;
|
||||||
|
|
||||||
type Charmask = i32;
|
type Charmask = i32;
|
||||||
type Achar = i8; // ASCII char
|
type Achar = i8; // ASCII char
|
||||||
|
@ -12,6 +15,8 @@ const WORD_LENGTH: usize = 5;
|
||||||
const WORD_LENGTH_P: usize = 5; // Padded for SIMD shenanigans
|
const WORD_LENGTH_P: usize = 5; // Padded for SIMD shenanigans
|
||||||
const GUESS_DEPTH: usize = 1; // TODO: Change this whenever working at different depths
|
const GUESS_DEPTH: usize = 1; // TODO: Change this whenever working at different depths
|
||||||
const N_SOLUTIONS: usize = 2315;
|
const N_SOLUTIONS: usize = 2315;
|
||||||
|
const IDX_ALL_WORDS: Charmask = (1<<26) - 1;
|
||||||
|
const IDX_VALID_SOLUTIONS: Charmask = 0;
|
||||||
const A: Achar = 'A' as Achar;
|
const A: Achar = 'A' as Achar;
|
||||||
const Z: Achar = 'Z' as Achar;
|
const Z: Achar = 'Z' as Achar;
|
||||||
|
|
||||||
|
@ -19,7 +24,7 @@ const Z: Achar = 'Z' as Achar;
|
||||||
struct Word {
|
struct Word {
|
||||||
charbits: [Charmask; WORD_LENGTH_P], // Each letter in bitmask form
|
charbits: [Charmask; WORD_LENGTH_P], // Each letter in bitmask form
|
||||||
charmask: Charmask, // All of the characters contained
|
charmask: Charmask, // All of the characters contained
|
||||||
letters: [Achar; WORD_LENGTH]
|
//letters: [Achar; WORD_LENGTH]
|
||||||
}
|
}
|
||||||
|
|
||||||
type WordCache = HashMap<Charmask, Vec<Word>>;
|
type WordCache = HashMap<Charmask, Vec<Word>>;
|
||||||
|
@ -38,6 +43,10 @@ fn letters2str(letters: [Achar; WORD_LENGTH]) -> String {
|
||||||
letters.iter().map(|x| (*x as u8) as char).collect()
|
letters.iter().map(|x| (*x as u8) as char).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn charbits2str(charbits: [Charmask; WORD_LENGTH]) -> String {
|
||||||
|
charbits.iter().map(|x| (cm2char(*x, 0) as u8) as char).collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn str2word(s: &str) -> Word {
|
fn str2word(s: &str) -> Word {
|
||||||
let mut word = Word::default();
|
let mut word = Word::default();
|
||||||
let mut iter = s.chars();
|
let mut iter = s.chars();
|
||||||
|
@ -45,7 +54,7 @@ fn str2word(s: &str) -> Word {
|
||||||
let c = iter.next().unwrap() as Achar;
|
let c = iter.next().unwrap() as Achar;
|
||||||
let cb = char2bit(c);
|
let cb = char2bit(c);
|
||||||
word.charbits[i] = cb;
|
word.charbits[i] = cb;
|
||||||
word.letters[i] = c;
|
//word.letters[i] = c;
|
||||||
word.charmask |= cb;
|
word.charmask |= cb;
|
||||||
}
|
}
|
||||||
word
|
word
|
||||||
|
@ -109,8 +118,8 @@ fn generate_wordcache(valid_words: Vec<Word>) -> WordCache {
|
||||||
let mut cache: WordCache = HashMap::new();
|
let mut cache: WordCache = HashMap::new();
|
||||||
let valid_solutions: Vec<Word> = valid_words[..N_SOLUTIONS].to_vec(); // Hacky way to separate the valid solutions from the larger guessing list
|
let valid_solutions: Vec<Word> = valid_words[..N_SOLUTIONS].to_vec(); // Hacky way to separate the valid solutions from the larger guessing list
|
||||||
_generate_wordcache_nested(&mut cache, &valid_solutions, 0, 5);
|
_generate_wordcache_nested(&mut cache, &valid_solutions, 0, 5);
|
||||||
cache.insert(0, valid_solutions);
|
cache.insert(IDX_VALID_SOLUTIONS, valid_solutions);
|
||||||
cache.insert(-1, valid_words);
|
cache.insert(IDX_ALL_WORDS, valid_words);
|
||||||
cache
|
cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -118,57 +127,79 @@ fn filter_word(w: &[Charmask; WORD_LENGTH_P], banned_chars: &[Charmask; WORD_LEN
|
||||||
zip(w, banned_chars).all(|(x,y)| x & y == 0)
|
zip(w, banned_chars).all(|(x,y)| x & y == 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn simulate(guess_ids: [usize; GUESS_DEPTH], solution_id: usize, wordcache: &WordCache) -> usize {
|
fn aggregate_guesses(guess_ids: Vec<usize>, wordcache: &WordCache) -> Word {
|
||||||
let valid_words = &wordcache[&-1];
|
//guess_ids.iter().reduce(|out, g| out |= wordcache[IDX_ALL_WORDS][g]).unwrap()
|
||||||
let solution = valid_words[solution_id]; // Technically this should never cross past N_SOLUTIONS or it breaks cache guarantees
|
let all_words = &wordcache[&IDX_ALL_WORDS];
|
||||||
let mut required_chars: Charmask = 0;
|
let mut iter = guess_ids.iter();
|
||||||
let mut banned_chars: [Charmask; WORD_LENGTH] = [0; WORD_LENGTH];
|
let mut aggregate_guess = all_words[*iter.next().unwrap()];
|
||||||
let mut bans = 0;
|
for g in iter {
|
||||||
for guess_id in guess_ids {
|
let guess = all_words[*g];
|
||||||
let guess = valid_words[guess_id];
|
for i in 0..aggregate_guess.charbits.len() {
|
||||||
required_chars |= guess.charmask & solution.charmask;
|
aggregate_guess.charbits[i] |= guess.charbits[i];
|
||||||
bans |= guess.charmask & !solution.charmask;
|
|
||||||
for i in 0..WORD_LENGTH {
|
|
||||||
if guess.letters[i] == solution.letters[i] { // Right letter right position
|
|
||||||
banned_chars[i] = !guess.charbits[i];
|
|
||||||
} else if guess.charbits[i] & solution.charmask != 0 { // Right letter wrong position
|
|
||||||
banned_chars[i] |= guess.charbits[i];
|
|
||||||
}
|
}
|
||||||
|
aggregate_guess.charmask |= guess.charmask;
|
||||||
}
|
}
|
||||||
}
|
aggregate_guess
|
||||||
for j in 0..WORD_LENGTH {
|
|
||||||
banned_chars[j] |= bans;
|
|
||||||
}
|
|
||||||
let cachekey = required_chars;
|
|
||||||
match wordcache.contains_key(&cachekey) {
|
|
||||||
true => wordcache[&cachekey].iter().filter(|w| filter_word(&w.charbits, &banned_chars)).count(),
|
|
||||||
false => 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_worstcase(word_ids: [usize; GUESS_DEPTH], wordcache: &WordCache) -> (String, usize) {
|
fn simulate(guess: Word, wordcache: &WordCache) -> (String, usize) {
|
||||||
let valid_words = &wordcache[&-1];
|
let valid_words = &wordcache[&IDX_ALL_WORDS];
|
||||||
let valid_solutions = &wordcache[&0];
|
let valid_solutions = &wordcache[&IDX_VALID_SOLUTIONS];
|
||||||
|
|
||||||
|
let required_chars: [Charmask; N_SOLUTIONS] = array_init::from_iter(
|
||||||
|
valid_solutions.iter().map(|s| s.charmask & guess.charmask)
|
||||||
|
).unwrap();
|
||||||
|
let mut banned_chars: [Charmask; WORD_LENGTH*N_SOLUTIONS] = [0; WORD_LENGTH*N_SOLUTIONS];
|
||||||
|
/* array_init::from_iter(
|
||||||
|
valid_solutions.iter().map(|s| s.charmask & guess.charmask)
|
||||||
|
).unwrap(); */
|
||||||
|
for i in 0..N_SOLUTIONS {
|
||||||
|
let s = valid_solutions[i];
|
||||||
|
let bans = guess.charmask & !s.charmask; // A letter fully rejected in any position bans it in all positions
|
||||||
|
for j in 0..WORD_LENGTH {
|
||||||
|
banned_chars[i*WORD_LENGTH + j] = bans;
|
||||||
|
banned_chars[i*WORD_LENGTH + j] |= guess.charbits[j] & !s.charbits[j]; // A letter in the wrong position
|
||||||
|
// A correct letter bans all others in the position. TODO: test branchless toggle
|
||||||
|
let correct = guess.charbits[j] & s.charbits[j];
|
||||||
|
//Branch
|
||||||
|
/* if correct != 0 {
|
||||||
|
banned_chars[i*WORD_LENGTH + j] |= !correct;
|
||||||
|
} */
|
||||||
|
//Branchless
|
||||||
|
banned_chars[i*WORD_LENGTH + j] |= !correct * (correct !=0) as i32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut worst = 0;
|
let mut worst = 0;
|
||||||
let mut worst_w = 0;
|
let mut worst_w = 0;
|
||||||
for target_id in 0..valid_solutions.len() {
|
for target_id in 0..N_SOLUTIONS {
|
||||||
let remaining = simulate(word_ids, target_id, wordcache);
|
let cachekey = required_chars[target_id];
|
||||||
|
if wordcache.contains_key(&cachekey) {
|
||||||
|
let mut remaining = 0;
|
||||||
|
for word in &wordcache[&cachekey] {
|
||||||
|
// TODO: test branchless toggle
|
||||||
|
let mut error = 0;
|
||||||
|
for c in 0..WORD_LENGTH {
|
||||||
|
error += word.charbits[c] & banned_chars[target_id*WORD_LENGTH + c];
|
||||||
|
}
|
||||||
|
remaining += (error == 0) as usize;
|
||||||
|
}
|
||||||
if remaining > worst {
|
if remaining > worst {
|
||||||
worst = remaining;
|
worst = remaining;
|
||||||
worst_w = target_id;
|
worst_w = target_id;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
let wordstr: String = word_ids.map(|i| letters2str(valid_words[i].letters)).join(", ");
|
}
|
||||||
let worststr: String = letters2str(valid_words[worst_w].letters);
|
}
|
||||||
|
|
||||||
|
let wordstr: String = charbits2str(guess.charbits); // THIS IS NOT SUITED FOR AGGREGATE GUESSES YET!
|
||||||
|
let worststr: String = charbits2str(valid_words[worst_w].charbits);
|
||||||
let output = format!("{} - {} ({})", wordstr, worst, worststr);
|
let output = format!("{} - {} ({})", wordstr, worst, worststr);
|
||||||
println!("{}", output);
|
|
||||||
(output, worst)
|
(output, worst)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_word_id_from_str(s: &str, words: &Vec<Word>) -> usize {
|
fn find_word_id_from_str(s: &str, words: &Vec<Word>) -> usize {
|
||||||
let w = str2word(s);
|
let w = str2word(s);
|
||||||
words.iter().position(|x| x.letters==w.letters).unwrap()
|
words.iter().position(|x| x.charbits==w.charbits).unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
@ -177,6 +208,7 @@ fn main() {
|
||||||
let totalwords = words.len();
|
let totalwords = words.len();
|
||||||
println!("Hello, world! {} words in dict", totalwords);
|
println!("Hello, world! {} words in dict", totalwords);
|
||||||
let wordcache = generate_wordcache(words);
|
let wordcache = generate_wordcache(words);
|
||||||
|
let all_words = &wordcache[&IDX_ALL_WORDS];
|
||||||
|
|
||||||
//let sr = simulate(&wordcache[""][0], &wordcache[""][5000], &wordcache);
|
//let sr = simulate(&wordcache[""][0], &wordcache[""][5000], &wordcache);
|
||||||
//println!("{:?}", sr);
|
//println!("{:?}", sr);
|
||||||
|
@ -189,7 +221,7 @@ fn main() {
|
||||||
// .map(|(i, j)| find_worstcase([i, j], &wordcache)).collect();
|
// .map(|(i, j)| find_worstcase([i, j], &wordcache)).collect();
|
||||||
|
|
||||||
// Depth-1 full
|
// Depth-1 full
|
||||||
let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| find_worstcase([i], &wordcache)).collect();
|
let mut results: Vec<(String, usize)> = (0..totalwords).into_par_iter().map(|i| simulate(all_words[i], &wordcache)).collect();
|
||||||
|
|
||||||
// Depth-3 (word1,word2,?)
|
// Depth-3 (word1,word2,?)
|
||||||
// let i1 = find_word_id_from_str("CARET", &wordcache[&0]);
|
// let i1 = find_word_id_from_str("CARET", &wordcache[&0]);
|
||||||
|
|
Loading…
Reference in New Issue