1pub mod api;
2pub mod common;
3pub mod constants;
4pub mod kernel;
5pub mod passes;
6pub mod search;
7pub mod utils;
8
9use std::marker::PhantomData;
40use std::{fmt, fs, path::Path};
41
42use tracing::info;
43
44use crate::utils::init_logger;
45use crate::{
46 api::{Pass, Search},
47 constants::SATResult,
48 kernel::Kernel,
49};
50
51#[cfg_attr(doc, aquamarine::aquamarine)]
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum SolverStatus {
64 UNKNOWN,
65 SOLVING,
66 SAT,
67 UNSAT,
68}
69
70pub trait SolverState {
72 const STATUS: SolverStatus;
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub struct UNKNOWN;
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct SOLVING;
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct SAT;
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct UNSAT;
83
84impl SolverState for UNKNOWN {
85 const STATUS: SolverStatus = SolverStatus::UNKNOWN;
86}
87impl SolverState for SOLVING {
88 const STATUS: SolverStatus = SolverStatus::SOLVING;
89}
90impl SolverState for SAT {
91 const STATUS: SolverStatus = SolverStatus::SAT;
92}
93impl SolverState for UNSAT {
94 const STATUS: SolverStatus = SolverStatus::UNSAT;
95}
96
97pub enum SolveResult<S>
99where
100 S: Search,
101{
102 SAT(Solver<S, SAT>),
103 UNSAT(Solver<S, UNSAT>),
104 UNKNOWN(Solver<S, UNKNOWN>),
105}
106
107impl<S> SolveResult<S>
108where
109 S: Search,
110{
111 pub const fn status(&self) -> SolverStatus {
113 match self {
114 Self::SAT(_) => SolverStatus::SAT,
115 Self::UNSAT(_) => SolverStatus::UNSAT,
116 Self::UNKNOWN(_) => SolverStatus::UNKNOWN,
117 }
118 }
119}
120
121#[derive(Debug)]
123pub enum DimacsError {
124 MissingHeader,
125 ClauseBeforeHeader { line: usize },
126 InvalidHeader { line: usize, content: String },
127 InvalidLiteral { line: usize, token: String },
128 LiteralOutOfRange { line: usize, lit: isize, max_vars: usize },
129 UnterminatedClause,
130 ClauseCountMismatch { expected: usize, parsed: usize },
131 Io(std::io::Error),
132}
133
134impl fmt::Display for DimacsError {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 match self {
137 Self::MissingHeader => write!(f, "missing DIMACS header line 'p cnf <vars> <clauses>'"),
138 Self::ClauseBeforeHeader { line } => {
139 write!(f, "clause appears before header at line {line}")
140 }
141 Self::InvalidHeader { line, content } => {
142 write!(f, "invalid DIMACS header at line {line}: '{content}'")
143 }
144 Self::InvalidLiteral { line, token } => {
145 write!(f, "invalid literal '{token}' at line {line}")
146 }
147 Self::LiteralOutOfRange { line, lit, max_vars } => {
148 write!(f, "literal {lit} out of range at line {line} (max var id {max_vars})")
149 }
150 Self::UnterminatedClause => write!(f, "unterminated clause: missing trailing 0"),
151 Self::ClauseCountMismatch { expected, parsed } => {
152 write!(f, "DIMACS clause count mismatch: expected {expected}, parsed {parsed}")
153 }
154 Self::Io(err) => write!(f, "{err}"),
155 }
156 }
157}
158
159impl std::error::Error for DimacsError {}
160
161impl From<std::io::Error> for DimacsError {
162 fn from(value: std::io::Error) -> Self {
163 Self::Io(value)
164 }
165}
166
167fn parse_dimacs_kernel(input: &str) -> Result<Kernel, DimacsError> {
179 let mut kernel: Option<Kernel> = None;
180 let mut expected_clauses = 0usize;
181 let mut parsed_clauses = 0usize;
182 let mut open_clause = false;
183
184 for (idx, raw_line) in input.lines().enumerate() {
185 let line_no = idx + 1;
186 let line = raw_line.trim();
187 if line.is_empty() || line.starts_with('c') {
188 continue;
189 }
190
191 if line.starts_with('p') {
192 if kernel.is_some() {
193 return Err(DimacsError::InvalidHeader {
194 line: line_no,
195 content: line.to_string(),
196 });
197 }
198
199 let mut parts = line.split_whitespace();
200 let p = parts.next();
201 let cnf = parts.next();
202 let vars = parts.next();
203 let clauses = parts.next();
204 let extra = parts.next();
205 if p != Some("p")
206 || cnf != Some("cnf")
207 || vars.is_none()
208 || clauses.is_none()
209 || extra.is_some()
210 {
211 return Err(DimacsError::InvalidHeader {
212 line: line_no,
213 content: line.to_string(),
214 });
215 }
216
217 let max_vars = vars.and_then(|s| s.parse::<usize>().ok()).ok_or_else(|| {
218 DimacsError::InvalidHeader { line: line_no, content: line.to_string() }
219 })?;
220 expected_clauses = clauses.and_then(|s| s.parse::<usize>().ok()).ok_or_else(|| {
221 DimacsError::InvalidHeader { line: line_no, content: line.to_string() }
222 })?;
223 kernel = Some(Kernel::new(max_vars));
224 continue;
225 }
226
227 let Some(k) = kernel.as_mut() else {
228 return Err(DimacsError::ClauseBeforeHeader { line: line_no });
229 };
230
231 for token in line.split_whitespace() {
232 let lit = token.parse::<isize>().map_err(|_| DimacsError::InvalidLiteral {
233 line: line_no,
234 token: token.to_string(),
235 })?;
236 if lit == 0 {
237 k.add(None);
238 parsed_clauses += 1;
239 open_clause = false;
240 continue;
241 }
242
243 let var_id = lit.unsigned_abs();
244 let max_vars = k.assignment.len() - 1;
245 if var_id == 0 || var_id > max_vars {
246 return Err(DimacsError::LiteralOutOfRange { line: line_no, lit, max_vars });
247 }
248 k.add(Some(lit));
249 open_clause = true;
250 }
251 }
252
253 let Some(kernel) = kernel else {
254 return Err(DimacsError::MissingHeader);
255 };
256 if open_clause {
257 return Err(DimacsError::UnterminatedClause);
258 }
259 if parsed_clauses != expected_clauses {
260 return Err(DimacsError::ClauseCountMismatch {
261 expected: expected_clauses,
262 parsed: parsed_clauses,
263 });
264 }
265 Ok(kernel)
266}
267
268pub struct SolverBuilder<S>
270where
271 S: Search,
272{
273 search: S,
274 kernel: Kernel,
275}
276
277impl<S> SolverBuilder<S>
278where
279 S: Search,
280{
281 pub fn with_max_vars(search: S, max_vars: usize) -> Self {
283 Self { search, kernel: Kernel::new(max_vars) }
284 }
285
286 pub fn from_dimacs_str(search: S, dimacs: &str) -> Result<Self, DimacsError> {
288 let kernel = parse_dimacs_kernel(dimacs)?;
289 Ok(Self { search, kernel })
290 }
291
292 pub fn from_dimacs_file(search: S, path: impl AsRef<Path>) -> Result<Self, DimacsError> {
294 let input = fs::read_to_string(path)?;
295 Self::from_dimacs_str(search, &input)
296 }
297
298 pub fn kernel(&self) -> &Kernel {
300 &self.kernel
301 }
302
303 pub fn kernel_mut(&mut self) -> &mut Kernel {
305 &mut self.kernel
306 }
307
308 pub fn build(self) -> Solver<S, UNKNOWN> {
310 Solver::new(self.search, self.kernel)
311 }
312}
313
314pub struct Solver<S, St = UNKNOWN>
322where
323 S: Search,
324 St: SolverState,
325{
326 pre_processor: Vec<Box<dyn Pass>>,
327 in_processor: Vec<Box<dyn Pass>>,
328 search: S,
329 kernel: Kernel,
330 _state: PhantomData<St>,
331}
332
333impl<S, St> Solver<S, St>
334where
335 S: Search,
336 St: SolverState,
337{
338 pub const fn status(&self) -> SolverStatus {
340 St::STATUS
341 }
342
343 pub fn kernel(&self) -> &Kernel {
345 &self.kernel
346 }
347
348 fn into_state<Next>(self) -> Solver<S, Next>
349 where
350 Next: SolverState,
351 {
352 Solver {
353 pre_processor: self.pre_processor,
354 in_processor: self.in_processor,
355 search: self.search,
356 kernel: self.kernel,
357 _state: PhantomData,
358 }
359 }
360}
361
362impl<S> Solver<S, UNKNOWN>
363where
364 S: Search,
365{
366 pub fn new(search: S, kernel: Kernel) -> Self {
368 init_logger();
369 Self {
370 pre_processor: Vec::new(),
371 in_processor: Vec::new(),
372 search,
373 kernel,
374 _state: PhantomData,
375 }
376 }
377
378 pub fn add_preprocess_pass(&mut self, pass: impl Pass + 'static) {
380 self.pre_processor.push(Box::new(pass));
381 }
382
383 pub fn arrange_preprocess_passes(&mut self, ordered: &str) {
385 Self::arrange_passes(&mut self.pre_processor, ordered);
386 }
387
388 pub fn add_inprocess_pass(&mut self, pass: impl Pass + 'static) {
390 self.in_processor.push(Box::new(pass));
391 }
392
393 pub fn arrange_inprocess_passes(&mut self, ordered: &str) {
395 Self::arrange_passes(&mut self.in_processor, ordered);
396 }
397
398 pub fn solve(mut self) -> SolveResult<S> {
405 let mut pre_result = SATResult::UNKNOWN;
406 for pass in &mut self.pre_processor {
407 if pass.applying(&self.kernel) {
408 pre_result = pass.apply(&mut self.kernel);
409 if pre_result != SATResult::UNKNOWN {
410 break;
411 }
412 }
413 }
414
415 if pre_result == SATResult::SAT {
416 return SolveResult::SAT(self.into_state::<SAT>());
417 }
418 if pre_result == SATResult::UNSAT {
419 return SolveResult::UNSAT(self.into_state::<UNSAT>());
420 }
421
422 let mut solving = self.into_state::<SOLVING>();
423 let result = solving.search.search(&mut solving.kernel, &mut solving.in_processor);
424
425 info!("c solver finished with result: {:?}", result);
426 match result {
427 SATResult::SAT => SolveResult::SAT(solving.into_state::<SAT>()),
428 SATResult::UNSAT => SolveResult::UNSAT(solving.into_state::<UNSAT>()),
429 SATResult::UNKNOWN => SolveResult::UNKNOWN(solving.into_state::<UNKNOWN>()),
430 }
431 }
432
433 fn arrange_passes(passes: &mut Vec<Box<dyn Pass>>, ordered: &str) {
435 let mut remaining = std::mem::take(passes);
436 let mut ordered_passes: Vec<Box<dyn Pass>> = Vec::with_capacity(remaining.len());
437
438 for short_name in ordered.split(',').map(str::trim).filter(|s| !s.is_empty()) {
439 if let Some(idx) = remaining.iter().position(|p| p.name() == short_name) {
440 ordered_passes.push(remaining.swap_remove(idx));
441 }
442 }
443
444 *passes = ordered_passes;
445 }
446}
447
448impl<S> Solver<S, SAT>
449where
450 S: Search,
451{
452 pub fn model(&self) -> Vec<bool> {
454 self.kernel.assignment.iter().map(|&value| value == 1).collect()
455 }
456
457 pub fn check_sat(&self) -> Result<(), String> {
459 let model = self.model();
460 for clause in &self.kernel.clauses {
461 let mut satisfied = false;
462 for &lit in clause.literals() {
463 let var = lit.unsigned_abs();
464 let val = model[var];
465 if (lit > 0 && val) || (lit < 0 && !val) {
466 satisfied = true;
467 break;
468 }
469 }
470 if !satisfied {
471 return Err(format!("clause {:?} is not satisfied", clause));
472 }
473 }
474 Ok(())
475 }
476
477 pub fn print_model(&self) {
479 let model = self.model();
480 print!("v ");
481 for (v, val) in model.iter().enumerate().skip(1) {
482 if *val {
483 print!("{} ", v);
484 } else {
485 print!("-{} ", v);
486 }
487 if v % 10 == 0 {
488 print!("\nv ");
489 }
490 }
491 println!("0");
492 }
493}