From 44ee2e939a839d74427376e62bc394eaaed9696f Mon Sep 17 00:00:00 2001 From: Ziyang Hu Date: Thu, 13 Apr 2023 21:24:36 +0800 Subject: [PATCH] use ndarray to represent vector --- Cargo.lock | 30 ++++++++++++++++++++ cozo-core/Cargo.toml | 3 +- cozo-core/src/cozoscript.pest | 9 ++++-- cozo-core/src/data/json.rs | 6 ++-- cozo-core/src/data/memcmp.rs | 51 +++++----------------------------- cozo-core/src/data/relation.rs | 36 +++++++----------------- cozo-core/src/data/value.rs | 41 +++++++++++++-------------- cozo-core/src/parse/schema.rs | 2 -- cozo-core/src/parse/sys.rs | 1 + cozo-core/src/runtime/db.rs | 3 ++ 10 files changed, 80 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 00642fc2..86d8859b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -621,6 +621,7 @@ dependencies = [ "log", "miette", "minreq", + "ndarray", "num-traits", "ordered-float", "pest", @@ -1900,6 +1901,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" +[[package]] +name = "matrixmultiply" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84" +dependencies = [ + "rawpointer", +] + [[package]] name = "memchr" version = "2.5.0" @@ -2056,6 +2066,20 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", + "serde", +] + [[package]] name = "neon" version = "0.10.1" @@ -2942,6 +2966,12 @@ dependencies = [ "rand_core 0.5.1", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.6.1" diff --git a/cozo-core/Cargo.toml b/cozo-core/Cargo.toml index 439f8882..914284a5 100644 --- a/cozo-core/Cargo.toml +++ b/cozo-core/Cargo.toml @@ -126,4 +126,5 @@ sqlite = { version = "0.30.1", optional = true } sqlite3-src = { version = "0.4.0", optional = true, features = ["bundled"] } js-sys = { version = "0.3.60", optional = true } graph = { version = "0.3.0", optional = true } -crossbeam = "0.8.2" \ No newline at end of file +crossbeam = "0.8.2" +ndarray = { version = "0.15.6", features = ["serde"] } \ No newline at end of file diff --git a/cozo-core/src/cozoscript.pest b/cozo-core/src/cozoscript.pest index 068b4d03..9fd43cc7 100644 --- a/cozo-core/src/cozoscript.pest +++ b/cozo-core/src/cozoscript.pest @@ -13,10 +13,11 @@ query_script_inner_no_bracket = { (option | rule | const_rule | fixed_rule)+ } imperative_script = {SOI ~ imperative_stmt+ ~ EOI} sys_script = {SOI ~ "::" ~ (list_relations_op | list_relation_op | remove_relations_op | trigger_relation_op | trigger_relation_show_op | rename_relations_op | running_op | kill_op | explain_op | - access_level_op | index_op | compact_op | list_fixed_rules) ~ EOI} + access_level_op | index_op | vec_idx_op | compact_op | list_fixed_rules) ~ EOI} index_op = {"index" ~ (index_create | index_drop)} +vec_idx_op = {"hnsw" ~ (index_create_hnsw | index_drop)} index_create = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (ident ~ ",")* ~ ident? ~ "}"} -index_create_hnsw = {"create_hnsw" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} +index_create_hnsw = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} index_drop = {"drop" ~ compound_ident ~ ":" ~ ident } compact_op = {"compact"} list_fixed_rules = {"fixed_rules"} @@ -51,6 +52,7 @@ param = @{"$" ~ (XID_CONTINUE | "_")*} ident = @{XID_START ~ ("_" | XID_CONTINUE)*} underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*} relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)} +vector_index_ident = @{"~" ~ compound_or_index_ident} compound_ident = @{ident ~ ("." ~ ident)*} compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?} @@ -76,6 +78,7 @@ rule_body = {(disjunction ~ ",")* ~ disjunction?} rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"} relation_named_apply = {relation_ident ~ "{" ~ named_apply_args ~ validity_clause? ~ "}"} relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"} +vector_search = {vector_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} disjunction = {(atom ~ "or" )* ~ atom} atom = _{ negation | relation_named_apply | relation_apply | rule_apply | unify_multi | unify | expr | grouped} @@ -207,7 +210,7 @@ validity_type = {"Validity"} list_type = {"[" ~ col_type ~ (";" ~ expr)? ~ "]"} tuple_type = {"(" ~ (col_type ~ ",")* ~ col_type? ~ ")"} vec_type = {"<" ~ vec_el_type ~ ";" ~ pos_int ~ ">"} -vec_el_type = {"F32" | "F64" | "F32" | "F64" | "Float" | "Double" | "Long" | "Int" } +vec_el_type = {"F32" | "F64" | "Float" | "Double" } imperative_stmt = _{ break_stmt | continue_stmt | return_stmt | debug_stmt | diff --git a/cozo-core/src/data/json.rs b/cozo-core/src/data/json.rs index 0df0ccba..212c61d1 100644 --- a/cozo-core/src/data/json.rs +++ b/cozo-core/src/data/json.rs @@ -99,10 +99,8 @@ impl From for JsonValue { } DataValue::Vec(arr) => { match arr { - Vector::F32(a) => json!(a), - Vector::F64(a) => json!(a), - Vector::I32(a) => json!(a), - Vector::I64(a) => json!(a), + Vector::F32(a) => json!(a.as_slice().unwrap()), + Vector::F64(a) => json!(a.as_slice().unwrap()), } } DataValue::Validity(v) => { diff --git a/cozo-core/src/data/memcmp.rs b/cozo-core/src/data/memcmp.rs index 51c48ea2..19e4158f 100644 --- a/cozo-core/src/data/memcmp.rs +++ b/cozo-core/src/data/memcmp.rs @@ -10,6 +10,7 @@ use std::cmp::Reverse; use std::collections::BTreeSet; use std::io::Write; use std::str::FromStr; +use ndarray; use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use regex::Regex; @@ -33,8 +34,6 @@ const BOT_TAG: u8 = 0xFF; const VEC_F32: u8 = 0x01; const VEC_F64: u8 = 0x02; -const VEC_I32: u8 = 0x03; -const VEC_I64: u8 = 0x04; const IS_FLOAT: u8 = 0b00010000; const IS_APPROX_INT: u8 = 0b00000100; @@ -66,22 +65,6 @@ pub(crate) trait MemCmpEncoder: Write { self.write_f64::(*el).unwrap(); } } - Vector::I32(a) => { - self.write_u8(VEC_I32).unwrap(); - let l = a.len(); - self.write_u64::(l as u64).unwrap(); - for el in a { - self.write_i32::(*el).unwrap(); - } - } - Vector::I64(a) => { - self.write_u8(VEC_I64).unwrap(); - let l = a.len(); - self.write_u64::(l as u64).unwrap(); - for el in a { - self.write_i64::(*el).unwrap(); - } - } } } DataValue::Num(n) => { @@ -344,45 +327,25 @@ impl DataValue { let len = BigEndian::read_u64(len_bytes) as usize; match *t_tag { VEC_F32 => { - let mut res_arr = Vec::with_capacity(len); - for _ in 0..len { + let mut res_arr = ndarray::Array1::zeros(len); + for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) { let (f_bytes, next_chunk) = rest.split_at(4); rest = next_chunk; let f = BigEndian::read_f32(f_bytes); - res_arr.push(f); + row.fill(f); } (DataValue::Vec(Vector::F32(res_arr)), rest) } VEC_F64 => { - let mut res_arr = Vec::with_capacity(len); - for _ in 0..len { + let mut res_arr = ndarray::Array1::zeros(len); + for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) { let (f_bytes, next_chunk) = rest.split_at(8); rest = next_chunk; let f = BigEndian::read_f64(f_bytes); - res_arr.push(f); + row.fill(f); } (DataValue::Vec(Vector::F64(res_arr)), rest) } - VEC_I32 => { - let mut res_arr = Vec::with_capacity(len); - for _ in 0..len { - let (i_bytes, next_chunk) = rest.split_at(4); - rest = next_chunk; - let i = BigEndian::read_i32(i_bytes); - res_arr.push(i); - } - (DataValue::Vec(Vector::I32(res_arr)), rest) - } - VEC_I64 => { - let mut res_arr = Vec::with_capacity(len); - for _ in 0..len { - let (i_bytes, next_chunk) = rest.split_at(8); - rest = next_chunk; - let i = BigEndian::read_i64(i_bytes); - res_arr.push(i); - } - (DataValue::Vec(Vector::I64(res_arr)), rest) - } _ => unreachable!() } } diff --git a/cozo-core/src/data/relation.rs b/cozo-core/src/data/relation.rs index ca19b0b2..154026e1 100644 --- a/cozo-core/src/data/relation.rs +++ b/cozo-core/src/data/relation.rs @@ -62,8 +62,6 @@ impl Display for NullableColType { match eltype { VecElementType::F32 => f.write_str("F32")?, VecElementType::F64 => f.write_str("F64")?, - VecElementType::I32 => f.write_str("I32")?, - VecElementType::I64 => f.write_str("I64")?, } write!(f, ";{len}")?; f.write_str(">")?; @@ -101,8 +99,6 @@ pub(crate) enum ColType { pub(crate) enum VecElementType { F32, F64, - I32, - I64, } #[derive(Debug, Clone, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] @@ -254,32 +250,20 @@ impl NullableColType { } match eltype { VecElementType::F32 => { - let mut v = Vec::with_capacity(l.len()); - for el in l { - v.push(el.get_float().ok_or_else(make_err)? as f32) + let mut res_arr = ndarray::Array1::zeros(*len); + for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) { + let f = el.get_float().ok_or_else(make_err)? as f32; + row.fill(f); } - DataValue::Vec(Vector::F32(v)) + DataValue::Vec(Vector::F32(res_arr)) } VecElementType::F64 => { - let mut v = Vec::with_capacity(l.len()); - for el in l { - v.push(el.get_float().ok_or_else(make_err)?) + let mut res_arr = ndarray::Array1::zeros(*len); + for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) { + let f = el.get_float().ok_or_else(make_err)?; + row.fill(f); } - DataValue::Vec(Vector::F64(v)) - } - VecElementType::I32 => { - let mut v = Vec::with_capacity(l.len()); - for el in l { - v.push(el.get_int().ok_or_else(make_err)? as i32) - } - DataValue::Vec(Vector::I32(v)) - } - VecElementType::I64 => { - let mut v = Vec::with_capacity(l.len()); - for el in l { - v.push(el.get_int().ok_or_else(make_err)?) - } - DataValue::Vec(Vector::I64(v)) + DataValue::Vec(Vector::F64(res_arr)) } } } diff --git a/cozo-core/src/data/value.rs b/cozo-core/src/data/value.rs index 9f88194d..2d7032db 100644 --- a/cozo-core/src/data/value.rs +++ b/cozo-core/src/data/value.rs @@ -8,17 +8,18 @@ use base64::engine::general_purpose::STANDARD; use base64::Engine; +use ndarray::Array1; use std::cmp::{Ordering, Reverse}; use std::collections::BTreeSet; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; +use crate::data::relation::VecElementType; use ordered_float::OrderedFloat; use regex::Regex; use serde::{Deserialize, Deserializer, Serialize}; use smartstring::{LazyCompact, SmartString}; use uuid::Uuid; -use crate::data::relation::VecElementType; /// UUID value in the database #[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] @@ -157,10 +158,8 @@ pub enum DataValue { #[derive(Clone, serde_derive::Serialize, serde_derive::Deserialize)] pub enum Vector { - F32(Vec), - F64(Vec), - I32(Vec), - I64(Vec), + F32(Array1), + F64(Array1), } impl Vector { @@ -168,16 +167,12 @@ impl Vector { match self { Vector::F32(v) => v.len(), Vector::F64(v) => v.len(), - Vector::I32(v) => v.len(), - Vector::I64(v) => v.len(), } } pub fn is_compatible(&self, other: &Self) -> bool { match (self, other) { (Vector::F32(_), Vector::F32(_)) => true, (Vector::F64(_), Vector::F64(_)) => true, - (Vector::I32(_), Vector::I32(_)) => true, - (Vector::I64(_), Vector::I64(_)) => true, _ => false, } } @@ -185,8 +180,6 @@ impl Vector { match self { Vector::F32(_) => VecElementType::F32, Vector::F64(_) => VecElementType::F64, - Vector::I32(_) => VecElementType::I32, - Vector::I64(_) => VecElementType::I64, } } } @@ -195,6 +188,9 @@ impl PartialEq for Vector { fn eq(&self, other: &Self) -> bool { match (self, other) { (Vector::F32(l), Vector::F32(r)) => { + if l.len() != r.len() { + return false; + } for (le, re) in l.iter().zip(r) { if !OrderedFloat(*le).eq(&OrderedFloat(*re)) { return false; @@ -203,6 +199,9 @@ impl PartialEq for Vector { true } (Vector::F64(l), Vector::F64(r)) => { + if l.len() != r.len() { + return false; + } for (le, re) in l.iter().zip(r) { if !OrderedFloat(*le).eq(&OrderedFloat(*re)) { return false; @@ -210,8 +209,6 @@ impl PartialEq for Vector { } true } - (Vector::I32(l), Vector::I32(r)) => l == r, - (Vector::I64(l), Vector::I64(r)) => l == r, _ => false, } } @@ -229,6 +226,10 @@ impl Ord for Vector { fn cmp(&self, other: &Self) -> Ordering { match (self, other) { (Vector::F32(l), Vector::F32(r)) => { + match l.len().cmp(&r.len()) { + Ordering::Equal => (), + o => return o, + } for (le, re) in l.iter().zip(r) { match OrderedFloat(*le).cmp(&OrderedFloat(*re)) { Ordering::Equal => continue, @@ -237,8 +238,12 @@ impl Ord for Vector { } return Ordering::Equal; } - (Vector::F32(_), _) => Ordering::Less, + (Vector::F32(_), Vector::F64(_)) => Ordering::Less, (Vector::F64(l), Vector::F64(r)) => { + match l.len().cmp(&r.len()) { + Ordering::Equal => (), + o => return o, + } for (le, re) in l.iter().zip(r) { match OrderedFloat(*le).cmp(&OrderedFloat(*re)) { Ordering::Equal => continue, @@ -248,12 +253,6 @@ impl Ord for Vector { return Ordering::Equal; } (Vector::F64(_), Vector::F32(_)) => Ordering::Greater, - (Vector::F64(_), _) => Ordering::Less, - (Vector::I32(l), Vector::I32(r)) => l.cmp(r), - (Vector::I32(_), Vector::I64(_)) => Ordering::Less, - (Vector::I32(_), _) => Ordering::Greater, - (Vector::I64(l), Vector::I64(r)) => l.cmp(r), - (Vector::I64(_), _) => Ordering::Greater, } } } @@ -271,8 +270,6 @@ impl Hash for Vector { OrderedFloat(*el).hash(state) } } - Vector::I32(a) => {a.hash(state)} - Vector::I64(a) => {a.hash(state)} } } } diff --git a/cozo-core/src/parse/schema.rs b/cozo-core/src/parse/schema.rs index 952ba837..75529032 100644 --- a/cozo-core/src/parse/schema.rs +++ b/cozo-core/src/parse/schema.rs @@ -153,8 +153,6 @@ fn parse_type_inner(pair: Pair<'_>) -> Result { let eltype = match inner.next().unwrap().as_str() { "F32" | "Float" => VecElementType::F32, "F64" | "Double" => VecElementType::F64, - "I32" | "Int" => VecElementType::I32, - "I64" | "Long" => VecElementType::I64, _ => unreachable!() }; let len = inner.next().unwrap(); diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index 6ac089a6..38494d70 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -36,6 +36,7 @@ pub(crate) enum SysOp { SetTriggers(Symbol, Vec, Vec, Vec), SetAccessLevel(Vec, AccessLevel), CreateIndex(Symbol, Symbol, Vec), + CreateVectorIndex(Symbol, Symbol, Vec), RemoveIndex(Symbol, Symbol), } diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 0337204f..bd78a668 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1157,6 +1157,9 @@ impl<'s, S: Storage<'s>> Db { vec![vec![DataValue::from(OK_STR)]], )) } + SysOp::CreateVectorIndex(..) => { + todo!() + } SysOp::RemoveIndex(rel_name, idx_name) => { let lock = self .obtain_relation_locks(iter::once(&rel_name.name))