Skip to main content

fotos_lib/ai/
ocr.rs

1/// Tesseract OCR pipeline with tiling and preprocessing.
2///
3/// Strategy:
4/// - Small images (both dims ≤ 2000px): upscale 2× then single-pass OCR.
5/// - Large images: divide into overlapping tiles (1024px, 100px overlap),
6///   OCR each tile independently, translate coordinates, then deduplicate
7///   overlapping detections using IoU + text similarity.
8use anyhow::Result;
9use rayon::prelude::*;
10use std::sync::atomic::{AtomicU32, Ordering};
11use tesseract::Tesseract;
12
13pub struct OcrRegion {
14    pub text: String,
15    pub x: u32,
16    pub y: u32,
17    pub w: u32,
18    pub h: u32,
19    pub confidence: f32,
20}
21
22pub struct OcrOutput {
23    pub full_text: String,
24    pub regions: Vec<OcrRegion>,
25}
26
27pub struct OcrOptions {
28    pub lang: String,
29    pub tessdata_path: String,
30}
31
32const TILE_SIZE: u32 = 1024;
33const TILE_OVERLAP: u32 = 100;
34const SMALL_IMAGE_THRESHOLD: u32 = 2000;
35const UPSCALE_FACTOR: u32 = 2;
36/// IoU threshold above which two detections are considered duplicates.
37const IOU_DEDUP_THRESHOLD: f32 = 0.5;
38
39/// Run OCR on an image. Selects between upscale+single-pass (small images)
40/// and tiled processing (large images).
41///
42/// `on_progress` is called after each tile completes with `(completed, total)`.
43pub fn run_ocr(
44    image: &image::DynamicImage,
45    opts: &OcrOptions,
46    on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
47) -> Result<OcrOutput> {
48    let (w, h) = (image.width(), image.height());
49
50    if w <= SMALL_IMAGE_THRESHOLD && h <= SMALL_IMAGE_THRESHOLD {
51        run_upscaled(image, opts, on_progress)
52    } else {
53        run_tiled(image, opts, on_progress)
54    }
55}
56
57// ---------------------------------------------------------------------------
58// Upscale strategy
59// ---------------------------------------------------------------------------
60
61fn run_upscaled(
62    image: &image::DynamicImage,
63    opts: &OcrOptions,
64    on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
65) -> Result<OcrOutput> {
66    let scale = UPSCALE_FACTOR;
67    let upscaled = image.resize(
68        image.width() * scale,
69        image.height() * scale,
70        image::imageops::FilterType::Lanczos3,
71    );
72
73    let mut regions = run_tesseract(&upscaled, opts)?;
74
75    // Scale coordinates back to original image space.
76    for r in &mut regions {
77        r.x /= scale;
78        r.y /= scale;
79        r.w = (r.w / scale).max(1);
80        r.h = (r.h / scale).max(1);
81    }
82
83    if let Some(cb) = on_progress {
84        cb(1, 1);
85    }
86
87    let full_text = regions_to_text(&regions);
88    Ok(OcrOutput { full_text, regions })
89}
90
91// ---------------------------------------------------------------------------
92// Tiled strategy
93// ---------------------------------------------------------------------------
94
95fn run_tiled(
96    image: &image::DynamicImage,
97    opts: &OcrOptions,
98    on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
99) -> Result<OcrOutput> {
100    let img_w = image.width();
101    let img_h = image.height();
102
103    let xs = tile_positions(img_w, TILE_SIZE, TILE_OVERLAP);
104    let ys = tile_positions(img_h, TILE_SIZE, TILE_OVERLAP);
105
106    let coords: Vec<(u32, u32)> = ys
107        .iter()
108        .flat_map(|&ty| xs.iter().map(move |&tx| (tx, ty)))
109        .collect();
110    let total = coords.len() as u32;
111    let done = AtomicU32::new(0);
112
113    let results: Result<Vec<Vec<OcrRegion>>> = coords
114        .par_iter()
115        .map(|(tile_x, tile_y)| {
116            let tw = TILE_SIZE.min(img_w - tile_x);
117            let th = TILE_SIZE.min(img_h - tile_y);
118            let tile = image.crop_imm(*tile_x, *tile_y, tw, th);
119            let mut regions = run_tesseract(&tile, opts)?;
120            for r in &mut regions {
121                r.x += tile_x;
122                r.y += tile_y;
123            }
124            let completed = done.fetch_add(1, Ordering::Relaxed) + 1;
125            if let Some(cb) = on_progress {
126                cb(completed, total);
127            }
128            Ok(regions)
129        })
130        .collect();
131
132    let all_regions: Vec<OcrRegion> = results?.into_iter().flatten().collect();
133    let regions = deduplicate_regions(all_regions);
134    let full_text = regions_to_text(&regions);
135    Ok(OcrOutput { full_text, regions })
136}
137
138// ---------------------------------------------------------------------------
139// Core Tesseract call
140// ---------------------------------------------------------------------------
141
142fn run_tesseract(image: &image::DynamicImage, opts: &OcrOptions) -> Result<Vec<OcrRegion>> {
143    let rgb = image.to_rgb8();
144    let width = rgb.width() as i32;
145    let height = rgb.height() as i32;
146    let bytes_per_line = width * 3;
147    let raw = rgb.into_raw();
148
149    let mut tess = Tesseract::new(Some(&opts.tessdata_path), Some(&opts.lang))
150        .map_err(|e| anyhow::anyhow!("Tesseract init failed: {}", e))?
151        .set_frame(&raw, width, height, 3, bytes_per_line)
152        .map_err(|e| anyhow::anyhow!("Tesseract set_frame failed: {}", e))?
153        .recognize()
154        .map_err(|e| anyhow::anyhow!("Tesseract recognize failed: {}", e))?;
155
156    // TSV columns (level 5 = word):
157    // level page_num block_num par_num line_num word_num left top width height conf text
158    let tsv = tess
159        .get_tsv_text(0)
160        .map_err(|e| anyhow::anyhow!("Tesseract get_tsv_text failed: {}", e))?;
161
162    let regions = tsv
163        .lines()
164        .skip(1)
165        .filter_map(|line| {
166            let cols: Vec<&str> = line.splitn(12, '\t').collect();
167            if cols.len() < 12 {
168                return None;
169            }
170            let level: u32 = cols[0].parse().ok()?;
171            if level != 5 {
172                return None;
173            }
174            let conf: f32 = cols[10].parse().ok()?;
175            if conf < 0.0 {
176                return None;
177            }
178            let text = cols[11].trim().to_string();
179            if text.is_empty() {
180                return None;
181            }
182            Some(OcrRegion {
183                text,
184                x: cols[6].parse().ok()?,
185                y: cols[7].parse().ok()?,
186                w: cols[8].parse().ok()?,
187                h: cols[9].parse().ok()?,
188                confidence: conf,
189            })
190        })
191        .collect();
192
193    Ok(regions)
194}
195
196// ---------------------------------------------------------------------------
197// Tiling helpers
198// ---------------------------------------------------------------------------
199
200/// Returns the left/top pixel positions for tiles covering `total` pixels.
201/// Each tile is `tile_size` pixels wide with `overlap` pixels shared with
202/// the adjacent tile. The last tile is aligned so it ends exactly at `total`.
203fn tile_positions(total: u32, tile_size: u32, overlap: u32) -> Vec<u32> {
204    if total <= tile_size {
205        return vec![0];
206    }
207    let stride = tile_size.saturating_sub(overlap);
208    let mut positions = Vec::new();
209    let mut pos = 0u32;
210    loop {
211        positions.push(pos);
212        let next = pos + stride;
213        if next + tile_size >= total {
214            let last = total - tile_size;
215            if last > pos {
216                positions.push(last);
217            }
218            break;
219        }
220        pos = next;
221    }
222    positions
223}
224
225// ---------------------------------------------------------------------------
226// Deduplication (NMS-style with IoU + text similarity)
227// ---------------------------------------------------------------------------
228
229fn deduplicate_regions(mut regions: Vec<OcrRegion>) -> Vec<OcrRegion> {
230    // Sort by confidence descending so the best detection wins.
231    regions.sort_by(|a, b| {
232        b.confidence
233            .partial_cmp(&a.confidence)
234            .unwrap_or(std::cmp::Ordering::Equal)
235    });
236
237    let n = regions.len();
238    let mut suppressed = vec![false; n];
239
240    for i in 0..n {
241        if suppressed[i] {
242            continue;
243        }
244        for j in (i + 1)..n {
245            if suppressed[j] {
246                continue;
247            }
248            if iou(&regions[i], &regions[j]) > IOU_DEDUP_THRESHOLD
249                && texts_similar(&regions[i].text, &regions[j].text)
250            {
251                suppressed[j] = true;
252            }
253        }
254    }
255
256    regions
257        .into_iter()
258        .enumerate()
259        .filter_map(|(i, r)| if !suppressed[i] { Some(r) } else { None })
260        .collect()
261}
262
263fn iou(a: &OcrRegion, b: &OcrRegion) -> f32 {
264    let ax2 = a.x + a.w;
265    let ay2 = a.y + a.h;
266    let bx2 = b.x + b.w;
267    let by2 = b.y + b.h;
268
269    let ix1 = a.x.max(b.x);
270    let iy1 = a.y.max(b.y);
271    let ix2 = ax2.min(bx2);
272    let iy2 = ay2.min(by2);
273
274    if ix2 <= ix1 || iy2 <= iy1 {
275        return 0.0;
276    }
277
278    let inter = (ix2 - ix1) as f32 * (iy2 - iy1) as f32;
279    let area_a = (a.w * a.h) as f32;
280    let area_b = (b.w * b.h) as f32;
281    let union = area_a + area_b - inter;
282
283    if union <= 0.0 {
284        0.0
285    } else {
286        inter / union
287    }
288}
289
290/// Two words are "similar" if they match case-insensitively, differ by at
291/// most 1 edit, or their edit distance is ≤ 20% of the longer word's length.
292fn texts_similar(a: &str, b: &str) -> bool {
293    if a.eq_ignore_ascii_case(b) {
294        return true;
295    }
296    let max_len = a.len().max(b.len());
297    if max_len == 0 {
298        return true;
299    }
300    let dist = levenshtein(a, b);
301    dist <= 1 || (dist as f32 / max_len as f32) <= 0.2
302}
303
304fn levenshtein(s: &str, t: &str) -> usize {
305    let s: Vec<char> = s.chars().collect();
306    let t: Vec<char> = t.chars().collect();
307    let m = s.len();
308    let n = t.len();
309
310    if m == 0 {
311        return n;
312    }
313    if n == 0 {
314        return m;
315    }
316
317    let mut prev: Vec<usize> = (0..=n).collect();
318    let mut curr = vec![0usize; n + 1];
319
320    for i in 1..=m {
321        curr[0] = i;
322        for j in 1..=n {
323            curr[j] = if s[i - 1] == t[j - 1] {
324                prev[j - 1]
325            } else {
326                1 + prev[j - 1].min(prev[j]).min(curr[j - 1])
327            };
328        }
329        std::mem::swap(&mut prev, &mut curr);
330    }
331
332    prev[n]
333}
334
335// ---------------------------------------------------------------------------
336// Reading-order text reconstruction
337// ---------------------------------------------------------------------------
338
339fn regions_to_text(regions: &[OcrRegion]) -> String {
340    if regions.is_empty() {
341        return String::new();
342    }
343
344    let avg_height = {
345        let sum: u32 = regions.iter().map(|r| r.h).sum();
346        (sum / regions.len() as u32).max(1)
347    };
348
349    let mut sorted: Vec<&OcrRegion> = regions.iter().collect();
350    sorted.sort_by_key(|r| (r.y, r.x));
351
352    let mut text = String::new();
353    let mut prev_bottom = 0u32;
354    let mut prev_right = 0u32;
355
356    for r in &sorted {
357        if !text.is_empty() {
358            if r.y > prev_bottom.saturating_add(avg_height / 2) {
359                text.push('\n');
360            } else if r.x >= prev_right {
361                text.push(' ');
362            }
363        }
364        text.push_str(&r.text);
365        prev_bottom = r.y + r.h;
366        prev_right = r.x + r.w;
367    }
368
369    text
370}
371
372// ---------------------------------------------------------------------------
373// Unit tests
374// ---------------------------------------------------------------------------
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    fn region(text: &str, x: u32, y: u32, w: u32, h: u32, conf: f32) -> OcrRegion {
381        OcrRegion {
382            text: text.into(),
383            x,
384            y,
385            w,
386            h,
387            confidence: conf,
388        }
389    }
390
391    // --- tile_positions ---
392
393    #[test]
394    fn tile_positions_image_smaller_than_tile() {
395        assert_eq!(tile_positions(800, 1024, 100), vec![0]);
396    }
397
398    #[test]
399    fn tile_positions_exact_tile_size() {
400        assert_eq!(tile_positions(1024, 1024, 100), vec![0]);
401    }
402
403    #[test]
404    fn tile_positions_two_tiles() {
405        // total=1200, stride=924 → pos=0, next=924, 924+1024=1948≥1200 → last=176
406        let positions = tile_positions(1200, 1024, 100);
407        assert_eq!(positions, vec![0, 176]);
408    }
409
410    #[test]
411    fn tile_positions_covers_full_width() {
412        let total = 2560u32;
413        let positions = tile_positions(total, TILE_SIZE, TILE_OVERLAP);
414        assert_eq!(*positions.first().unwrap(), 0);
415        let last = *positions.last().unwrap();
416        assert!(last + TILE_SIZE >= total, "last tile does not reach end");
417        for &p in &positions {
418            assert!(p + TILE_SIZE <= total, "tile extends past image");
419        }
420    }
421
422    // --- coordinate translation ---
423
424    #[test]
425    fn tile_offset_translation() {
426        let tile_x = 500u32;
427        let tile_y = 300u32;
428        assert_eq!(tile_x + 10, 510);
429        assert_eq!(tile_y + 5, 305);
430    }
431
432    // --- IoU ---
433
434    #[test]
435    fn iou_identical_boxes() {
436        let r = region("a", 0, 0, 100, 20, 90.0);
437        assert!((iou(&r, &r) - 1.0).abs() < 1e-5);
438    }
439
440    #[test]
441    fn iou_non_overlapping() {
442        let a = region("a", 0, 0, 50, 20, 90.0);
443        let b = region("b", 100, 0, 50, 20, 90.0);
444        assert_eq!(iou(&a, &b), 0.0);
445    }
446
447    #[test]
448    fn iou_partial_overlap() {
449        // Two 100×20 boxes, shifted 50px → inter=50×20=1000, union=3000, iou=1/3
450        let a = region("a", 0, 0, 100, 20, 90.0);
451        let b = region("b", 50, 0, 100, 20, 90.0);
452        let score = iou(&a, &b);
453        assert!((score - 1.0 / 3.0).abs() < 1e-4, "iou={score}");
454    }
455
456    // --- Levenshtein ---
457
458    #[test]
459    fn levenshtein_identical() {
460        assert_eq!(levenshtein("hello", "hello"), 0);
461    }
462
463    #[test]
464    fn levenshtein_one_deletion() {
465        assert_eq!(levenshtein("hello", "helo"), 1);
466    }
467
468    #[test]
469    fn levenshtein_empty() {
470        assert_eq!(levenshtein("", "abc"), 3);
471        assert_eq!(levenshtein("abc", ""), 3);
472    }
473
474    // --- texts_similar ---
475
476    #[test]
477    fn texts_similar_exact() {
478        assert!(texts_similar("word", "word"));
479    }
480
481    #[test]
482    fn texts_similar_one_edit() {
483        assert!(texts_similar("word", "wrd"));
484    }
485
486    #[test]
487    fn texts_similar_case_insensitive() {
488        assert!(texts_similar("Word", "word"));
489    }
490
491    #[test]
492    fn texts_not_similar() {
493        assert!(!texts_similar("hello", "world"));
494    }
495
496    // --- deduplication ---
497
498    #[test]
499    fn dedup_removes_overlapping_identical() {
500        let regions = vec![
501            region("Hello", 10, 10, 60, 20, 85.0),
502            region("Hello", 10, 10, 60, 20, 70.0),
503        ];
504        let deduped = deduplicate_regions(regions);
505        assert_eq!(deduped.len(), 1);
506        assert!((deduped[0].confidence - 85.0).abs() < 1e-5);
507    }
508
509    #[test]
510    fn dedup_keeps_non_overlapping() {
511        let regions = vec![
512            region("Hello", 0, 0, 60, 20, 85.0),
513            region("World", 200, 0, 60, 20, 85.0),
514        ];
515        assert_eq!(deduplicate_regions(regions).len(), 2);
516    }
517
518    #[test]
519    fn dedup_keeps_different_text_same_location() {
520        // Same bbox but very different texts → texts_similar returns false → both kept.
521        let regions = vec![
522            region("Hello", 10, 10, 60, 20, 85.0),
523            region("World", 10, 10, 60, 20, 70.0),
524        ];
525        assert_eq!(deduplicate_regions(regions).len(), 2);
526    }
527
528    // --- regions_to_text ---
529
530    #[test]
531    fn text_reconstruction_same_line() {
532        let regions = vec![
533            region("Hello", 0, 0, 60, 20, 90.0),
534            region("World", 70, 0, 60, 20, 90.0),
535        ];
536        let text = regions_to_text(&regions);
537        assert!(
538            text.contains("Hello") && text.contains("World"),
539            "got: {text}"
540        );
541    }
542
543    #[test]
544    fn text_reconstruction_newline_between_lines() {
545        let regions = vec![
546            region("Line1", 0, 0, 60, 20, 90.0),
547            region("Line2", 0, 100, 60, 20, 90.0),
548        ];
549        let text = regions_to_text(&regions);
550        assert!(text.contains('\n'), "expected newline, got: {text}");
551    }
552
553    // --- reading-order validation for complex / "scrambled" layouts ---
554
555    /// Two-column layout: regions arrive in arbitrary order (as Tesseract might
556    /// return them in a single-pass scan across the full image width).
557    /// After tiling, coordinates are correct; regions_to_text must reconstruct
558    /// reading order: left-to-right within each row, top-to-bottom across rows.
559    #[test]
560    fn text_reconstruction_two_column_layout() {
561        // Col 1 (x≈0):   "Left"  row 0, "Text"  row 1
562        // Col 2 (x≈300): "Right" row 0, "Side"  row 1
563        // Arrive in scrambled (Tesseract single-pass) order.
564        let regions = vec![
565            region("Right", 300, 0, 60, 20, 90.0),
566            region("Side", 300, 30, 60, 20, 90.0),
567            region("Left", 0, 0, 60, 20, 90.0),
568            region("Text", 0, 30, 60, 20, 90.0),
569        ];
570        let text = regions_to_text(&regions);
571        let pos = |w: &str| {
572            text.find(w)
573                .unwrap_or_else(|| panic!("'{w}' missing in: {text}"))
574        };
575        // Within the same row, left column precedes right column.
576        assert!(
577            pos("Left") < pos("Right"),
578            "row 0: left col should precede right col"
579        );
580        assert!(
581            pos("Text") < pos("Side"),
582            "row 1: left col should precede right col"
583        );
584        // Row 0 precedes row 1 for each column.
585        assert!(
586            pos("Left") < pos("Text"),
587            "col 1: top word should precede bottom word"
588        );
589        assert!(
590            pos("Right") < pos("Side"),
591            "col 2: top word should precede bottom word"
592        );
593    }
594
595    /// Sidebar layout: a narrow navigation column (x≈0) beside main content (x≈200).
596    /// Regions arrive out of order; result must preserve per-column reading order.
597    #[test]
598    fn text_reconstruction_sidebar_layout() {
599        let regions = vec![
600            region("File", 10, 10, 60, 18, 90.0),
601            region("The", 200, 10, 40, 18, 90.0),
602            region("Edit", 10, 35, 60, 18, 90.0),
603            region("quick", 250, 10, 50, 18, 90.0),
604            region("View", 10, 60, 60, 18, 90.0),
605            region("brown", 310, 10, 55, 18, 90.0),
606        ];
607        let text = regions_to_text(&regions);
608        for word in &["File", "Edit", "View", "The", "quick", "brown"] {
609            assert!(text.contains(word), "'{word}' missing in: {text}");
610        }
611        // Sidebar items share the same y-band as main content; left col comes first.
612        let pos = |w: &str| text.find(w).unwrap();
613        assert!(
614            pos("File") < pos("The"),
615            "sidebar 'File' should precede main 'The'"
616        );
617        // Sidebar items are top-to-bottom.
618        assert!(
619            pos("File") < pos("Edit"),
620            "File should precede Edit in sidebar"
621        );
622        assert!(
623            pos("Edit") < pos("View"),
624            "Edit should precede View in sidebar"
625        );
626    }
627
628    /// Unsorted input (worst-case Tesseract scramble): regions arrive in reverse
629    /// spatial order. regions_to_text must still produce correct reading order.
630    #[test]
631    fn text_reconstruction_reverse_input_order() {
632        // Three lines, fed in reverse order.
633        let mut regions = vec![
634            region("Third", 0, 80, 60, 20, 90.0),
635            region("Second", 0, 40, 60, 20, 90.0),
636            region("First", 0, 0, 60, 20, 90.0),
637        ];
638        // Shuffle to worst case.
639        regions.reverse();
640        let text = regions_to_text(&regions);
641        let pos = |w: &str| {
642            text.find(w)
643                .unwrap_or_else(|| panic!("'{w}' missing in: {text}"))
644        };
645        assert!(pos("First") < pos("Second"), "First should precede Second");
646        assert!(pos("Second") < pos("Third"), "Second should precede Third");
647    }
648}