wip bring in ort

This commit is contained in:
KCaverly
2023-09-08 10:27:09 -04:00
parent e7b7ac9d8c
commit 65add70a37
4 changed files with 197 additions and 2 deletions

171
Cargo.lock generated
View File

@@ -550,7 +550,7 @@ dependencies = [
"libc",
"pin-project",
"redox_syscall 0.2.16",
"xattr",
"xattr 0.2.3",
]
[[package]]
@@ -2008,6 +2008,12 @@ dependencies = [
"cfg-if 1.0.0",
]
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "crypto-common"
version = "0.1.6"
@@ -3206,6 +3212,16 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872"
dependencies = [
"cfg-if 1.0.0",
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.11.2"
@@ -4513,6 +4529,19 @@ dependencies = [
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex 0.4.4",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]]
name = "ndk"
version = "0.7.0"
@@ -4640,7 +4669,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
dependencies = [
"num-bigint 0.2.6",
"num-complex",
"num-complex 0.2.4",
"num-integer",
"num-iter",
"num-rational 0.2.4",
@@ -4696,6 +4725,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
dependencies = [
"num-traits",
]
[[package]]
name = "num-derive"
version = "0.3.3"
@@ -4935,6 +4973,26 @@ dependencies = [
"num-traits",
]
[[package]]
name = "ort"
version = "1.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5e56c9c4185ee949ef961aca8777d1dbd52cb104b444669adad63e8181820a7"
dependencies = [
"flate2",
"half",
"lazy_static",
"libc",
"ndarray",
"tar",
"thiserror",
"tracing",
"ureq",
"vswhom",
"winapi 0.3.9",
"zip",
]
[[package]]
name = "os_str_bytes"
version = "6.5.1"
@@ -6404,6 +6462,18 @@ dependencies = [
"webpki 0.22.0",
]
[[package]]
name = "rustls"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd8d6c9f025a446bc4d18ad9632e69aec8f287aa84499ee335599fabd20c3fd8"
dependencies = [
"log",
"ring",
"rustls-webpki 0.101.4",
"sct 0.7.0",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.3"
@@ -6413,6 +6483,26 @@ dependencies = [
"base64 0.21.2",
]
[[package]]
name = "rustls-webpki"
version = "0.100.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "rustls-webpki"
version = "0.101.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d93931baf2d282fff8d3a532bbfd7653f734643161b87e3e01e59a04439bf0d"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.14"
@@ -6734,6 +6824,7 @@ dependencies = [
"lazy_static",
"log",
"matrixmultiply",
"ort",
"parking_lot 0.11.2",
"parse_duration",
"picker",
@@ -7586,6 +7677,17 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "tar"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
dependencies = [
"filetime",
"libc",
"xattr 1.0.1",
]
[[package]]
name = "target-lexicon"
version = "0.12.11"
@@ -8657,6 +8759,21 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "ureq"
version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
dependencies = [
"base64 0.21.2",
"log",
"once_cell",
"rustls 0.21.7",
"rustls-webpki 0.100.2",
"url",
"webpki-roots 0.23.1",
]
[[package]]
name = "url"
version = "2.4.0"
@@ -8849,6 +8966,26 @@ dependencies = [
"workspace",
]
[[package]]
name = "vswhom"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be979b7f07507105799e854203b470ff7c78a1639e330a58f183b5fea574608b"
dependencies = [
"libc",
"vswhom-sys",
]
[[package]]
name = "vswhom-sys"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3b17ae1f6c8a2b28506cd96d412eebf83b4a0ff2cbefeeb952f2f9dfa44ba18"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "vte"
version = "0.11.1"
@@ -9315,6 +9452,15 @@ dependencies = [
"webpki 0.22.0",
]
[[package]]
name = "webpki-roots"
version = "0.23.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338"
dependencies = [
"rustls-webpki 0.100.2",
]
[[package]]
name = "weezl"
version = "0.1.7"
@@ -9711,6 +9857,15 @@ dependencies = [
"libc",
]
[[package]]
name = "xattr"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985"
dependencies = [
"libc",
]
[[package]]
name = "xmlparser"
version = "0.13.5"
@@ -9914,6 +10069,18 @@ dependencies = [
"syn 2.0.29",
]
[[package]]
name = "zip"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]]
name = "zstd"
version = "0.11.2+zstd.1.5.2"

View File

@@ -41,6 +41,7 @@ schemars.workspace = true
globset.workspace = true
sha1 = "0.10.5"
parse_duration = "2.1.1"
ort = { version = "1.15.2", features = ["coreml"]}
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }

View File

@@ -0,0 +1,26 @@
use ort::{Environment, ExecutionProvider, GraphOptimizationLevel};
struct CrossEncoder {}
impl CrossEncoder {
pub fn load() -> anyhow::Result<Self> {
let environment = Environment::builder()
.with_name("cross-encoder")
.with_execution_providers([ExecutionProvider::CoreML(Default::default())])
.build()?
.into_arc();
let model = "../models/cross-encoder.onnx";
let mut session = environment
.new_session_builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_number_threads(1)
.unwrap()
.with_model_from_file(model)
.unwrap();
Ok(Self {})
}
}

View File

@@ -1,3 +1,4 @@
mod cross_encoder;
mod db;
mod embedding;
mod embedding_queue;