Skip to main content

fotos_lib/commands/
ai.rs

1use crate::ai::ocr::OcrOptions;
2use crate::capture::ImageStore;
3use serde::Serialize;
4use std::path::PathBuf;
5use tauri::Emitter;
6use uuid::Uuid;
7
8#[derive(Serialize)]
9pub struct OcrRegion {
10    pub text: String,
11    pub x: u32,
12    pub y: u32,
13    pub w: u32,
14    pub h: u32,
15    pub confidence: f32,
16}
17
18#[derive(Serialize)]
19pub struct OcrResult {
20    pub text: String,
21    pub regions: Vec<OcrRegion>,
22}
23
24#[derive(Serialize)]
25pub struct BlurRegion {
26    pub x: u32,
27    pub y: u32,
28    pub w: u32,
29    pub h: u32,
30    pub pii_type: String,
31}
32
33#[derive(Clone, Serialize)]
34pub struct OcrProgressPayload {
35    pub current: u32,
36    pub total: u32,
37}
38
39#[derive(Serialize)]
40pub struct LlmResponse {
41    pub provider: String,
42    pub model: String,
43    pub response_text: String,
44    pub tokens_used: u32,
45    pub latency_ms: u64,
46}
47
48/// Resolve the tessdata directory path for the given language.
49/// - "eng" uses the bundled tessdata (or the Flatpak-provided path).
50/// - Other languages use the app data directory where traineddata files
51///   are downloaded on demand.
52pub fn resolve_tessdata_path(app: &tauri::AppHandle, lang: &str) -> Result<String, String> {
53    use tauri::Manager;
54
55    // Non-English langs always live in the app data directory.
56    if lang != "eng" {
57        let dir = app
58            .path()
59            .app_data_dir()
60            .map_err(|e| format!("Failed to get app data dir: {e}"))?
61            .join("tessdata");
62        return dir
63            .to_str()
64            .ok_or_else(|| "tessdata path contains invalid UTF-8".to_string())
65            .map(|s| s.to_string());
66    }
67
68    // English: use the bundled tessdata.
69    if std::env::var("FLATPAK_ID").is_ok() {
70        return Ok("/app/share/tessdata".to_string());
71    }
72    let path: PathBuf = app
73        .path()
74        .resource_dir()
75        .map_err(|e| format!("Failed to get resource dir: {e}"))?
76        .join("resources")
77        .join("tessdata");
78    path.to_str()
79        .ok_or_else(|| "tessdata path contains invalid UTF-8".to_string())
80        .map(|s| s.to_string())
81}
82
83/// Return whether the traineddata file for `lang` is available locally.
84#[tauri::command]
85pub fn tessdata_available(app: tauri::AppHandle, lang: String) -> Result<bool, String> {
86    use tauri::Manager;
87    if lang == "eng" {
88        // English is always bundled.
89        return Ok(true);
90    }
91    let path = app
92        .path()
93        .app_data_dir()
94        .map_err(|e| format!("Failed to get app data dir: {e}"))?
95        .join("tessdata")
96        .join(format!("{lang}.traineddata"));
97    Ok(path.exists())
98}
99
100#[derive(Clone, serde::Serialize)]
101pub struct TessdataProgressPayload {
102    pub lang: String,
103    pub downloaded: u64,
104    pub total: u64,
105}
106
107/// Download the tessdata file for `lang` from the Tesseract GitHub release.
108/// No-ops if the file is already present. Emits `tessdata:progress` events
109/// with `{ lang, downloaded, total }` (total=0 when content-length is unknown).
110#[tauri::command]
111pub async fn download_tessdata(app: tauri::AppHandle, lang: String) -> Result<(), String> {
112    use tauri::{Emitter, Manager};
113
114    match lang.as_str() {
115        "fra" | "deu" | "spa" => {}
116        other => return Err(format!("Unsupported tessdata language: {other}")),
117    }
118
119    let tessdata_dir = app
120        .path()
121        .app_data_dir()
122        .map_err(|e| format!("Failed to get app data dir: {e}"))?
123        .join("tessdata");
124
125    std::fs::create_dir_all(&tessdata_dir)
126        .map_err(|e| format!("Failed to create tessdata dir: {e}"))?;
127
128    let dest = tessdata_dir.join(format!("{lang}.traineddata"));
129    if dest.exists() {
130        return Ok(());
131    }
132
133    let url =
134        format!("https://raw.githubusercontent.com/tesseract-ocr/tessdata/main/{lang}.traineddata");
135
136    let client = reqwest::Client::new();
137    let response = client
138        .get(&url)
139        .send()
140        .await
141        .map_err(|e| format!("Download request failed: {e}"))?;
142
143    if !response.status().is_success() {
144        return Err(format!("Download failed: HTTP {}", response.status()));
145    }
146
147    let total = response.content_length().unwrap_or(0);
148
149    // Emit an initial progress event so the UI can show "Downloading…".
150    let _ = app.emit(
151        "tessdata:progress",
152        TessdataProgressPayload {
153            lang: lang.clone(),
154            downloaded: 0,
155            total,
156        },
157    );
158
159    let bytes = response
160        .bytes()
161        .await
162        .map_err(|e| format!("Failed to read download body: {e}"))?;
163
164    let downloaded = bytes.len() as u64;
165    std::fs::write(&dest, &bytes).map_err(|e| format!("Failed to write tessdata: {e}"))?;
166
167    let _ = app.emit(
168        "tessdata:progress",
169        TessdataProgressPayload {
170            lang: lang.clone(),
171            downloaded,
172            total: downloaded,
173        },
174    );
175
176    Ok(())
177}
178
179#[tauri::command]
180pub fn run_ocr(
181    app: tauri::AppHandle,
182    image_id: String,
183    lang: Option<String>,
184    store: tauri::State<'_, ImageStore>,
185) -> Result<OcrResult, String> {
186    let uuid = Uuid::parse_str(&image_id).map_err(|e| format!("Invalid image ID: {e}"))?;
187    let image = store
188        .get(&uuid)
189        .ok_or_else(|| format!("Image not found: {image_id}"))?;
190
191    let lang = lang.unwrap_or_else(|| "eng".to_string());
192    let tessdata_path = resolve_tessdata_path(&app, &lang)?;
193    let opts = OcrOptions {
194        lang,
195        tessdata_path,
196    };
197
198    let progress_app = app.clone();
199    let on_progress = move |current: u32, total: u32| {
200        let _ = progress_app.emit("ocr:progress", OcrProgressPayload { current, total });
201    };
202    let output = crate::ai::ocr::run_ocr(&image, &opts, Some(&on_progress))
203        .map_err(|e| format!("OCR failed: {e}"))?;
204
205    let regions = output
206        .regions
207        .into_iter()
208        .map(|r| OcrRegion {
209            text: r.text,
210            x: r.x,
211            y: r.y,
212            w: r.w,
213            h: r.h,
214            confidence: r.confidence,
215        })
216        .collect();
217
218    Ok(OcrResult {
219        text: output.full_text,
220        regions,
221    })
222}
223
224#[tauri::command]
225pub fn auto_blur_pii(
226    app: tauri::AppHandle,
227    image_id: String,
228    store: tauri::State<'_, ImageStore>,
229) -> Result<Vec<BlurRegion>, String> {
230    let uuid = Uuid::parse_str(&image_id).map_err(|e| format!("Invalid image ID: {e}"))?;
231    let image = store
232        .get(&uuid)
233        .ok_or_else(|| format!("Image not found: {image_id}"))?;
234
235    let tessdata_path = resolve_tessdata_path(&app, "eng")?;
236    let opts = OcrOptions {
237        lang: "eng".to_string(),
238        tessdata_path,
239    };
240
241    let ocr_output =
242        crate::ai::ocr::run_ocr(&image, &opts, None).map_err(|e| format!("OCR failed: {e}"))?;
243
244    let pii_matches = crate::ai::pii::detect_pii(&ocr_output.regions)
245        .map_err(|e| format!("PII detection failed: {e}"))?;
246
247    let blur_regions = pii_matches
248        .into_iter()
249        .map(|m| BlurRegion {
250            x: m.x,
251            y: m.y,
252            w: m.w,
253            h: m.h,
254            pii_type: m.pii_type,
255        })
256        .collect();
257
258    Ok(blur_regions)
259}
260
261#[tauri::command]
262pub async fn analyze_llm(
263    app: tauri::AppHandle,
264    image_id: String,
265    prompt: Option<String>,
266    provider: String,
267    store: tauri::State<'_, ImageStore>,
268) -> Result<LlmResponse, String> {
269    use crate::ai::{compress, llm, openai_compat};
270    use tauri_plugin_store::StoreExt;
271
272    let uuid = Uuid::parse_str(&image_id).map_err(|e| format!("Invalid image ID: {e}"))?;
273    let image = store
274        .get(&uuid)
275        .ok_or_else(|| format!("Image not found: {image_id}"))?;
276
277    let prefs_store = app
278        .store("prefs.json")
279        .map_err(|e| format!("Store error: {e}"))?;
280    let ai_settings: crate::commands::settings::AiSettings = prefs_store
281        .get("ai")
282        .and_then(|v| serde_json::from_value(v).ok())
283        .unwrap_or_default();
284
285    let image_b64 =
286        compress::compress_for_llm(&image, ai_settings.image_max_dim, ai_settings.image_quality)
287            .map_err(|e| format!("Image compression failed: {e}"))?;
288
289    let prompt_text = prompt.unwrap_or_else(|| "Describe this image.".to_string());
290
291    let output = match provider.as_str() {
292        "claude" | "anthropic" => {
293            let api_key = crate::credentials::get_api_key("anthropic")
294                .map_err(|_| "No Anthropic API key configured".to_string())?;
295            let llm_provider = llm::LlmProvider::Claude {
296                model: ai_settings.claude_model.clone(),
297            };
298            llm::analyze(&image_b64, &prompt_text, &llm_provider, &api_key)
299                .await
300                .map_err(|e| e.to_string())?
301        }
302        "gemini" => {
303            let api_key = crate::credentials::get_api_key("gemini")
304                .map_err(|_| "No Gemini API key configured".to_string())?;
305            let llm_provider = llm::LlmProvider::Gemini {
306                model: ai_settings.gemini_model.clone(),
307            };
308            llm::analyze(&image_b64, &prompt_text, &llm_provider, &api_key)
309                .await
310                .map_err(|e| e.to_string())?
311        }
312        s if s.starts_with("endpoint:") => {
313            let id = &s["endpoint:".len()..];
314            let endpoint = ai_settings
315                .endpoints
316                .iter()
317                .find(|e| e.id == id)
318                .ok_or_else(|| format!("Unknown endpoint '{id}'"))?;
319            let api_key = crate::credentials::get_api_key(&provider).unwrap_or_default();
320            openai_compat::analyze(
321                &image_b64,
322                &prompt_text,
323                &endpoint.base_url,
324                &endpoint.model,
325                &api_key,
326            )
327            .await
328            .map_err(|e| e.to_string())?
329        }
330        other => return Err(format!("Unknown provider '{other}'")),
331    };
332
333    Ok(LlmResponse {
334        provider: provider.clone(),
335        model: output.model,
336        response_text: output.response,
337        tokens_used: output.tokens_used,
338        latency_ms: output.latency_ms,
339    })
340}