diff --git a/Cargo.lock b/Cargo.lock index 121e9a28dd..1ff9de0d5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -550,7 +550,7 @@ dependencies = [ "libc", "pin-project", "redox_syscall 0.2.16", - "xattr", + "xattr 0.2.3", ] [[package]] @@ -2008,6 +2008,12 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -3206,6 +3212,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if 1.0.0", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.11.2" @@ -4513,6 +4529,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex 0.4.4", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "ndk" version = "0.7.0" @@ -4640,7 +4669,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" dependencies = [ "num-bigint 0.2.6", - "num-complex", + "num-complex 0.2.4", "num-integer", "num-iter", "num-rational 0.2.4", @@ -4696,6 +4725,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + [[package]] name = "num-derive" version = "0.3.3" @@ -4935,6 +4973,26 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ort" +version = "1.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5e56c9c4185ee949ef961aca8777d1dbd52cb104b444669adad63e8181820a7" +dependencies = [ + "flate2", + "half", + "lazy_static", + "libc", + "ndarray", + "tar", + "thiserror", + "tracing", + "ureq", + "vswhom", + "winapi 0.3.9", + "zip", +] + [[package]] name = "os_str_bytes" version = "6.5.1" @@ -6404,6 +6462,18 @@ dependencies = [ "webpki 0.22.0", ] +[[package]] +name = "rustls" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8" +dependencies = [ + "log", + "ring", + "rustls-webpki 0.101.4", + "sct 0.7.0", +] + [[package]] name = "rustls-pemfile" version = "1.0.3" @@ -6413,6 +6483,26 @@ dependencies = [ "base64 0.21.2", ] +[[package]] +name = "rustls-webpki" +version = "0.100.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "rustls-webpki" +version = "0.101.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -6734,6 +6824,7 @@ dependencies = [ "lazy_static", "log", "matrixmultiply", + "ort", "parking_lot 0.11.2", "parse_duration", "picker", @@ -7586,6 +7677,17 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tar" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +dependencies = [ + "filetime", + "libc", + "xattr 1.0.1", +] + [[package]] name = "target-lexicon" version = "0.12.11" @@ -8657,6 +8759,21 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "ureq" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +dependencies = [ + "base64 0.21.2", + "log", + "once_cell", + "rustls 0.21.7", + "rustls-webpki 0.100.2", + "url", + "webpki-roots 0.23.1", +] + [[package]] name = "url" version = "2.4.0" @@ -8849,6 +8966,26 @@ dependencies = [ "workspace", ] +[[package]] +name = "vswhom" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b" +dependencies = [ + "libc", + "vswhom-sys", +] + +[[package]] +name = "vswhom-sys" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "vte" version = "0.11.1" @@ -9315,6 +9452,15 @@ dependencies = [ "webpki 0.22.0", ] +[[package]] +name = "webpki-roots" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" +dependencies = [ + "rustls-webpki 0.100.2", +] + [[package]] name = "weezl" version = "0.1.7" @@ -9711,6 +9857,15 @@ dependencies = [ "libc", ] +[[package]] +name = "xattr" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985" +dependencies = [ + "libc", +] + [[package]] name = "xmlparser" version = "0.13.5" @@ -9914,6 +10069,18 @@ dependencies = [ "syn 2.0.29", ] +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", + "flate2", +] + [[package]] name = "zstd" version = "0.11.2+zstd.1.5.2" diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 72a36efd50..93658b1c3f 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -41,6 +41,7 @@ schemars.workspace = true globset.workspace = true sha1 = "0.10.5" parse_duration = "2.1.1" +ort = { version = "1.15.2", features = ["coreml"]} [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } diff --git a/crates/semantic_index/src/cross_encoder.rs b/crates/semantic_index/src/cross_encoder.rs new file mode 100644 index 0000000000..9fb928e329 --- /dev/null +++ b/crates/semantic_index/src/cross_encoder.rs @@ -0,0 +1,26 @@ +use ort::{Environment, ExecutionProvider, GraphOptimizationLevel}; + +struct CrossEncoder {} + +impl CrossEncoder { + pub fn load() -> anyhow::Result { + let environment = Environment::builder() + .with_name("cross-encoder") + .with_execution_providers([ExecutionProvider::CoreML(Default::default())]) + .build()? + .into_arc(); + + let model = "../models/cross-encoder.onnx"; + let mut session = environment + .new_session_builder() + .unwrap() + .with_optimization_level(GraphOptimizationLevel::Basic) + .unwrap() + .with_number_threads(1) + .unwrap() + .with_model_from_file(model) + .unwrap(); + + Ok(Self {}) + } +} diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 0e18c42049..83cdd8bdd0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1,3 +1,4 @@ +mod cross_encoder; mod db; mod embedding; mod embedding_queue;