Skip to main content

easy_sat_rs/
search.rs

1use tracing::debug;
2
3use crate::{
4    api::{Pass, Search},
5    common::Watches,
6    constants::SATResult,
7    kernel::Kernel,
8};
9
10/// 默认 CDCL 搜索器实现。
11///
12/// 其主循环为:
13/// `BCP -> (冲突 ? 分析+回跳 : 决策) -> 重复`。
14pub struct Searcher;
15
16#[cfg_attr(doc, aquamarine::aquamarine)]
17impl Search for Searcher {
18    /// 使用双观察字(2WL)执行布尔约束传播(BCP)。
19    ///
20    /// # 核心状态与索引约定
21    /// - `trail[..propagated]`:已经完成传播的真文字。
22    /// - `trail[propagated..]`:刚变为真、但尚未处理的文字。
23    /// - 监视器按“使被监视文字变假的文字”来建索引。
24    ///   换言之,若子句监视 $w$,则其条目存放在 $watch(\neg w)$。
25    ///
26    /// 监视列表会在子句加入内核时由内部初始化逻辑自动建立。
27    ///
28    /// 这也解释了这里为什么用 `lit` 访问 [watches](Kernel::watches):
29    /// 当文字 $l$ 被赋为真时,所有监视 $\neg l$ 的子句都可能失效,必须立即重检。
30    ///
31    /// # 算法流程
32    /// 1. 从 `trail` 取出一个尚未传播的真文字 $l$。
33    /// 2. 用 `mem::take` 取走 `watch(l)`,避免遍历时与原桶相互干扰。
34    /// 3. 对桶内每个 watcher 执行三路判断:
35    ///    - 若 `blocker` 已为真,子句已满足,直接保留 watcher;
36    ///    - 否则尝试在子句中寻找新的可监视文字(值不为假),若找到则迁移监视;
37    ///    - 若找不到替代文字,则该子句要么成为单子句并触发蕴含赋值,要么直接冲突。
38    /// 4. 将压实后的 watcher 列表写回同一个监视桶。
39    ///
40    /// # 关键结果分支
41    /// - **迁移监视成功**:子句继续保持“两个被监视文字”的结构,无需立即赋值。
42    /// - **单子句**:另一个被监视文字成为唯一可满足候选,触发强制赋值。
43    /// - **冲突**:两个被监视文字都为假,且无替代文字,返回冲突供后续分析学习。
44    ///
45    /// ```mermaid
46    /// flowchart TD
47    ///     A[从 trail 取 lit] --> B[取出 watch(lit)]
48    ///     B --> C{blocker 为真?}
49    ///     C -->|是| K[保留 watcher]
50    ///     C -->|否| D[检查子句并规范化被监视位]
51    ///     D --> E{能找到新监视文字?}
52    ///     E -->|是| M[迁移 watcher 到新桶]
53    ///     E -->|否| F{first_lit 为假?}
54    ///     F -->|是| X[记录冲突并返回 false]
55    ///     F -->|否| U[first_lit 成为单子句并赋值]
56    ///     K --> N[写回压实 watcher]
57    ///     M --> N
58    ///     U --> N
59    /// ```
60    ///
61    /// # 示例
62    /// 设子句 $C = (\neg x \lor y \lor z)$,初始监视 $(\neg x, y)$。
63    ///
64    /// 监视表中对应条目为:
65    /// - `watch(x)` 中存放“监视 $\neg x$”的 watcher(因为 $x = -(\neg x)$);
66    /// - `watch(\neg y)` 中存放“监视 $y$”的 watcher。
67    ///
68    /// 若一次决策令 $x = \top$,则 `trail` 新增 `x`:
69    /// - 此时 $\neg x = \bot$,子句 $C$ 变为“需要检查”;
70    /// - [propagate](Self::propagate) 会处理 `lit = x` 并读取 `watch(x)`;
71    /// - 之后可能迁移到监视 $z$、推出 $y$ 为单子句,或直接报告冲突。
72    ///
73    /// 例如当前赋值为:
74    /// - `x = true`
75    /// - `y = false`
76    /// - `z = unassigned`
77    ///
78    /// 则该子句在传播时会把监视从 `¬x` 迁移到 `z`,避免整句重复扫描。
79    fn propagate(&mut self, kernel: &mut Kernel) -> bool {
80        // TODO Should re-write it more rusty
81        while kernel.propagated < kernel.trail.len() {
82            let lit = kernel.trail[kernel.propagated];
83            kernel.propagated += 1;
84            let mut ws = kernel.watches(lit);
85            let mut i = 0usize;
86            let mut j = 0usize;
87            let size = ws.len();
88            while i < size {
89                let blocker = ws[i].blocker;
90                if kernel.value(blocker) == 1 {
91                    ws[j] = ws[i];
92                    i += 1;
93                    j += 1;
94                    continue;
95                }
96
97                let clause_id = ws[i].clause_id;
98                let (first_lit, clause_len) = {
99                    let clause = &mut kernel.clauses[clause_id];
100                    let clause_len = clause.literals().len();
101                    if clause_len > 1 && clause[0] == -lit {
102                        clause[0] = clause[1];
103                        clause[1] = -lit;
104                    }
105                    (clause[0], clause_len)
106                };
107                let w: Watches = Watches { clause_id, blocker: first_lit };
108                i += 1;
109                if kernel.value(first_lit) == 1 {
110                    ws[j] = w;
111                    j += 1;
112                    continue;
113                }
114                // The first two literals are watched literals.
115                let mut k = 2usize;
116                while k < clause_len {
117                    let lit_k = kernel.clauses[clause_id][k];
118                    if kernel.value(lit_k) != -1 {
119                        break;
120                    }
121                    k += 1;
122                }
123                if k < clause_len {
124                    let moved_watch_lit = {
125                        let clause = &mut kernel.clauses[clause_id];
126                        clause[1] = clause[k];
127                        clause[k] = -lit;
128                        clause[1]
129                    };
130                    kernel.add_watch(-moved_watch_lit, w);
131                } else {
132                    ws[j] = w;
133                    j += 1;
134                    if kernel.value(first_lit) == -1 {
135                        while i < size {
136                            ws[j] = ws[i];
137                            j += 1;
138                            i += 1;
139                        }
140                        ws.truncate(j);
141                        kernel.set_watches(lit, ws);
142                        kernel.conflict = Some((clause_id, first_lit));
143                        return false;
144                    }
145                    kernel.assign(first_lit.unsigned_abs(), first_lit, Some(clause_id));
146                }
147            }
148            ws.truncate(j);
149            kernel.set_watches(lit, ws);
150        }
151        true
152    }
153
154    /// 选择一个变量进行决策赋值。
155    ///
156    /// # 流程
157    /// 1. 用 VSIDS 在“未赋值变量”中选出优先级最高的变量。
158    /// 2. 通过相位启发式(phase saving / target / forced)确定该变量的极性。
159    /// 3. 进入新的决策层,并更新决策统计计数。
160    /// 4. 将该文字作为“决策赋值”压入 trail。
161    ///
162    /// ```mermaid
163    /// flowchart TD
164    ///     A[从 VSIDS 取最高分未赋值变量] --> B{找到变量?}
165    ///     B -->|是| C[按 phase 选择极性]
166    ///     C --> D[决策层 +1]
167    ///     D --> E[assign(reason=None)]
168    ///     B -->|否| F[所有变量已赋值 -> SAT]
169    /// ```
170    ///
171    /// 因此这里调用 [assign](Kernel::assign) 时 `reason = None`:
172    /// 该赋值不是由任何子句蕴含出来的,而是一个分支点。后续冲突分析与非时序回溯
173    /// 都会以这些分支点为边界进行学习和回跳。
174    ///
175    /// # SAT 结束条件
176    /// 如果 VSIDS 找不到任何未赋值变量,说明所有变量都已被赋值,且先前传播阶段
177    /// 没有导出矛盾,因此该实例可判定为 SAT。
178    fn decide(&mut self, kernel: &mut Kernel) {
179        let var = kernel.vsids.next_variable(&|v| kernel.assignment[v] == 0);
180        if let Some(var_id) = var {
181            let lit = kernel.phases.decide_phase(var_id, true);
182            debug!("c deciding variable: {:?}, and assign literal: {:?}", var_id, lit);
183            kernel.statistics.decisions += 1;
184            kernel.level += 1;
185            kernel.assign(var_id, lit, None);
186        } else {
187            kernel.result = SATResult::SAT;
188        }
189    }
190
191    /// 冲突分析:基于 First-UIP 构造学习子句并计算回跳层级。
192    ///
193    /// # 对应理论
194    /// 从冲突子句出发,沿 trail 逆序做归结,
195    /// 直到学习子句中“当前层变量”只剩一个(即 First-UIP)。
196    ///
197    /// ```mermaid
198    /// flowchart TD
199    ///     A[冲突子句] --> B[标记文字并统计 open]
200    ///     B --> C[逆序扫描 trail 找当前层已标记文字]
201    ///     C --> D[将该文字作为 resolve_lit]
202    ///     D --> E[open -= 1]
203    ///     E --> F{open == 0?}
204    ///     F -->|否| G[跳到原因子句继续归结]
205    ///     G --> B
206    ///     F -->|是| H[得到 First-UIP]
207    ///     H --> I[构建学习子句并计算 LBD]
208    ///     I --> J[确定 backtrack_level 与 learnt]
209    /// ```
210    ///
211    /// # 主要步骤
212    /// 1. 读取当前冲突子句,标记参与归结的变量并 bump VSIDS 分数。
213    /// 2. 用 `open` 统计“当前层尚未消解完”的变量个数。
214    /// 3. 逆序扫描 trail,不断取原因子句继续归结,直到 `open == 0`。
215    /// 4. 生成学习子句,计算 LBD,并把次高层级放到位置 1 以确定回跳层。
216    /// 5. 写入 `kernel.learnt`,供后续 [backtrack](Self::backtrack) 断言。
217    ///
218    /// # 例子
219    /// 设当前 trail 末尾为:`x5@2, x6@2, x7@2`,冲突子句为 `(!x1 v !x6 v !x7)`。
220    /// 归结过程会先消掉 `x7`、再消掉 `x6`,最终在当前层仅剩一个文字(First-UIP),
221    /// 得到学习子句形如 `(!x1 v !x5)`,并据此回跳到次高层。
222    fn analyze(&mut self, kernel: &mut Kernel) {
223        let Some((mut conflict_idx, _conflict_lit)) = kernel.conflict else {
224            panic!("c no conflict clause found, crashed in propagate");
225        };
226
227        kernel.statistics.conflicts += 1;
228        let conflict_level = kernel.level;
229        if conflict_level == 0 {
230            kernel.backtrack_level = 0;
231            kernel.result = SATResult::UNSAT;
232            return;
233        }
234
235        kernel.lemma.clear();
236        kernel.lemma.push(0);
237
238        let var_stamp = kernel.next_mark_epoch();
239        let mut bump_vars: Vec<usize> = Vec::new();
240        let mut open = 0usize;
241        let mut resolve_lit = 0isize;
242        let mut trail_idx = kernel.trail.len();
243
244        while open > 0 || resolve_lit == 0 {
245            let clause_len = kernel.clauses[conflict_idx].literals().len();
246            for i in 0..clause_len {
247                let q = kernel.clauses[conflict_idx][i];
248                if q == resolve_lit {
249                    continue;
250                }
251
252                let var_id = q.unsigned_abs();
253                let level = kernel.vars[var_id].level;
254                if level == 0 || kernel.mark_at(var_id) == var_stamp {
255                    continue;
256                }
257
258                kernel.set_mark_at(var_id, var_stamp);
259                kernel.vsids.bump_var_score(var_id);
260                bump_vars.push(var_id);
261                if level == conflict_level {
262                    open += 1;
263                } else {
264                    kernel.lemma.push(q);
265                }
266            }
267
268            loop {
269                trail_idx -= 1;
270                let lit = kernel.trail[trail_idx];
271                let var_id = lit.unsigned_abs();
272                if kernel.mark_at(var_id) == var_stamp
273                    && kernel.vars[var_id].level == conflict_level
274                {
275                    resolve_lit = lit;
276                    break;
277                }
278            }
279
280            let resolve_var = resolve_lit.unsigned_abs();
281            kernel.set_mark_at(resolve_var, 0);
282            open -= 1;
283            if open == 0 {
284                break;
285            }
286            conflict_idx = kernel.vars[resolve_var]
287                .reason
288                .unwrap_or_else(|| panic!("c missing reason for lit {resolve_lit}"));
289        }
290
291        kernel.lemma[0] = -resolve_lit;
292
293        let level_stamp = kernel.next_mark_epoch();
294        let mut lbd = 0u32;
295        for i in 0..kernel.lemma.len() {
296            let lit = kernel.lemma[i];
297            let level = kernel.vars[lit.unsigned_abs()].level;
298            if level > 0 && kernel.mark_at(level) != level_stamp {
299                kernel.set_mark_at(level, level_stamp);
300                lbd += 1;
301            }
302        }
303
304        if kernel.lemma.len() == 1 {
305            kernel.backtrack_level = 0;
306        } else {
307            let mut max_idx = 1usize;
308            let mut max_level = kernel.vars[kernel.lemma[1].unsigned_abs()].level;
309            for i in 2..kernel.lemma.len() {
310                let level = kernel.vars[kernel.lemma[i].unsigned_abs()].level;
311                if level > max_level {
312                    max_level = level;
313                    max_idx = i;
314                }
315            }
316            if max_idx != 1 {
317                kernel.lemma.swap(1, max_idx);
318            }
319            kernel.backtrack_level = max_level;
320        }
321
322        let threshold = kernel.backtrack_level.saturating_sub(1);
323        for var_id in bump_vars {
324            if kernel.vars[var_id].level >= threshold {
325                kernel.vsids.bump_var_score(var_id);
326            }
327        }
328
329        let lemma = std::mem::take(&mut kernel.lemma);
330        debug!("c learned clause: {:?}", lemma);
331        let first_lit = lemma[0];
332        let clause_id = kernel.add_learned_clause_with_lbd(lemma, lbd);
333        if kernel.clauses[clause_id].literals().len() == 1 {
334            kernel.learnt = (first_lit, None);
335        } else {
336            kernel.learnt = (first_lit, Some(clause_id));
337        }
338        kernel.conflict = None;
339        if kernel.statistics.conflicts % 5000 == 0 {
340            kernel.vsids.bump_decay_factor();
341        }
342    }
343
344    /// 非时序回溯(Backjumping)。
345    ///
346    /// 该步骤会撤销所有高于 `backtrack_level` 的赋值,然后立刻断言学习子句
347    /// 的 UIP 文字,使搜索跳转到更有信息的位置,而不是简单回到上一层。
348    ///
349    /// ```mermaid
350    /// flowchart TD
351    ///     A[根据 backtrack_level 弹出 trail] --> B[撤销对应 assignment]
352    ///     B --> C[恢复 level 与 propagated]
353    ///     C --> D[断言 learnt 文字]
354    ///     D --> E[进入下一轮传播]
355    /// ```
356    ///
357    /// 例:若当前在 7 层,分析得到 `backtrack_level = 3`,
358    /// 则会一次性撤销 4~7 层赋值,再在 3 层断言学习子句首文字。
359    fn backtrack(&mut self, kernel: &mut Kernel) {
360        debug!(
361            "c backtracking to level: {}, and assign literal: {:?}",
362            kernel.backtrack_level, kernel.learnt
363        );
364        while let Some(&lit) = kernel.trail.last() {
365            let var_id = lit.unsigned_abs();
366            if kernel.vars[var_id].level <= kernel.backtrack_level {
367                break;
368            }
369            kernel.trail.pop();
370            kernel.reset_value(var_id);
371        }
372        kernel.level = kernel.backtrack_level;
373        kernel.propagated = kernel.propagated.min(kernel.trail.len());
374
375        let (lit, clause_id) = kernel.learnt;
376        kernel.assign(lit.unsigned_abs(), lit, clause_id);
377    }
378
379    /// 驱动 CDCL 主循环直到得到结果。
380    ///
381    /// 流程顺序:
382    /// 1. 先做传播,若冲突则分析并回跳;
383    /// 2. 若无冲突且已全赋值,返回 SAT;
384    /// 3. 否则执行 in-process passes,再做一次决策,继续循环。
385    fn search(&mut self, kernel: &mut Kernel, in_processor: &mut Vec<Box<dyn Pass>>) -> SATResult {
386        while kernel.result == SATResult::UNKNOWN {
387            if kernel.result == SATResult::UNSAT {
388                return SATResult::UNSAT;
389            } else if !self.propagate(kernel) {
390                self.analyze(kernel);
391                if kernel.result == SATResult::UNSAT {
392                    return SATResult::UNSAT;
393                }
394                self.backtrack(kernel);
395            } else if kernel.satisfied() {
396                return SATResult::SAT;
397            } else {
398                for pass in in_processor.iter_mut() {
399                    if pass.applying(kernel) {
400                        pass.apply(kernel);
401                    }
402                }
403                self.decide(kernel);
404            }
405        }
406        kernel.result
407    }
408}