Skip to main content

easy_sat_rs/
lib.rs

1pub mod api;
2pub mod common;
3pub mod constants;
4pub mod kernel;
5pub mod passes;
6pub mod search;
7pub mod utils;
8
9/// 求解器使用 Type-State 模式表达状态机,并在编译期约束非法状态迁移。
10///
11/// 状态转移图:
12/// `UNKNOWN -> SOLVING -> {SAT | UNSAT | UNKNOWN}`。
13///
14/// 典型使用方式:
15/// 1. 用 [SolverBuilder] 从 DIMACS 构建内核;
16/// 2. 调用 `build().solve()`;
17/// 3. 通过 [SolveResult] 分支读取 SAT/UNSAT 结果与模型。
18///
19/// # 使用方式
20/// ```rust
21/// let cnf_path = "path/to/cnf/file.cnf";
22/// let searcher = Searcher;
23/// let solver = SolverBuilder::from_dimacs_file(searcher, &cnf_path)?.build();
24/// let result = solver.solve();
25/// match result {
26///     SolveResult::SAT(solver) => {
27///         assert!(solver.check_sat().is_ok());
28///         println!("s SATISFIABLE");
29///         solver.print_model();
30///     }
31///     SolveResult::UNSAT(solver) => {
32///         println!("s UNSATISFIABLE");
33///     }
34///     SolveResult::UNKNOWN(solver) => {
35///         println!("s UNKNOWN");
36///     }
37/// }
38/// ```
39use std::marker::PhantomData;
40use std::{fmt, fs, path::Path};
41
42use tracing::info;
43
44use crate::utils::init_logger;
45use crate::{
46    api::{Pass, Search},
47    constants::SATResult,
48    kernel::Kernel,
49};
50
51/// 求解器外部可见的状态枚举。
52///
53/// ```mermaid
54/// stateDiagram-v2
55///     [*] --> UNKNOWN
56///     UNKNOWN --> SOLVING: solve()
57///     SOLVING --> SAT: 找到满足赋值
58///     SOLVING --> UNSAT: 证明不可满足
59///     SOLVING --> UNKNOWN: 保守返回/提前终止
60/// ```
61#[cfg_attr(doc, aquamarine::aquamarine)]
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum SolverStatus {
64    UNKNOWN,
65    SOLVING,
66    SAT,
67    UNSAT,
68}
69
70/// Type-State 标记 trait。
71pub trait SolverState {
72    const STATUS: SolverStatus;
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub struct UNKNOWN;
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct SOLVING;
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct SAT;
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub struct UNSAT;
83
84impl SolverState for UNKNOWN {
85    const STATUS: SolverStatus = SolverStatus::UNKNOWN;
86}
87impl SolverState for SOLVING {
88    const STATUS: SolverStatus = SolverStatus::SOLVING;
89}
90impl SolverState for SAT {
91    const STATUS: SolverStatus = SolverStatus::SAT;
92}
93impl SolverState for UNSAT {
94    const STATUS: SolverStatus = SolverStatus::UNSAT;
95}
96
97/// 求解结果(按最终状态分型返回)。
98pub enum SolveResult<S>
99where
100    S: Search,
101{
102    SAT(Solver<S, SAT>),
103    UNSAT(Solver<S, UNSAT>),
104    UNKNOWN(Solver<S, UNKNOWN>),
105}
106
107impl<S> SolveResult<S>
108where
109    S: Search,
110{
111    /// 返回结果对应的状态枚举。
112    pub const fn status(&self) -> SolverStatus {
113        match self {
114            Self::SAT(_) => SolverStatus::SAT,
115            Self::UNSAT(_) => SolverStatus::UNSAT,
116            Self::UNKNOWN(_) => SolverStatus::UNKNOWN,
117        }
118    }
119}
120
121/// 解析 DIMACS 文件时可能出现的错误。
122#[derive(Debug)]
123pub enum DimacsError {
124    MissingHeader,
125    ClauseBeforeHeader { line: usize },
126    InvalidHeader { line: usize, content: String },
127    InvalidLiteral { line: usize, token: String },
128    LiteralOutOfRange { line: usize, lit: isize, max_vars: usize },
129    UnterminatedClause,
130    ClauseCountMismatch { expected: usize, parsed: usize },
131    Io(std::io::Error),
132}
133
134impl fmt::Display for DimacsError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            Self::MissingHeader => write!(f, "missing DIMACS header line 'p cnf <vars> <clauses>'"),
138            Self::ClauseBeforeHeader { line } => {
139                write!(f, "clause appears before header at line {line}")
140            }
141            Self::InvalidHeader { line, content } => {
142                write!(f, "invalid DIMACS header at line {line}: '{content}'")
143            }
144            Self::InvalidLiteral { line, token } => {
145                write!(f, "invalid literal '{token}' at line {line}")
146            }
147            Self::LiteralOutOfRange { line, lit, max_vars } => {
148                write!(f, "literal {lit} out of range at line {line} (max var id {max_vars})")
149            }
150            Self::UnterminatedClause => write!(f, "unterminated clause: missing trailing 0"),
151            Self::ClauseCountMismatch { expected, parsed } => {
152                write!(f, "DIMACS clause count mismatch: expected {expected}, parsed {parsed}")
153            }
154            Self::Io(err) => write!(f, "{err}"),
155        }
156    }
157}
158
159impl std::error::Error for DimacsError {}
160
161impl From<std::io::Error> for DimacsError {
162    fn from(value: std::io::Error) -> Self {
163        Self::Io(value)
164    }
165}
166
167/// 解析 DIMACS CNF 文本并构造 [Kernel]。
168///
169/// 输入约定:
170/// - 头部:`p cnf <vars> <clauses>`
171/// - 子句:以 `0` 结束
172/// - `c` 开头行为注释
173///
174/// 例:
175/// `p cnf 3 2`
176/// `1 -2 0`
177/// `2 3 0`
178fn parse_dimacs_kernel(input: &str) -> Result<Kernel, DimacsError> {
179    let mut kernel: Option<Kernel> = None;
180    let mut expected_clauses = 0usize;
181    let mut parsed_clauses = 0usize;
182    let mut open_clause = false;
183
184    for (idx, raw_line) in input.lines().enumerate() {
185        let line_no = idx + 1;
186        let line = raw_line.trim();
187        if line.is_empty() || line.starts_with('c') {
188            continue;
189        }
190
191        if line.starts_with('p') {
192            if kernel.is_some() {
193                return Err(DimacsError::InvalidHeader {
194                    line: line_no,
195                    content: line.to_string(),
196                });
197            }
198
199            let mut parts = line.split_whitespace();
200            let p = parts.next();
201            let cnf = parts.next();
202            let vars = parts.next();
203            let clauses = parts.next();
204            let extra = parts.next();
205            if p != Some("p")
206                || cnf != Some("cnf")
207                || vars.is_none()
208                || clauses.is_none()
209                || extra.is_some()
210            {
211                return Err(DimacsError::InvalidHeader {
212                    line: line_no,
213                    content: line.to_string(),
214                });
215            }
216
217            let max_vars = vars.and_then(|s| s.parse::<usize>().ok()).ok_or_else(|| {
218                DimacsError::InvalidHeader { line: line_no, content: line.to_string() }
219            })?;
220            expected_clauses = clauses.and_then(|s| s.parse::<usize>().ok()).ok_or_else(|| {
221                DimacsError::InvalidHeader { line: line_no, content: line.to_string() }
222            })?;
223            kernel = Some(Kernel::new(max_vars));
224            continue;
225        }
226
227        let Some(k) = kernel.as_mut() else {
228            return Err(DimacsError::ClauseBeforeHeader { line: line_no });
229        };
230
231        for token in line.split_whitespace() {
232            let lit = token.parse::<isize>().map_err(|_| DimacsError::InvalidLiteral {
233                line: line_no,
234                token: token.to_string(),
235            })?;
236            if lit == 0 {
237                k.add(None);
238                parsed_clauses += 1;
239                open_clause = false;
240                continue;
241            }
242
243            let var_id = lit.unsigned_abs();
244            let max_vars = k.assignment.len() - 1;
245            if var_id == 0 || var_id > max_vars {
246                return Err(DimacsError::LiteralOutOfRange { line: line_no, lit, max_vars });
247            }
248            k.add(Some(lit));
249            open_clause = true;
250        }
251    }
252
253    let Some(kernel) = kernel else {
254        return Err(DimacsError::MissingHeader);
255    };
256    if open_clause {
257        return Err(DimacsError::UnterminatedClause);
258    }
259    if parsed_clauses != expected_clauses {
260        return Err(DimacsError::ClauseCountMismatch {
261            expected: expected_clauses,
262            parsed: parsed_clauses,
263        });
264    }
265    Ok(kernel)
266}
267
268/// `Solver` 构建器:负责准备搜索器与内核初始状态。
269pub struct SolverBuilder<S>
270where
271    S: Search,
272{
273    search: S,
274    kernel: Kernel,
275}
276
277impl<S> SolverBuilder<S>
278where
279    S: Search,
280{
281    /// 用变量上限创建空公式求解器构建器。
282    pub fn with_max_vars(search: S, max_vars: usize) -> Self {
283        Self { search, kernel: Kernel::new(max_vars) }
284    }
285
286    /// 从 DIMACS 字符串创建构建器。
287    pub fn from_dimacs_str(search: S, dimacs: &str) -> Result<Self, DimacsError> {
288        let kernel = parse_dimacs_kernel(dimacs)?;
289        Ok(Self { search, kernel })
290    }
291
292    /// 从 DIMACS 文件创建构建器。
293    pub fn from_dimacs_file(search: S, path: impl AsRef<Path>) -> Result<Self, DimacsError> {
294        let input = fs::read_to_string(path)?;
295        Self::from_dimacs_str(search, &input)
296    }
297
298    /// 只读访问内核(用于调试或查询)。
299    pub fn kernel(&self) -> &Kernel {
300        &self.kernel
301    }
302
303    /// 可变访问内核(用于自定义注入子句/参数)。
304    pub fn kernel_mut(&mut self) -> &mut Kernel {
305        &mut self.kernel
306    }
307
308    /// 构建求解器,初始状态为 `UNKNOWN`。
309    pub fn build(self) -> Solver<S, UNKNOWN> {
310        Solver::new(self.search, self.kernel)
311    }
312}
313
314/// 主求解器对象。
315///
316/// 组件说明:
317/// - `pre_processor`:搜索前执行;
318/// - `in_processor`:搜索循环中周期执行;
319/// - `search`:CDCL 核心;
320/// - `kernel`:统一状态与数据。
321pub struct Solver<S, St = UNKNOWN>
322where
323    S: Search,
324    St: SolverState,
325{
326    pre_processor: Vec<Box<dyn Pass>>,
327    in_processor: Vec<Box<dyn Pass>>,
328    search: S,
329    kernel: Kernel,
330    _state: PhantomData<St>,
331}
332
333impl<S, St> Solver<S, St>
334where
335    S: Search,
336    St: SolverState,
337{
338    /// 返回编译期状态对应的运行时枚举。
339    pub const fn status(&self) -> SolverStatus {
340        St::STATUS
341    }
342
343    /// 只读访问内核。
344    pub fn kernel(&self) -> &Kernel {
345        &self.kernel
346    }
347
348    fn into_state<Next>(self) -> Solver<S, Next>
349    where
350        Next: SolverState,
351    {
352        Solver {
353            pre_processor: self.pre_processor,
354            in_processor: self.in_processor,
355            search: self.search,
356            kernel: self.kernel,
357            _state: PhantomData,
358        }
359    }
360}
361
362impl<S> Solver<S, UNKNOWN>
363where
364    S: Search,
365{
366    /// 创建处于 `UNKNOWN` 状态的求解器。
367    pub fn new(search: S, kernel: Kernel) -> Self {
368        init_logger();
369        Self {
370            pre_processor: Vec::new(),
371            in_processor: Vec::new(),
372            search,
373            kernel,
374            _state: PhantomData,
375        }
376    }
377
378    /// 注册一个预处理 Pass。
379    pub fn add_preprocess_pass(&mut self, pass: impl Pass + 'static) {
380        self.pre_processor.push(Box::new(pass));
381    }
382
383    /// 按逗号分隔短名重排预处理 Pass。
384    pub fn arrange_preprocess_passes(&mut self, ordered: &str) {
385        Self::arrange_passes(&mut self.pre_processor, ordered);
386    }
387
388    /// 注册一个搜索中 Pass。
389    pub fn add_inprocess_pass(&mut self, pass: impl Pass + 'static) {
390        self.in_processor.push(Box::new(pass));
391    }
392
393    /// 按逗号分隔短名重排搜索中 Pass。
394    pub fn arrange_inprocess_passes(&mut self, ordered: &str) {
395        Self::arrange_passes(&mut self.in_processor, ordered);
396    }
397
398    /// 执行完整求解流程并返回分型结果。
399    ///
400    /// 主要顺序:
401    /// 1. 先跑预处理,若可直接判定则提前返回;
402    /// 2. 进入 `SOLVING` 状态,执行 CDCL 搜索;
403    /// 3. 按最终结果转移到 `SAT`/`UNSAT`/`UNKNOWN` 状态。
404    pub fn solve(mut self) -> SolveResult<S> {
405        let mut pre_result = SATResult::UNKNOWN;
406        for pass in &mut self.pre_processor {
407            if pass.applying(&self.kernel) {
408                pre_result = pass.apply(&mut self.kernel);
409                if pre_result != SATResult::UNKNOWN {
410                    break;
411                }
412            }
413        }
414
415        if pre_result == SATResult::SAT {
416            return SolveResult::SAT(self.into_state::<SAT>());
417        }
418        if pre_result == SATResult::UNSAT {
419            return SolveResult::UNSAT(self.into_state::<UNSAT>());
420        }
421
422        let mut solving = self.into_state::<SOLVING>();
423        let result = solving.search.search(&mut solving.kernel, &mut solving.in_processor);
424
425        info!("c solver finished with result: {:?}", result);
426        match result {
427            SATResult::SAT => SolveResult::SAT(solving.into_state::<SAT>()),
428            SATResult::UNSAT => SolveResult::UNSAT(solving.into_state::<UNSAT>()),
429            SATResult::UNKNOWN => SolveResult::UNKNOWN(solving.into_state::<UNKNOWN>()),
430        }
431    }
432
433    /// 根据短名称重排 Pass,未出现在 `ordered` 中的 Pass 会被丢弃。
434    fn arrange_passes(passes: &mut Vec<Box<dyn Pass>>, ordered: &str) {
435        let mut remaining = std::mem::take(passes);
436        let mut ordered_passes: Vec<Box<dyn Pass>> = Vec::with_capacity(remaining.len());
437
438        for short_name in ordered.split(',').map(str::trim).filter(|s| !s.is_empty()) {
439            if let Some(idx) = remaining.iter().position(|p| p.name() == short_name) {
440                ordered_passes.push(remaining.swap_remove(idx));
441            }
442        }
443
444        *passes = ordered_passes;
445    }
446}
447
448impl<S> Solver<S, SAT>
449where
450    S: Search,
451{
452    /// 导出模型(索引 0 保留,不对应实际变量)。
453    pub fn model(&self) -> Vec<bool> {
454        self.kernel.assignment.iter().map(|&value| value == 1).collect()
455    }
456
457    /// 用当前模型逐子句校验 SAT 结果。
458    pub fn check_sat(&self) -> Result<(), String> {
459        let model = self.model();
460        for clause in &self.kernel.clauses {
461            let mut satisfied = false;
462            for &lit in clause.literals() {
463                let var = lit.unsigned_abs();
464                let val = model[var];
465                if (lit > 0 && val) || (lit < 0 && !val) {
466                    satisfied = true;
467                    break;
468                }
469            }
470            if !satisfied {
471                return Err(format!("clause {:?} is not satisfied", clause));
472            }
473        }
474        Ok(())
475    }
476
477    /// 按 DIMACS 竞赛风格打印模型。
478    pub fn print_model(&self) {
479        let model = self.model();
480        print!("v ");
481        for (v, val) in model.iter().enumerate().skip(1) {
482            if *val {
483                print!("{} ", v);
484            } else {
485                print!("-{} ", v);
486            }
487            if v % 10 == 0 {
488                print!("\nv ");
489            }
490        }
491        println!("0");
492    }
493}