use ndarray to represent vector

main
Ziyang Hu 1 year ago
parent a67dcda7bc
commit 44ee2e939a

30
Cargo.lock generated

@ -621,6 +621,7 @@ dependencies = [
"log", "log",
"miette", "miette",
"minreq", "minreq",
"ndarray",
"num-traits", "num-traits",
"ordered-float", "ordered-float",
"pest", "pest",
@ -1900,6 +1901,15 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40" checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "matrixmultiply"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
dependencies = [
"rawpointer",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.5.0" version = "2.5.0"
@ -2056,6 +2066,20 @@ dependencies = [
"tempfile", "tempfile",
] ]
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
"serde",
]
[[package]] [[package]]
name = "neon" name = "neon"
version = "0.10.1" version = "0.10.1"
@ -2942,6 +2966,12 @@ dependencies = [
"rand_core 0.5.1", "rand_core 0.5.1",
] ]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.6.1" version = "1.6.1"

@ -127,3 +127,4 @@ sqlite3-src = { version = "0.4.0", optional = true, features = ["bundled"] }
js-sys = { version = "0.3.60", optional = true } js-sys = { version = "0.3.60", optional = true }
graph = { version = "0.3.0", optional = true } graph = { version = "0.3.0", optional = true }
crossbeam = "0.8.2" crossbeam = "0.8.2"
ndarray = { version = "0.15.6", features = ["serde"] }

@ -13,10 +13,11 @@ query_script_inner_no_bracket = { (option | rule | const_rule | fixed_rule)+ }
imperative_script = {SOI ~ imperative_stmt+ ~ EOI} imperative_script = {SOI ~ imperative_stmt+ ~ EOI}
sys_script = {SOI ~ "::" ~ (list_relations_op | list_relation_op | remove_relations_op | trigger_relation_op | sys_script = {SOI ~ "::" ~ (list_relations_op | list_relation_op | remove_relations_op | trigger_relation_op |
trigger_relation_show_op | rename_relations_op | running_op | kill_op | explain_op | trigger_relation_show_op | rename_relations_op | running_op | kill_op | explain_op |
access_level_op | index_op | compact_op | list_fixed_rules) ~ EOI} access_level_op | index_op | vec_idx_op | compact_op | list_fixed_rules) ~ EOI}
index_op = {"index" ~ (index_create | index_drop)} index_op = {"index" ~ (index_create | index_drop)}
vec_idx_op = {"hnsw" ~ (index_create_hnsw | index_drop)}
index_create = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (ident ~ ",")* ~ ident? ~ "}"} index_create = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (ident ~ ",")* ~ ident? ~ "}"}
index_create_hnsw = {"create_hnsw" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"} index_create_hnsw = {"create" ~ compound_ident ~ ":" ~ ident ~ "{" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"}
index_drop = {"drop" ~ compound_ident ~ ":" ~ ident } index_drop = {"drop" ~ compound_ident ~ ":" ~ ident }
compact_op = {"compact"} compact_op = {"compact"}
list_fixed_rules = {"fixed_rules"} list_fixed_rules = {"fixed_rules"}
@ -51,6 +52,7 @@ param = @{"$" ~ (XID_CONTINUE | "_")*}
ident = @{XID_START ~ ("_" | XID_CONTINUE)*} ident = @{XID_START ~ ("_" | XID_CONTINUE)*}
underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*} underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*}
relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)} relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)}
vector_index_ident = @{"~" ~ compound_or_index_ident}
compound_ident = @{ident ~ ("." ~ ident)*} compound_ident = @{ident ~ ("." ~ ident)*}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?} compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?}
@ -76,6 +78,7 @@ rule_body = {(disjunction ~ ",")* ~ disjunction?}
rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"} rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"}
relation_named_apply = {relation_ident ~ "{" ~ named_apply_args ~ validity_clause? ~ "}"} relation_named_apply = {relation_ident ~ "{" ~ named_apply_args ~ validity_clause? ~ "}"}
relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"} relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"}
vector_search = {vector_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"}
disjunction = {(atom ~ "or" )* ~ atom} disjunction = {(atom ~ "or" )* ~ atom}
atom = _{ negation | relation_named_apply | relation_apply | rule_apply | unify_multi | unify | expr | grouped} atom = _{ negation | relation_named_apply | relation_apply | rule_apply | unify_multi | unify | expr | grouped}
@ -207,7 +210,7 @@ validity_type = {"Validity"}
list_type = {"[" ~ col_type ~ (";" ~ expr)? ~ "]"} list_type = {"[" ~ col_type ~ (";" ~ expr)? ~ "]"}
tuple_type = {"(" ~ (col_type ~ ",")* ~ col_type? ~ ")"} tuple_type = {"(" ~ (col_type ~ ",")* ~ col_type? ~ ")"}
vec_type = {"<" ~ vec_el_type ~ ";" ~ pos_int ~ ">"} vec_type = {"<" ~ vec_el_type ~ ";" ~ pos_int ~ ">"}
vec_el_type = {"F32" | "F64" | "F32" | "F64" | "Float" | "Double" | "Long" | "Int" } vec_el_type = {"F32" | "F64" | "Float" | "Double" }
imperative_stmt = _{ imperative_stmt = _{
break_stmt | continue_stmt | return_stmt | debug_stmt | break_stmt | continue_stmt | return_stmt | debug_stmt |

@ -99,10 +99,8 @@ impl From<DataValue> for JsonValue {
} }
DataValue::Vec(arr) => { DataValue::Vec(arr) => {
match arr { match arr {
Vector::F32(a) => json!(a), Vector::F32(a) => json!(a.as_slice().unwrap()),
Vector::F64(a) => json!(a), Vector::F64(a) => json!(a.as_slice().unwrap()),
Vector::I32(a) => json!(a),
Vector::I64(a) => json!(a),
} }
} }
DataValue::Validity(v) => { DataValue::Validity(v) => {

@ -10,6 +10,7 @@ use std::cmp::Reverse;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::io::Write; use std::io::Write;
use std::str::FromStr; use std::str::FromStr;
use ndarray;
use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use regex::Regex; use regex::Regex;
@ -33,8 +34,6 @@ const BOT_TAG: u8 = 0xFF;
const VEC_F32: u8 = 0x01; const VEC_F32: u8 = 0x01;
const VEC_F64: u8 = 0x02; const VEC_F64: u8 = 0x02;
const VEC_I32: u8 = 0x03;
const VEC_I64: u8 = 0x04;
const IS_FLOAT: u8 = 0b00010000; const IS_FLOAT: u8 = 0b00010000;
const IS_APPROX_INT: u8 = 0b00000100; const IS_APPROX_INT: u8 = 0b00000100;
@ -66,22 +65,6 @@ pub(crate) trait MemCmpEncoder: Write {
self.write_f64::<BigEndian>(*el).unwrap(); self.write_f64::<BigEndian>(*el).unwrap();
} }
} }
Vector::I32(a) => {
self.write_u8(VEC_I32).unwrap();
let l = a.len();
self.write_u64::<BigEndian>(l as u64).unwrap();
for el in a {
self.write_i32::<BigEndian>(*el).unwrap();
}
}
Vector::I64(a) => {
self.write_u8(VEC_I64).unwrap();
let l = a.len();
self.write_u64::<BigEndian>(l as u64).unwrap();
for el in a {
self.write_i64::<BigEndian>(*el).unwrap();
}
}
} }
} }
DataValue::Num(n) => { DataValue::Num(n) => {
@ -344,45 +327,25 @@ impl DataValue {
let len = BigEndian::read_u64(len_bytes) as usize; let len = BigEndian::read_u64(len_bytes) as usize;
match *t_tag { match *t_tag {
VEC_F32 => { VEC_F32 => {
let mut res_arr = Vec::with_capacity(len); let mut res_arr = ndarray::Array1::zeros(len);
for _ in 0..len { for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) {
let (f_bytes, next_chunk) = rest.split_at(4); let (f_bytes, next_chunk) = rest.split_at(4);
rest = next_chunk; rest = next_chunk;
let f = BigEndian::read_f32(f_bytes); let f = BigEndian::read_f32(f_bytes);
res_arr.push(f); row.fill(f);
} }
(DataValue::Vec(Vector::F32(res_arr)), rest) (DataValue::Vec(Vector::F32(res_arr)), rest)
} }
VEC_F64 => { VEC_F64 => {
let mut res_arr = Vec::with_capacity(len); let mut res_arr = ndarray::Array1::zeros(len);
for _ in 0..len { for mut row in res_arr.axis_iter_mut(ndarray::Axis(0)) {
let (f_bytes, next_chunk) = rest.split_at(8); let (f_bytes, next_chunk) = rest.split_at(8);
rest = next_chunk; rest = next_chunk;
let f = BigEndian::read_f64(f_bytes); let f = BigEndian::read_f64(f_bytes);
res_arr.push(f); row.fill(f);
} }
(DataValue::Vec(Vector::F64(res_arr)), rest) (DataValue::Vec(Vector::F64(res_arr)), rest)
} }
VEC_I32 => {
let mut res_arr = Vec::with_capacity(len);
for _ in 0..len {
let (i_bytes, next_chunk) = rest.split_at(4);
rest = next_chunk;
let i = BigEndian::read_i32(i_bytes);
res_arr.push(i);
}
(DataValue::Vec(Vector::I32(res_arr)), rest)
}
VEC_I64 => {
let mut res_arr = Vec::with_capacity(len);
for _ in 0..len {
let (i_bytes, next_chunk) = rest.split_at(8);
rest = next_chunk;
let i = BigEndian::read_i64(i_bytes);
res_arr.push(i);
}
(DataValue::Vec(Vector::I64(res_arr)), rest)
}
_ => unreachable!() _ => unreachable!()
} }
} }

@ -62,8 +62,6 @@ impl Display for NullableColType {
match eltype { match eltype {
VecElementType::F32 => f.write_str("F32")?, VecElementType::F32 => f.write_str("F32")?,
VecElementType::F64 => f.write_str("F64")?, VecElementType::F64 => f.write_str("F64")?,
VecElementType::I32 => f.write_str("I32")?,
VecElementType::I64 => f.write_str("I64")?,
} }
write!(f, ";{len}")?; write!(f, ";{len}")?;
f.write_str(">")?; f.write_str(">")?;
@ -101,8 +99,6 @@ pub(crate) enum ColType {
pub(crate) enum VecElementType { pub(crate) enum VecElementType {
F32, F32,
F64, F64,
I32,
I64,
} }
#[derive(Debug, Clone, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] #[derive(Debug, Clone, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)]
@ -254,32 +250,20 @@ impl NullableColType {
} }
match eltype { match eltype {
VecElementType::F32 => { VecElementType::F32 => {
let mut v = Vec::with_capacity(l.len()); let mut res_arr = ndarray::Array1::zeros(*len);
for el in l { for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) {
v.push(el.get_float().ok_or_else(make_err)? as f32) let f = el.get_float().ok_or_else(make_err)? as f32;
row.fill(f);
} }
DataValue::Vec(Vector::F32(v)) DataValue::Vec(Vector::F32(res_arr))
} }
VecElementType::F64 => { VecElementType::F64 => {
let mut v = Vec::with_capacity(l.len()); let mut res_arr = ndarray::Array1::zeros(*len);
for el in l { for (mut row, el) in res_arr.axis_iter_mut(ndarray::Axis(0)).zip(l.iter()) {
v.push(el.get_float().ok_or_else(make_err)?) let f = el.get_float().ok_or_else(make_err)?;
row.fill(f);
} }
DataValue::Vec(Vector::F64(v)) DataValue::Vec(Vector::F64(res_arr))
}
VecElementType::I32 => {
let mut v = Vec::with_capacity(l.len());
for el in l {
v.push(el.get_int().ok_or_else(make_err)? as i32)
}
DataValue::Vec(Vector::I32(v))
}
VecElementType::I64 => {
let mut v = Vec::with_capacity(l.len());
for el in l {
v.push(el.get_int().ok_or_else(make_err)?)
}
DataValue::Vec(Vector::I64(v))
} }
} }
} }

@ -8,17 +8,18 @@
use base64::engine::general_purpose::STANDARD; use base64::engine::general_purpose::STANDARD;
use base64::Engine; use base64::Engine;
use ndarray::Array1;
use std::cmp::{Ordering, Reverse}; use std::cmp::{Ordering, Reverse};
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::fmt::{Debug, Display, Formatter}; use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use crate::data::relation::VecElementType;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use regex::Regex; use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use uuid::Uuid; use uuid::Uuid;
use crate::data::relation::VecElementType;
/// UUID value in the database /// UUID value in the database
#[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)] #[derive(Clone, Hash, Eq, PartialEq, serde_derive::Deserialize, serde_derive::Serialize)]
@ -157,10 +158,8 @@ pub enum DataValue {
#[derive(Clone, serde_derive::Serialize, serde_derive::Deserialize)] #[derive(Clone, serde_derive::Serialize, serde_derive::Deserialize)]
pub enum Vector { pub enum Vector {
F32(Vec<f32>), F32(Array1<f32>),
F64(Vec<f64>), F64(Array1<f64>),
I32(Vec<i32>),
I64(Vec<i64>),
} }
impl Vector { impl Vector {
@ -168,16 +167,12 @@ impl Vector {
match self { match self {
Vector::F32(v) => v.len(), Vector::F32(v) => v.len(),
Vector::F64(v) => v.len(), Vector::F64(v) => v.len(),
Vector::I32(v) => v.len(),
Vector::I64(v) => v.len(),
} }
} }
pub fn is_compatible(&self, other: &Self) -> bool { pub fn is_compatible(&self, other: &Self) -> bool {
match (self, other) { match (self, other) {
(Vector::F32(_), Vector::F32(_)) => true, (Vector::F32(_), Vector::F32(_)) => true,
(Vector::F64(_), Vector::F64(_)) => true, (Vector::F64(_), Vector::F64(_)) => true,
(Vector::I32(_), Vector::I32(_)) => true,
(Vector::I64(_), Vector::I64(_)) => true,
_ => false, _ => false,
} }
} }
@ -185,8 +180,6 @@ impl Vector {
match self { match self {
Vector::F32(_) => VecElementType::F32, Vector::F32(_) => VecElementType::F32,
Vector::F64(_) => VecElementType::F64, Vector::F64(_) => VecElementType::F64,
Vector::I32(_) => VecElementType::I32,
Vector::I64(_) => VecElementType::I64,
} }
} }
} }
@ -195,6 +188,9 @@ impl PartialEq<Self> for Vector {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match (self, other) { match (self, other) {
(Vector::F32(l), Vector::F32(r)) => { (Vector::F32(l), Vector::F32(r)) => {
if l.len() != r.len() {
return false;
}
for (le, re) in l.iter().zip(r) { for (le, re) in l.iter().zip(r) {
if !OrderedFloat(*le).eq(&OrderedFloat(*re)) { if !OrderedFloat(*le).eq(&OrderedFloat(*re)) {
return false; return false;
@ -203,6 +199,9 @@ impl PartialEq<Self> for Vector {
true true
} }
(Vector::F64(l), Vector::F64(r)) => { (Vector::F64(l), Vector::F64(r)) => {
if l.len() != r.len() {
return false;
}
for (le, re) in l.iter().zip(r) { for (le, re) in l.iter().zip(r) {
if !OrderedFloat(*le).eq(&OrderedFloat(*re)) { if !OrderedFloat(*le).eq(&OrderedFloat(*re)) {
return false; return false;
@ -210,8 +209,6 @@ impl PartialEq<Self> for Vector {
} }
true true
} }
(Vector::I32(l), Vector::I32(r)) => l == r,
(Vector::I64(l), Vector::I64(r)) => l == r,
_ => false, _ => false,
} }
} }
@ -229,6 +226,10 @@ impl Ord for Vector {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
match (self, other) { match (self, other) {
(Vector::F32(l), Vector::F32(r)) => { (Vector::F32(l), Vector::F32(r)) => {
match l.len().cmp(&r.len()) {
Ordering::Equal => (),
o => return o,
}
for (le, re) in l.iter().zip(r) { for (le, re) in l.iter().zip(r) {
match OrderedFloat(*le).cmp(&OrderedFloat(*re)) { match OrderedFloat(*le).cmp(&OrderedFloat(*re)) {
Ordering::Equal => continue, Ordering::Equal => continue,
@ -237,8 +238,12 @@ impl Ord for Vector {
} }
return Ordering::Equal; return Ordering::Equal;
} }
(Vector::F32(_), _) => Ordering::Less, (Vector::F32(_), Vector::F64(_)) => Ordering::Less,
(Vector::F64(l), Vector::F64(r)) => { (Vector::F64(l), Vector::F64(r)) => {
match l.len().cmp(&r.len()) {
Ordering::Equal => (),
o => return o,
}
for (le, re) in l.iter().zip(r) { for (le, re) in l.iter().zip(r) {
match OrderedFloat(*le).cmp(&OrderedFloat(*re)) { match OrderedFloat(*le).cmp(&OrderedFloat(*re)) {
Ordering::Equal => continue, Ordering::Equal => continue,
@ -248,12 +253,6 @@ impl Ord for Vector {
return Ordering::Equal; return Ordering::Equal;
} }
(Vector::F64(_), Vector::F32(_)) => Ordering::Greater, (Vector::F64(_), Vector::F32(_)) => Ordering::Greater,
(Vector::F64(_), _) => Ordering::Less,
(Vector::I32(l), Vector::I32(r)) => l.cmp(r),
(Vector::I32(_), Vector::I64(_)) => Ordering::Less,
(Vector::I32(_), _) => Ordering::Greater,
(Vector::I64(l), Vector::I64(r)) => l.cmp(r),
(Vector::I64(_), _) => Ordering::Greater,
} }
} }
} }
@ -271,8 +270,6 @@ impl Hash for Vector {
OrderedFloat(*el).hash(state) OrderedFloat(*el).hash(state)
} }
} }
Vector::I32(a) => {a.hash(state)}
Vector::I64(a) => {a.hash(state)}
} }
} }
} }

@ -153,8 +153,6 @@ fn parse_type_inner(pair: Pair<'_>) -> Result<ColType> {
let eltype = match inner.next().unwrap().as_str() { let eltype = match inner.next().unwrap().as_str() {
"F32" | "Float" => VecElementType::F32, "F32" | "Float" => VecElementType::F32,
"F64" | "Double" => VecElementType::F64, "F64" | "Double" => VecElementType::F64,
"I32" | "Int" => VecElementType::I32,
"I64" | "Long" => VecElementType::I64,
_ => unreachable!() _ => unreachable!()
}; };
let len = inner.next().unwrap(); let len = inner.next().unwrap();

@ -36,6 +36,7 @@ pub(crate) enum SysOp {
SetTriggers(Symbol, Vec<String>, Vec<String>, Vec<String>), SetTriggers(Symbol, Vec<String>, Vec<String>, Vec<String>),
SetAccessLevel(Vec<Symbol>, AccessLevel), SetAccessLevel(Vec<Symbol>, AccessLevel),
CreateIndex(Symbol, Symbol, Vec<Symbol>), CreateIndex(Symbol, Symbol, Vec<Symbol>),
CreateVectorIndex(Symbol, Symbol, Vec<Symbol>),
RemoveIndex(Symbol, Symbol), RemoveIndex(Symbol, Symbol),
} }

@ -1157,6 +1157,9 @@ impl<'s, S: Storage<'s>> Db<S> {
vec![vec![DataValue::from(OK_STR)]], vec![vec![DataValue::from(OK_STR)]],
)) ))
} }
SysOp::CreateVectorIndex(..) => {
todo!()
}
SysOp::RemoveIndex(rel_name, idx_name) => { SysOp::RemoveIndex(rel_name, idx_name) => {
let lock = self let lock = self
.obtain_relation_locks(iter::once(&rel_name.name)) .obtain_relation_locks(iter::once(&rel_name.name))

Loading…
Cancel
Save