1use anyhow::Result;
9use rayon::prelude::*;
10use std::sync::atomic::{AtomicU32, Ordering};
11use tesseract::Tesseract;
12
13pub struct OcrRegion {
14 pub text: String,
15 pub x: u32,
16 pub y: u32,
17 pub w: u32,
18 pub h: u32,
19 pub confidence: f32,
20}
21
22pub struct OcrOutput {
23 pub full_text: String,
24 pub regions: Vec<OcrRegion>,
25}
26
27pub struct OcrOptions {
28 pub lang: String,
29 pub tessdata_path: String,
30}
31
32const TILE_SIZE: u32 = 1024;
33const TILE_OVERLAP: u32 = 100;
34const SMALL_IMAGE_THRESHOLD: u32 = 2000;
35const UPSCALE_FACTOR: u32 = 2;
36const IOU_DEDUP_THRESHOLD: f32 = 0.5;
38
39pub fn run_ocr(
44 image: &image::DynamicImage,
45 opts: &OcrOptions,
46 on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
47) -> Result<OcrOutput> {
48 let (w, h) = (image.width(), image.height());
49
50 if w <= SMALL_IMAGE_THRESHOLD && h <= SMALL_IMAGE_THRESHOLD {
51 run_upscaled(image, opts, on_progress)
52 } else {
53 run_tiled(image, opts, on_progress)
54 }
55}
56
57fn run_upscaled(
62 image: &image::DynamicImage,
63 opts: &OcrOptions,
64 on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
65) -> Result<OcrOutput> {
66 let scale = UPSCALE_FACTOR;
67 let upscaled = image.resize(
68 image.width() * scale,
69 image.height() * scale,
70 image::imageops::FilterType::Lanczos3,
71 );
72
73 let mut regions = run_tesseract(&upscaled, opts)?;
74
75 for r in &mut regions {
77 r.x /= scale;
78 r.y /= scale;
79 r.w = (r.w / scale).max(1);
80 r.h = (r.h / scale).max(1);
81 }
82
83 if let Some(cb) = on_progress {
84 cb(1, 1);
85 }
86
87 let full_text = regions_to_text(®ions);
88 Ok(OcrOutput { full_text, regions })
89}
90
91fn run_tiled(
96 image: &image::DynamicImage,
97 opts: &OcrOptions,
98 on_progress: Option<&(dyn Fn(u32, u32) + Send + Sync)>,
99) -> Result<OcrOutput> {
100 let img_w = image.width();
101 let img_h = image.height();
102
103 let xs = tile_positions(img_w, TILE_SIZE, TILE_OVERLAP);
104 let ys = tile_positions(img_h, TILE_SIZE, TILE_OVERLAP);
105
106 let coords: Vec<(u32, u32)> = ys
107 .iter()
108 .flat_map(|&ty| xs.iter().map(move |&tx| (tx, ty)))
109 .collect();
110 let total = coords.len() as u32;
111 let done = AtomicU32::new(0);
112
113 let results: Result<Vec<Vec<OcrRegion>>> = coords
114 .par_iter()
115 .map(|(tile_x, tile_y)| {
116 let tw = TILE_SIZE.min(img_w - tile_x);
117 let th = TILE_SIZE.min(img_h - tile_y);
118 let tile = image.crop_imm(*tile_x, *tile_y, tw, th);
119 let mut regions = run_tesseract(&tile, opts)?;
120 for r in &mut regions {
121 r.x += tile_x;
122 r.y += tile_y;
123 }
124 let completed = done.fetch_add(1, Ordering::Relaxed) + 1;
125 if let Some(cb) = on_progress {
126 cb(completed, total);
127 }
128 Ok(regions)
129 })
130 .collect();
131
132 let all_regions: Vec<OcrRegion> = results?.into_iter().flatten().collect();
133 let regions = deduplicate_regions(all_regions);
134 let full_text = regions_to_text(®ions);
135 Ok(OcrOutput { full_text, regions })
136}
137
138fn run_tesseract(image: &image::DynamicImage, opts: &OcrOptions) -> Result<Vec<OcrRegion>> {
143 let rgb = image.to_rgb8();
144 let width = rgb.width() as i32;
145 let height = rgb.height() as i32;
146 let bytes_per_line = width * 3;
147 let raw = rgb.into_raw();
148
149 let mut tess = Tesseract::new(Some(&opts.tessdata_path), Some(&opts.lang))
150 .map_err(|e| anyhow::anyhow!("Tesseract init failed: {}", e))?
151 .set_frame(&raw, width, height, 3, bytes_per_line)
152 .map_err(|e| anyhow::anyhow!("Tesseract set_frame failed: {}", e))?
153 .recognize()
154 .map_err(|e| anyhow::anyhow!("Tesseract recognize failed: {}", e))?;
155
156 let tsv = tess
159 .get_tsv_text(0)
160 .map_err(|e| anyhow::anyhow!("Tesseract get_tsv_text failed: {}", e))?;
161
162 let regions = tsv
163 .lines()
164 .skip(1)
165 .filter_map(|line| {
166 let cols: Vec<&str> = line.splitn(12, '\t').collect();
167 if cols.len() < 12 {
168 return None;
169 }
170 let level: u32 = cols[0].parse().ok()?;
171 if level != 5 {
172 return None;
173 }
174 let conf: f32 = cols[10].parse().ok()?;
175 if conf < 0.0 {
176 return None;
177 }
178 let text = cols[11].trim().to_string();
179 if text.is_empty() {
180 return None;
181 }
182 Some(OcrRegion {
183 text,
184 x: cols[6].parse().ok()?,
185 y: cols[7].parse().ok()?,
186 w: cols[8].parse().ok()?,
187 h: cols[9].parse().ok()?,
188 confidence: conf,
189 })
190 })
191 .collect();
192
193 Ok(regions)
194}
195
196fn tile_positions(total: u32, tile_size: u32, overlap: u32) -> Vec<u32> {
204 if total <= tile_size {
205 return vec![0];
206 }
207 let stride = tile_size.saturating_sub(overlap);
208 let mut positions = Vec::new();
209 let mut pos = 0u32;
210 loop {
211 positions.push(pos);
212 let next = pos + stride;
213 if next + tile_size >= total {
214 let last = total - tile_size;
215 if last > pos {
216 positions.push(last);
217 }
218 break;
219 }
220 pos = next;
221 }
222 positions
223}
224
225fn deduplicate_regions(mut regions: Vec<OcrRegion>) -> Vec<OcrRegion> {
230 regions.sort_by(|a, b| {
232 b.confidence
233 .partial_cmp(&a.confidence)
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 let n = regions.len();
238 let mut suppressed = vec![false; n];
239
240 for i in 0..n {
241 if suppressed[i] {
242 continue;
243 }
244 for j in (i + 1)..n {
245 if suppressed[j] {
246 continue;
247 }
248 if iou(®ions[i], ®ions[j]) > IOU_DEDUP_THRESHOLD
249 && texts_similar(®ions[i].text, ®ions[j].text)
250 {
251 suppressed[j] = true;
252 }
253 }
254 }
255
256 regions
257 .into_iter()
258 .enumerate()
259 .filter_map(|(i, r)| if !suppressed[i] { Some(r) } else { None })
260 .collect()
261}
262
263fn iou(a: &OcrRegion, b: &OcrRegion) -> f32 {
264 let ax2 = a.x + a.w;
265 let ay2 = a.y + a.h;
266 let bx2 = b.x + b.w;
267 let by2 = b.y + b.h;
268
269 let ix1 = a.x.max(b.x);
270 let iy1 = a.y.max(b.y);
271 let ix2 = ax2.min(bx2);
272 let iy2 = ay2.min(by2);
273
274 if ix2 <= ix1 || iy2 <= iy1 {
275 return 0.0;
276 }
277
278 let inter = (ix2 - ix1) as f32 * (iy2 - iy1) as f32;
279 let area_a = (a.w * a.h) as f32;
280 let area_b = (b.w * b.h) as f32;
281 let union = area_a + area_b - inter;
282
283 if union <= 0.0 {
284 0.0
285 } else {
286 inter / union
287 }
288}
289
290fn texts_similar(a: &str, b: &str) -> bool {
293 if a.eq_ignore_ascii_case(b) {
294 return true;
295 }
296 let max_len = a.len().max(b.len());
297 if max_len == 0 {
298 return true;
299 }
300 let dist = levenshtein(a, b);
301 dist <= 1 || (dist as f32 / max_len as f32) <= 0.2
302}
303
304fn levenshtein(s: &str, t: &str) -> usize {
305 let s: Vec<char> = s.chars().collect();
306 let t: Vec<char> = t.chars().collect();
307 let m = s.len();
308 let n = t.len();
309
310 if m == 0 {
311 return n;
312 }
313 if n == 0 {
314 return m;
315 }
316
317 let mut prev: Vec<usize> = (0..=n).collect();
318 let mut curr = vec![0usize; n + 1];
319
320 for i in 1..=m {
321 curr[0] = i;
322 for j in 1..=n {
323 curr[j] = if s[i - 1] == t[j - 1] {
324 prev[j - 1]
325 } else {
326 1 + prev[j - 1].min(prev[j]).min(curr[j - 1])
327 };
328 }
329 std::mem::swap(&mut prev, &mut curr);
330 }
331
332 prev[n]
333}
334
335fn regions_to_text(regions: &[OcrRegion]) -> String {
340 if regions.is_empty() {
341 return String::new();
342 }
343
344 let avg_height = {
345 let sum: u32 = regions.iter().map(|r| r.h).sum();
346 (sum / regions.len() as u32).max(1)
347 };
348
349 let mut sorted: Vec<&OcrRegion> = regions.iter().collect();
350 sorted.sort_by_key(|r| (r.y, r.x));
351
352 let mut text = String::new();
353 let mut prev_bottom = 0u32;
354 let mut prev_right = 0u32;
355
356 for r in &sorted {
357 if !text.is_empty() {
358 if r.y > prev_bottom.saturating_add(avg_height / 2) {
359 text.push('\n');
360 } else if r.x >= prev_right {
361 text.push(' ');
362 }
363 }
364 text.push_str(&r.text);
365 prev_bottom = r.y + r.h;
366 prev_right = r.x + r.w;
367 }
368
369 text
370}
371
372#[cfg(test)]
377mod tests {
378 use super::*;
379
380 fn region(text: &str, x: u32, y: u32, w: u32, h: u32, conf: f32) -> OcrRegion {
381 OcrRegion {
382 text: text.into(),
383 x,
384 y,
385 w,
386 h,
387 confidence: conf,
388 }
389 }
390
391 #[test]
394 fn tile_positions_image_smaller_than_tile() {
395 assert_eq!(tile_positions(800, 1024, 100), vec![0]);
396 }
397
398 #[test]
399 fn tile_positions_exact_tile_size() {
400 assert_eq!(tile_positions(1024, 1024, 100), vec![0]);
401 }
402
403 #[test]
404 fn tile_positions_two_tiles() {
405 let positions = tile_positions(1200, 1024, 100);
407 assert_eq!(positions, vec![0, 176]);
408 }
409
410 #[test]
411 fn tile_positions_covers_full_width() {
412 let total = 2560u32;
413 let positions = tile_positions(total, TILE_SIZE, TILE_OVERLAP);
414 assert_eq!(*positions.first().unwrap(), 0);
415 let last = *positions.last().unwrap();
416 assert!(last + TILE_SIZE >= total, "last tile does not reach end");
417 for &p in &positions {
418 assert!(p + TILE_SIZE <= total, "tile extends past image");
419 }
420 }
421
422 #[test]
425 fn tile_offset_translation() {
426 let tile_x = 500u32;
427 let tile_y = 300u32;
428 assert_eq!(tile_x + 10, 510);
429 assert_eq!(tile_y + 5, 305);
430 }
431
432 #[test]
435 fn iou_identical_boxes() {
436 let r = region("a", 0, 0, 100, 20, 90.0);
437 assert!((iou(&r, &r) - 1.0).abs() < 1e-5);
438 }
439
440 #[test]
441 fn iou_non_overlapping() {
442 let a = region("a", 0, 0, 50, 20, 90.0);
443 let b = region("b", 100, 0, 50, 20, 90.0);
444 assert_eq!(iou(&a, &b), 0.0);
445 }
446
447 #[test]
448 fn iou_partial_overlap() {
449 let a = region("a", 0, 0, 100, 20, 90.0);
451 let b = region("b", 50, 0, 100, 20, 90.0);
452 let score = iou(&a, &b);
453 assert!((score - 1.0 / 3.0).abs() < 1e-4, "iou={score}");
454 }
455
456 #[test]
459 fn levenshtein_identical() {
460 assert_eq!(levenshtein("hello", "hello"), 0);
461 }
462
463 #[test]
464 fn levenshtein_one_deletion() {
465 assert_eq!(levenshtein("hello", "helo"), 1);
466 }
467
468 #[test]
469 fn levenshtein_empty() {
470 assert_eq!(levenshtein("", "abc"), 3);
471 assert_eq!(levenshtein("abc", ""), 3);
472 }
473
474 #[test]
477 fn texts_similar_exact() {
478 assert!(texts_similar("word", "word"));
479 }
480
481 #[test]
482 fn texts_similar_one_edit() {
483 assert!(texts_similar("word", "wrd"));
484 }
485
486 #[test]
487 fn texts_similar_case_insensitive() {
488 assert!(texts_similar("Word", "word"));
489 }
490
491 #[test]
492 fn texts_not_similar() {
493 assert!(!texts_similar("hello", "world"));
494 }
495
496 #[test]
499 fn dedup_removes_overlapping_identical() {
500 let regions = vec![
501 region("Hello", 10, 10, 60, 20, 85.0),
502 region("Hello", 10, 10, 60, 20, 70.0),
503 ];
504 let deduped = deduplicate_regions(regions);
505 assert_eq!(deduped.len(), 1);
506 assert!((deduped[0].confidence - 85.0).abs() < 1e-5);
507 }
508
509 #[test]
510 fn dedup_keeps_non_overlapping() {
511 let regions = vec![
512 region("Hello", 0, 0, 60, 20, 85.0),
513 region("World", 200, 0, 60, 20, 85.0),
514 ];
515 assert_eq!(deduplicate_regions(regions).len(), 2);
516 }
517
518 #[test]
519 fn dedup_keeps_different_text_same_location() {
520 let regions = vec![
522 region("Hello", 10, 10, 60, 20, 85.0),
523 region("World", 10, 10, 60, 20, 70.0),
524 ];
525 assert_eq!(deduplicate_regions(regions).len(), 2);
526 }
527
528 #[test]
531 fn text_reconstruction_same_line() {
532 let regions = vec![
533 region("Hello", 0, 0, 60, 20, 90.0),
534 region("World", 70, 0, 60, 20, 90.0),
535 ];
536 let text = regions_to_text(®ions);
537 assert!(
538 text.contains("Hello") && text.contains("World"),
539 "got: {text}"
540 );
541 }
542
543 #[test]
544 fn text_reconstruction_newline_between_lines() {
545 let regions = vec![
546 region("Line1", 0, 0, 60, 20, 90.0),
547 region("Line2", 0, 100, 60, 20, 90.0),
548 ];
549 let text = regions_to_text(®ions);
550 assert!(text.contains('\n'), "expected newline, got: {text}");
551 }
552
553 #[test]
560 fn text_reconstruction_two_column_layout() {
561 let regions = vec![
565 region("Right", 300, 0, 60, 20, 90.0),
566 region("Side", 300, 30, 60, 20, 90.0),
567 region("Left", 0, 0, 60, 20, 90.0),
568 region("Text", 0, 30, 60, 20, 90.0),
569 ];
570 let text = regions_to_text(®ions);
571 let pos = |w: &str| {
572 text.find(w)
573 .unwrap_or_else(|| panic!("'{w}' missing in: {text}"))
574 };
575 assert!(
577 pos("Left") < pos("Right"),
578 "row 0: left col should precede right col"
579 );
580 assert!(
581 pos("Text") < pos("Side"),
582 "row 1: left col should precede right col"
583 );
584 assert!(
586 pos("Left") < pos("Text"),
587 "col 1: top word should precede bottom word"
588 );
589 assert!(
590 pos("Right") < pos("Side"),
591 "col 2: top word should precede bottom word"
592 );
593 }
594
595 #[test]
598 fn text_reconstruction_sidebar_layout() {
599 let regions = vec![
600 region("File", 10, 10, 60, 18, 90.0),
601 region("The", 200, 10, 40, 18, 90.0),
602 region("Edit", 10, 35, 60, 18, 90.0),
603 region("quick", 250, 10, 50, 18, 90.0),
604 region("View", 10, 60, 60, 18, 90.0),
605 region("brown", 310, 10, 55, 18, 90.0),
606 ];
607 let text = regions_to_text(®ions);
608 for word in &["File", "Edit", "View", "The", "quick", "brown"] {
609 assert!(text.contains(word), "'{word}' missing in: {text}");
610 }
611 let pos = |w: &str| text.find(w).unwrap();
613 assert!(
614 pos("File") < pos("The"),
615 "sidebar 'File' should precede main 'The'"
616 );
617 assert!(
619 pos("File") < pos("Edit"),
620 "File should precede Edit in sidebar"
621 );
622 assert!(
623 pos("Edit") < pos("View"),
624 "Edit should precede View in sidebar"
625 );
626 }
627
628 #[test]
631 fn text_reconstruction_reverse_input_order() {
632 let mut regions = vec![
634 region("Third", 0, 80, 60, 20, 90.0),
635 region("Second", 0, 40, 60, 20, 90.0),
636 region("First", 0, 0, 60, 20, 90.0),
637 ];
638 regions.reverse();
640 let text = regions_to_text(®ions);
641 let pos = |w: &str| {
642 text.find(w)
643 .unwrap_or_else(|| panic!("'{w}' missing in: {text}"))
644 };
645 assert!(pos("First") < pos("Second"), "First should precede Second");
646 assert!(pos("Second") < pos("Third"), "Second should precede Third");
647 }
648}