mirror of
https://github.com/Priler/jarvis.git
synced 2026-06-03 02:49:46 +00:00
New intent classification engine - MiniLM L6v2 and MiniLM L12v2 ONNX
This commit is contained in:
145
Cargo.lock
generated
145
Cargo.lock
generated
@@ -877,19 +877,6 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "console"
|
||||
version = "0.15.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
|
||||
dependencies = [
|
||||
"encode_unicode",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"unicode-width",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "convert_case"
|
||||
version = "0.4.0"
|
||||
@@ -1674,12 +1661,6 @@ version = "1.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7"
|
||||
|
||||
[[package]]
|
||||
name = "encode_unicode"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.35"
|
||||
@@ -1811,8 +1792,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59a3f841f27a44bcc32214f8df75cc9b6cea55dbbebbfe546735690eab5bb2d2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"hf-hub",
|
||||
"image",
|
||||
"ndarray",
|
||||
"ort",
|
||||
"safetensors 0.7.0",
|
||||
@@ -2829,27 +2808,6 @@ version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||
|
||||
[[package]]
|
||||
name = "hf-hub"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
|
||||
dependencies = [
|
||||
"dirs",
|
||||
"http",
|
||||
"indicatif",
|
||||
"libc",
|
||||
"log",
|
||||
"native-tls",
|
||||
"rand 0.9.2",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.17",
|
||||
"ureq 2.12.1",
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hmac-sha256"
|
||||
version = "1.1.13"
|
||||
@@ -2957,22 +2915,6 @@ dependencies = [
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-tls"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"native-tls",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper-util"
|
||||
version = "0.1.19"
|
||||
@@ -3204,19 +3146,6 @@ dependencies = [
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indicatif"
|
||||
version = "0.17.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235"
|
||||
dependencies = [
|
||||
"console",
|
||||
"number_prefix",
|
||||
"portable-atomic",
|
||||
"unicode-width",
|
||||
"web-time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "infer"
|
||||
version = "0.19.0"
|
||||
@@ -4281,12 +4210,6 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "number_prefix"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
|
||||
|
||||
[[package]]
|
||||
name = "objc-sys"
|
||||
version = "0.3.5"
|
||||
@@ -4874,7 +4797,7 @@ dependencies = [
|
||||
"ort-sys",
|
||||
"smallvec",
|
||||
"tracing",
|
||||
"ureq 3.2.0",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4885,7 +4808,7 @@ checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b"
|
||||
dependencies = [
|
||||
"hmac-sha256",
|
||||
"lzma-rust2",
|
||||
"ureq 3.2.0",
|
||||
"ureq",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5866,30 +5789,22 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-tls",
|
||||
"hyper-util",
|
||||
"js-sys",
|
||||
"log",
|
||||
"mime",
|
||||
"native-tls",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http",
|
||||
@@ -6079,9 +5994,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b"
|
||||
dependencies = [
|
||||
"aws-lc-rs",
|
||||
"log",
|
||||
"once_cell",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki",
|
||||
"subtle",
|
||||
@@ -7801,16 +7714,6 @@ dependencies = [
|
||||
"syn 2.0.114",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-native-tls"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
|
||||
dependencies = [
|
||||
"native-tls",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.26.4"
|
||||
@@ -8210,12 +8113,6 @@ version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-width"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
@@ -8240,26 +8137,6 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"flate2",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socks",
|
||||
"url",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "3.2.0"
|
||||
@@ -8706,24 +8583,6 @@ dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9"
|
||||
dependencies = [
|
||||
"webpki-roots 1.0.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webview2-com"
|
||||
version = "0.38.2"
|
||||
|
||||
@@ -45,4 +45,4 @@ mlua = { version = "0.11.5", features = ["lua55", "vendored", "async", "serde"]
|
||||
reqwest = { version = "0.13.1", features = ["blocking", "json"] }
|
||||
tempfile = "^3.24"
|
||||
winrt-notification = "0.5"
|
||||
fastembed = "^5.8.1"
|
||||
fastembed = { version = "^5.8.1", default-features = false, features = ["ort-download-binaries"] }
|
||||
@@ -156,6 +156,8 @@ pub const VOSK_SPEECH_PARTIAL_WORDS: bool = false;
|
||||
// IRE (intents recognition)
|
||||
pub const INTENT_CLASSIFIER_MIN_CONFIDENCE: f64 = 0.75;
|
||||
|
||||
// embedding classifier
|
||||
pub const EMBEDDING_MIN_CONFIDENCE: f64 = 0.60;
|
||||
|
||||
// AUDIO PROCESSING DEFAULTS
|
||||
pub const DEFAULT_NOISE_SUPPRESSION: NoiseSuppressionBackend = NoiseSuppressionBackend::None;
|
||||
@@ -180,7 +182,7 @@ pub const DEFAULT_LUA_SANDBOX: &str = "standard";
|
||||
pub const DEFAULT_LUA_TIMEOUT: u64 = 10000; // ms
|
||||
|
||||
// ETC
|
||||
pub const CMD_RATIO_THRESHOLD: f64 = 65f64;
|
||||
pub const CMD_RATIO_THRESHOLD: f64 = 75f64;
|
||||
pub const CMS_WAIT_DELAY: std::time::Duration = std::time::Duration::from_secs(15);
|
||||
|
||||
// pub const ASSISTANT_GREET_PHRASES: [&str; 3] = ["greet1", "greet2", "greet3"];
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod intentclassifier;
|
||||
mod embeddingclassifier;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
@@ -23,7 +24,11 @@ pub async fn init(commands: &Vec<JCommandsList>) -> Result<(), String> {
|
||||
intentclassifier::init(&commands).await?;
|
||||
info!("IRE backend initialized.");
|
||||
},
|
||||
IntentRecognitionEngine::Rasa => todo!(),
|
||||
IntentRecognitionEngine::EmbeddingClassifier => {
|
||||
info!("Initializing EmbeddingClassifier IRE backend.");
|
||||
embeddingclassifier::init(&commands)?;
|
||||
info!("EmbeddingClassifier IRE backend initialized.");
|
||||
},
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -47,7 +52,21 @@ pub async fn classify(text: &str) -> Option<(String, f64)> {
|
||||
}
|
||||
}
|
||||
}
|
||||
IntentRecognitionEngine::Rasa => todo!(),
|
||||
IntentRecognitionEngine::EmbeddingClassifier => {
|
||||
match embeddingclassifier::classify(text) {
|
||||
Ok((intent_id, confidence)) => {
|
||||
if confidence >= config::EMBEDDING_MIN_CONFIDENCE {
|
||||
Some((intent_id, confidence))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Embedding classification error: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +75,8 @@ pub fn get_command_by_intent(commands: &'static Vec<JCommandsList>, intent_id: &
|
||||
IntentRecognitionEngine::IntentClassifier => {
|
||||
intentclassifier::get_command(commands, intent_id)
|
||||
}
|
||||
IntentRecognitionEngine::Rasa => todo!(),
|
||||
IntentRecognitionEngine::EmbeddingClassifier => {
|
||||
embeddingclassifier::get_command(commands, intent_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
240
crates/jarvis-core/src/intent/embeddingclassifier.rs
Normal file
240
crates/jarvis-core/src/intent/embeddingclassifier.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
use parking_lot::Mutex;
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
|
||||
// use fastembed::{TextEmbedding, InitOptions, EmbeddingModel};
|
||||
use fastembed::{TextEmbedding, UserDefinedEmbeddingModel, TokenizerFiles, InitOptionsUserDefined, Pooling, QuantizationMode, OutputKey};
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
use crate::commands::JCommandsList;
|
||||
use crate::i18n::get_language;
|
||||
use crate::{APP_CONFIG_DIR, APP_DIR, i18n};
|
||||
|
||||
static CLASSIFIER: OnceCell<Mutex<EmbeddingClassifier>> = OnceCell::new();
|
||||
|
||||
struct IntentVector {
|
||||
id: String,
|
||||
vector: Vec<f32>,
|
||||
}
|
||||
|
||||
struct EmbeddingClassifier {
|
||||
model: TextEmbedding,
|
||||
intents: Vec<IntentVector>,
|
||||
}
|
||||
|
||||
const CACHE_FILE: &str = "embedding_intents.json";
|
||||
const HASH_FILE: &str = "embedding_hash.txt";
|
||||
|
||||
pub fn init(commands: &[JCommandsList]) -> Result<(), String> {
|
||||
if CLASSIFIER.get().is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Initializing embedding model...");
|
||||
|
||||
// let mut model = TextEmbedding::try_new(
|
||||
// InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(true),
|
||||
// ).map_err(|e| format!("Failed to load embedding model: {}", e))?;
|
||||
|
||||
let model_dir;
|
||||
match i18n::get_language().as_str() {
|
||||
"en" => {
|
||||
// smaller model for English
|
||||
model_dir = APP_DIR.join("resources").join("models").join("all-MiniLM-L6-v2");
|
||||
},
|
||||
_ => {
|
||||
// bigger model for any other languages (multilingual)
|
||||
model_dir = APP_DIR.join("resources").join("models").join("paraphrase-multilingual-MiniLM-L12-v2-onnx-Q");
|
||||
}
|
||||
}
|
||||
|
||||
let user_model = UserDefinedEmbeddingModel {
|
||||
onnx_file: std::fs::read(model_dir.join("model.onnx"))
|
||||
.map_err(|e| format!("Failed to read model.onnx: {}", e))?,
|
||||
tokenizer_files: TokenizerFiles {
|
||||
tokenizer_file: std::fs::read(model_dir.join("tokenizer.json"))
|
||||
.map_err(|e| format!("Failed to read tokenizer.json: {}", e))?,
|
||||
config_file: std::fs::read(model_dir.join("config.json"))
|
||||
.map_err(|e| format!("Failed to read config.json: {}", e))?,
|
||||
special_tokens_map_file: std::fs::read(model_dir.join("special_tokens_map.json"))
|
||||
.map_err(|e| format!("Failed to read special_tokens_map.json: {}", e))?,
|
||||
tokenizer_config_file: std::fs::read(model_dir.join("tokenizer_config.json"))
|
||||
.map_err(|e| format!("Failed to read tokenizer_config.json: {}", e))?,
|
||||
},
|
||||
pooling: Some(Pooling::Mean),
|
||||
quantization: QuantizationMode::None,
|
||||
output_key: Some(OutputKey::ByName("last_hidden_state")),
|
||||
};
|
||||
|
||||
let mut model = TextEmbedding::try_new_from_user_defined(user_model, Default::default())
|
||||
.map_err(|e| format!("Failed to load embedding model: {}", e))?;
|
||||
|
||||
info!("Embedding model loaded");
|
||||
|
||||
let current_hash = crate::commands::commands_hash(commands);
|
||||
let config_dir = APP_CONFIG_DIR.get().ok_or("Config dir not set")?;
|
||||
let hash_path = config_dir.join(HASH_FILE);
|
||||
let cache_path = config_dir.join(CACHE_FILE);
|
||||
|
||||
// check if cached vectors are still valid
|
||||
let should_retrain = if hash_path.exists() && cache_path.exists() {
|
||||
let stored_hash = fs::read_to_string(&hash_path).unwrap_or_default();
|
||||
stored_hash.trim() != current_hash
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
let intents = if should_retrain {
|
||||
info!("Building intent vectors from commands...");
|
||||
let intents = build_intent_vectors(&mut model, commands)?;
|
||||
|
||||
// cache to disk
|
||||
if let Ok(json) = serde_json::to_string(&intents_to_cache(&intents)) {
|
||||
let _ = fs::write(&cache_path, json);
|
||||
let _ = fs::write(&hash_path, ¤t_hash);
|
||||
info!("Intent vectors cached");
|
||||
}
|
||||
|
||||
intents
|
||||
} else {
|
||||
info!("Loading cached intent vectors...");
|
||||
load_cached_intents(&cache_path)?
|
||||
};
|
||||
|
||||
info!("Embedding classifier ready with {} intents", intents.len());
|
||||
|
||||
CLASSIFIER.set(Mutex::new(EmbeddingClassifier { model, intents }))
|
||||
.map_err(|_| "Classifier already set")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_intent_vectors(
|
||||
model: &mut TextEmbedding,
|
||||
commands: &[JCommandsList],
|
||||
) -> Result<Vec<IntentVector>, String> {
|
||||
let lang = i18n::get_language();
|
||||
let mut intents = Vec::new();
|
||||
|
||||
for cmd_list in commands {
|
||||
for cmd in &cmd_list.commands {
|
||||
let phrases = cmd.get_phrases(&lang);
|
||||
if phrases.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let texts: Vec<&str> = phrases.iter().map(|s| s.as_str()).collect();
|
||||
|
||||
let embeddings = model.embed(texts, None)
|
||||
.map_err(|e| format!("Embedding failed for '{}': {}", cmd.id, e))?;
|
||||
|
||||
// average all phrase vectors into one intent vector
|
||||
let dim = embeddings[0].len();
|
||||
let mut avg = vec![0.0f32; dim];
|
||||
|
||||
for emb in &embeddings {
|
||||
for (i, val) in emb.iter().enumerate() {
|
||||
avg[i] += val;
|
||||
}
|
||||
}
|
||||
|
||||
let count = embeddings.len() as f32;
|
||||
for val in &mut avg {
|
||||
*val /= count;
|
||||
}
|
||||
|
||||
// normalize
|
||||
let norm: f32 = avg.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for val in &mut avg {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
intents.push(IntentVector {
|
||||
id: cmd.id.clone(),
|
||||
vector: avg,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(intents)
|
||||
}
|
||||
|
||||
pub fn classify(text: &str) -> Result<(String, f64), String> {
|
||||
let mut classifier = CLASSIFIER.get().ok_or("Classifier not initialized")?.lock();
|
||||
|
||||
let embeddings = classifier.model.embed(vec![text], None)
|
||||
.map_err(|e| format!("Failed to embed query: {}", e))?;
|
||||
|
||||
let mut query_vec = embeddings.into_iter().next()
|
||||
.ok_or("Empty embedding result")?;
|
||||
|
||||
// normalize query
|
||||
let norm: f32 = query_vec.iter().map(|v| v * v).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for val in &mut query_vec {
|
||||
*val /= norm;
|
||||
}
|
||||
}
|
||||
|
||||
// cosine similarity against all intents (dot product of normalized vectors)
|
||||
let mut best_id = String::new();
|
||||
let mut best_score: f64 = -1.0;
|
||||
|
||||
for intent in &classifier.intents {
|
||||
let score: f64 = query_vec.iter()
|
||||
.zip(intent.vector.iter())
|
||||
.map(|(a, b)| (*a as f64) * (*b as f64))
|
||||
.sum();
|
||||
|
||||
if score > best_score {
|
||||
best_score = score;
|
||||
best_id = intent.id.clone();
|
||||
}
|
||||
}
|
||||
|
||||
Ok((best_id, best_score))
|
||||
}
|
||||
|
||||
pub fn get_command<'a>(
|
||||
commands: &'a [JCommandsList],
|
||||
intent_id: &str,
|
||||
) -> Option<(&'a PathBuf, &'a crate::commands::JCommand)> {
|
||||
for cmd_list in commands {
|
||||
for cmd in &cmd_list.commands {
|
||||
if cmd.id == intent_id {
|
||||
return Some((&cmd_list.path, cmd));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// ### CACHE HELPERS
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct CachedIntent {
|
||||
id: String,
|
||||
vector: Vec<f32>,
|
||||
}
|
||||
|
||||
fn intents_to_cache(intents: &[IntentVector]) -> Vec<CachedIntent> {
|
||||
intents.iter().map(|i| CachedIntent {
|
||||
id: i.id.clone(),
|
||||
vector: i.vector.clone(),
|
||||
}).collect()
|
||||
}
|
||||
|
||||
fn load_cached_intents(path: &PathBuf) -> Result<Vec<IntentVector>, String> {
|
||||
let json = fs::read_to_string(path)
|
||||
.map_err(|e| format!("Failed to read cache: {}", e))?;
|
||||
|
||||
let cached: Vec<CachedIntent> = serde_json::from_str(&json)
|
||||
.map_err(|e| format!("Failed to parse cache: {}", e))?;
|
||||
|
||||
Ok(cached.into_iter().map(|c| IntentVector {
|
||||
id: c.id,
|
||||
vector: c.vector,
|
||||
}).collect())
|
||||
}
|
||||
@@ -50,7 +50,7 @@ pub fn db_write(state: tauri::State<'_, AppState>, key: &str, val: &str) -> bool
|
||||
"selected_intent_recognition_engine" => {
|
||||
match val.to_lowercase().as_str() {
|
||||
"intentclassifier" => settings.intent_recognition_engine = jarvis_core::config::structs::IntentRecognitionEngine::IntentClassifier,
|
||||
"rasa" => settings.intent_recognition_engine = jarvis_core::config::structs::IntentRecognitionEngine::Rasa,
|
||||
"embeddingclassifier" => settings.intent_recognition_engine = jarvis_core::config::structs::IntentRecognitionEngine::EmbeddingClassifier,
|
||||
_ => return false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -359,7 +359,7 @@
|
||||
<NativeSelect
|
||||
data={[
|
||||
{ label: "Intent Classifier", value: "IntentClassifier" },
|
||||
{ label: "Rasa", value: "Rasa" }
|
||||
{ label: "Embedding Classifier", value: "EmbeddingClassifier" }
|
||||
]}
|
||||
label={t('settings-intent-engine')}
|
||||
description={t('settings-intent-engine-desc')}
|
||||
|
||||
@@ -17,6 +17,7 @@ SOURCE = (
|
||||
("resources/keywords/", "resources/keywords/"),
|
||||
("resources/rustpotter/", "resources/rustpotter/"),
|
||||
("resources/sound/", "resources/sound/"),
|
||||
("resources/models/", "resources/models/"),
|
||||
|
||||
# vosk
|
||||
("lib/windows/amd64/libgcc_s_seh-1.dll", None),
|
||||
|
||||
35
resources/models/all-MiniLM-L6-v2/.gitattributes
vendored
Normal file
35
resources/models/all-MiniLM-L6-v2/.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
35
resources/models/all-MiniLM-L6-v2/README.md
Normal file
35
resources/models/all-MiniLM-L6-v2/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
pipeline_tag: sentence-similarity
|
||||
---
|
||||
|
||||
ONNX port of [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) for text classification and similarity searches.
|
||||
|
||||
### Usage
|
||||
|
||||
Here's an example of performing inference using the model with [FastEmbed](https://github.com/qdrant/fastembed).
|
||||
|
||||
```py
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
documents = [
|
||||
"You should stay, study and sprint.",
|
||||
"History can only prepare us to be surprised yet again.",
|
||||
]
|
||||
|
||||
model = TextEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
||||
embeddings = list(model.embed(documents))
|
||||
|
||||
# [
|
||||
# array([
|
||||
# 0.00611658, 0.00068912, -0.0203846, ..., -0.01751488, -0.01174267,
|
||||
# 0.01463472
|
||||
# ],
|
||||
# dtype=float32),
|
||||
# array([
|
||||
# 0.00173448, -0.00329958, 0.01557874, ..., -0.01473586, 0.0281806,
|
||||
# -0.00448205
|
||||
# ],
|
||||
# dtype=float32)
|
||||
# ]
|
||||
```
|
||||
25
resources/models/all-MiniLM-L6-v2/config.json
Normal file
25
resources/models/all-MiniLM-L6-v2/config.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"classifier_dropout": null,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 384,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 1536,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 6,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"transformers_version": "4.36.2",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
3
resources/models/all-MiniLM-L6-v2/model.onnx
Normal file
3
resources/models/all-MiniLM-L6-v2/model.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bbd7b466f6d58e646fdc2bd5fd67b2f5e93c0b687011bd4548c420f7bd46f0c5
|
||||
size 90387630
|
||||
37
resources/models/all-MiniLM-L6-v2/special_tokens_map.json
Normal file
37
resources/models/all-MiniLM-L6-v2/special_tokens_map.json
Normal file
@@ -0,0 +1,37 @@
|
||||
{
|
||||
"cls_token": {
|
||||
"content": "[CLS]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"mask_token": {
|
||||
"content": "[MASK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "[PAD]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"sep_token": {
|
||||
"content": "[SEP]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "[UNK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
30686
resources/models/all-MiniLM-L6-v2/tokenizer.json
Normal file
30686
resources/models/all-MiniLM-L6-v2/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
64
resources/models/all-MiniLM-L6-v2/tokenizer_config.json
Normal file
64
resources/models/all-MiniLM-L6-v2/tokenizer_config.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "[PAD]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"100": {
|
||||
"content": "[UNK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"101": {
|
||||
"content": "[CLS]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"102": {
|
||||
"content": "[SEP]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"103": {
|
||||
"content": "[MASK]",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"max_length": 128,
|
||||
"model_max_length": 512,
|
||||
"never_split": null,
|
||||
"pad_to_multiple_of": null,
|
||||
"pad_token": "[PAD]",
|
||||
"pad_token_type_id": 0,
|
||||
"padding_side": "right",
|
||||
"sep_token": "[SEP]",
|
||||
"stride": 0,
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"truncation_side": "right",
|
||||
"truncation_strategy": "longest_first",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
30522
resources/models/all-MiniLM-L6-v2/vocab.txt
Normal file
30522
resources/models/all-MiniLM-L6-v2/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
37
resources/models/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q/.gitattributes
vendored
Normal file
37
resources/models/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q/.gitattributes
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
||||
unigram.json filter=lfs diff=lfs merge=lfs -text
|
||||
@@ -0,0 +1,27 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
pipeline_tag: sentence-similarity
|
||||
---
|
||||
|
||||
Quantized ONNX port of [sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2](https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2) for text classification and similarity searches.
|
||||
|
||||
### Usage
|
||||
|
||||
Here's an example of performing inference using the model with [FastEmbed](https://github.com/qdrant/fastembed).
|
||||
|
||||
```py
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
documents = [
|
||||
"You should stay, study and sprint.",
|
||||
"History can only prepare us to be surprised yet again.",
|
||||
]
|
||||
|
||||
model = TextEmbedding(model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
|
||||
embeddings = list(model.embed(documents))
|
||||
|
||||
# [
|
||||
# array([1.96449570e-02, 1.60677675e-02, 4.10149433e-02...]),
|
||||
# array([-1.56669170e-02, -1.66313536e-02, -6.84525725e-03...])
|
||||
# ]
|
||||
```
|
||||
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
"architectures": [
|
||||
"BertModel"
|
||||
],
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"classifier_dropout": null,
|
||||
"gradient_checkpointing": false,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 384,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 1536,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "bert",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"position_embedding_type": "absolute",
|
||||
"transformers_version": "4.36.2",
|
||||
"type_vocab_size": 2,
|
||||
"use_cache": true,
|
||||
"vocab_size": 250037
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:634d0f66c29dc934c8fa72b8a4fe91dd4d420a22f1d82a241058d4316e659a99
|
||||
size 235052644
|
||||
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"one_external_file": true,
|
||||
"opset": null,
|
||||
"optimization": {
|
||||
"disable_attention": null,
|
||||
"disable_attention_fusion": false,
|
||||
"disable_bias_gelu": null,
|
||||
"disable_bias_gelu_fusion": false,
|
||||
"disable_bias_skip_layer_norm": null,
|
||||
"disable_bias_skip_layer_norm_fusion": false,
|
||||
"disable_embed_layer_norm": true,
|
||||
"disable_embed_layer_norm_fusion": true,
|
||||
"disable_gelu": null,
|
||||
"disable_gelu_fusion": false,
|
||||
"disable_group_norm_fusion": true,
|
||||
"disable_layer_norm": null,
|
||||
"disable_layer_norm_fusion": false,
|
||||
"disable_packed_kv": true,
|
||||
"disable_rotary_embeddings": false,
|
||||
"disable_shape_inference": false,
|
||||
"disable_skip_layer_norm": null,
|
||||
"disable_skip_layer_norm_fusion": false,
|
||||
"enable_gelu_approximation": true,
|
||||
"enable_gemm_fast_gelu_fusion": false,
|
||||
"enable_transformers_specific_optimizations": true,
|
||||
"fp16": true,
|
||||
"no_attention_mask": false,
|
||||
"optimization_level": 2,
|
||||
"optimize_for_gpu": true,
|
||||
"optimize_with_onnxruntime_only": null,
|
||||
"use_mask_index": false,
|
||||
"use_multi_head_attention": false,
|
||||
"use_raw_attention_mask": false
|
||||
},
|
||||
"optimum_version": "1.15.0",
|
||||
"quantization": {},
|
||||
"transformers_version": "4.36.2",
|
||||
"use_external_data_format": false
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"cls_token": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"mask_token": {
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"sep_token": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fa685fc160bbdbab64058d4fc91b60e62d207e8dc60b9af5c002c5ab946ded00
|
||||
size 17083009
|
||||
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"250001": {
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "<s>",
|
||||
"do_lower_case": true,
|
||||
"eos_token": "</s>",
|
||||
"mask_token": "<mask>",
|
||||
"max_length": 128,
|
||||
"model_max_length": 512,
|
||||
"pad_to_multiple_of": null,
|
||||
"pad_token": "<pad>",
|
||||
"pad_token_type_id": 0,
|
||||
"padding_side": "right",
|
||||
"sep_token": "</s>",
|
||||
"stride": 0,
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"truncation_side": "right",
|
||||
"truncation_strategy": "longest_first",
|
||||
"unk_token": "<unk>"
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:da145b5e7700ae40f16691ec32a0b1fdc1ee3298db22a31ea55f57a966c4a65d
|
||||
size 14763260
|
||||
Reference in New Issue
Block a user