improve performance

main
Ziyang Hu 1 year ago
parent 4f839bbab1
commit d90aea2dad

4
.gitignore vendored

@ -27,7 +27,6 @@ _test*
*.dll
*.db
.DS_Store
flamegraph.svg
release.zip
.idea
.fleet
@ -38,4 +37,5 @@ Cross.toml
/tools
*.cozo_auth
.cozo_repl_history
/venv/
/venv/
flamegraph*.svg

1
Cargo.lock generated

@ -709,6 +709,7 @@ dependencies = [
"rmp",
"rmp-serde",
"rmpv",
"rustc-hash",
"serde",
"serde_bytes",
"serde_derive",

@ -12,4 +12,5 @@ members = [
]
[profile.bench]
lto = true
lto = true
debug = true

@ -128,4 +128,5 @@ js-sys = { version = "0.3.60", optional = true }
graph = { version = "0.3.0", optional = true }
crossbeam = "0.8.2"
ndarray = { version = "0.15.6", features = ["serde"] }
sha2 = "0.10.6"
sha2 = "0.10.6"
rustc-hash = "1.1.0"

@ -38,9 +38,9 @@ where
}
}
pub fn decode_tuple_from_key(key: &[u8]) -> Tuple {
pub fn decode_tuple_from_key(key: &[u8], size_hint: usize) -> Tuple {
let mut remaining = &key[ENCODED_KEY_MIN_LEN..];
let mut ret = vec![];
let mut ret = Vec::with_capacity(size_hint);
while !remaining.is_empty() {
let (val, next) = DataValue::decode_from_key(remaining);
ret.push(val);
@ -49,14 +49,16 @@ pub fn decode_tuple_from_key(key: &[u8]) -> Tuple {
ret
}
const DEFAULT_SIZE_HINT: usize = 16;
/// Check if the tuple key passed in should be a valid return for a validity query.
///
/// Returns two elements, the first element contains `Some(tuple)` if the key should be included
/// in the return set and `None` otherwise,
/// the second element gives the next binary key for the seek to be used as an inclusive
/// lower bound.
pub fn check_key_for_validity(key: &[u8], valid_at: ValidityTs) -> (Option<Tuple>, Vec<u8>) {
let mut decoded = decode_tuple_from_key(key);
pub fn check_key_for_validity(key: &[u8], valid_at: ValidityTs, size_hint: Option<usize>) -> (Option<Tuple>, Vec<u8>) {
let mut decoded = decode_tuple_from_key(key, size_hint.unwrap_or(DEFAULT_SIZE_HINT));
let rel_id = RelationId::raw_decode(key);
let vld = match decoded.last().unwrap() {
DataValue::Validity(vld) => vld,

@ -17,7 +17,9 @@ use std::hash::{Hash, Hasher};
use crate::data::relation::VecElementType;
use ordered_float::OrderedFloat;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize};
use serde::de::{SeqAccess, Visitor};
use serde::ser::SerializeTuple;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use smartstring::{LazyCompact, SmartString};
@ -159,7 +161,7 @@ pub enum DataValue {
}
/// Vector of floating numbers
#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
#[derive(Debug, Clone)]
pub enum Vector {
/// 32-bit float array
F32(Array1<f32>),
@ -167,6 +169,103 @@ pub enum Vector {
F64(Array1<f64>),
}
struct VecBytes<'a>(&'a [u8]);
impl serde::Serialize for VecBytes<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self.0)
}
}
impl serde::Serialize for Vector {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_tuple(2)?;
match self {
Vector::F32(a) => {
state.serialize_element(&0u8)?;
let arr = a.as_slice().unwrap();
let len = arr.len() * std::mem::size_of::<f32>();
let ptr = arr.as_ptr() as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
state.serialize_element(&VecBytes(bytes))?;
}
Vector::F64(a) => {
state.serialize_element(&1u8)?;
let arr = a.as_slice().unwrap();
let len = arr.len() * std::mem::size_of::<f64>();
let ptr = arr.as_ptr() as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
state.serialize_element(&VecBytes(bytes))?;
}
}
state.end()
}
}
impl<'de> serde::Deserialize<'de> for Vector {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_tuple(2, VectorVisitor)
}
}
struct VectorVisitor;
impl<'de> Visitor<'de> for VectorVisitor {
type Value = Vector;
fn expecting(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
formatter.write_str("vector representation")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let tag: u8 = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let bytes: &[u8] = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
match tag {
0u8 => {
let len = bytes.len() / std::mem::size_of::<f32>();
let mut v = vec![];
v.reserve_exact(len);
let ptr = v.as_mut_ptr() as *mut u8;
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len());
v.set_len(len);
}
Ok(Vector::F32(Array1::from(v)))
}
1u8 => {
let len = bytes.len() / std::mem::size_of::<f64>();
let mut v = vec![];
v.reserve_exact(len);
let ptr = v.as_mut_ptr() as *mut u8;
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), ptr, bytes.len());
v.set_len(len);
}
Ok(Vector::F64(Array1::from(v)))
}
_ => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Unsigned(tag as u64),
&self,
)),
}
}
}
impl Vector {
/// Get the length of the vector
pub fn len(&self) -> usize {
@ -459,16 +558,14 @@ impl Display for DataValue {
.field("timestamp", &v.timestamp.0)
.field("retracted", &v.is_assert)
.finish(),
DataValue::Vec(a) => {
match a {
Vector::F32(a) => {
write!(f, "vec({:?})", a.to_vec())
}
Vector::F64(a) => {
write!(f, "vec({:?}, \"F64\")", a.to_vec())
}
DataValue::Vec(a) => match a {
Vector::F32(a) => {
write!(f, "vec({:?})", a.to_vec())
}
}
Vector::F64(a) => {
write!(f, "vec({:?}, \"F64\")", a.to_vec())
}
},
}
}
}

@ -382,6 +382,7 @@ impl<'s, S: Storage<'s>> Db<S> {
let mut ret: BTreeMap<String, NamedRows> = BTreeMap::new();
for rel in relations {
let handle = tx.get_relation(rel.as_ref(), false)?;
let size_hint = handle.metadata.keys.len() + handle.metadata.non_keys.len();
if handle.access_level < AccessLevel::ReadOnly {
bail!(InsufficientAccessLevel(
@ -412,7 +413,7 @@ impl<'s, S: Storage<'s>> Db<S> {
let mut rows = vec![];
for data in tx.store_tx.range_scan(&start, &end) {
let (k, v) = data?;
let tuple = decode_tuple_from_kv(&k, &v);
let tuple = decode_tuple_from_kv(&k, &v, Some(size_hint));
rows.push(tuple);
}
let headers = cols.iter().map(|col| col.to_string()).collect_vec();

@ -22,7 +22,7 @@ use priority_queue::PriorityQueue;
use rand::Rng;
use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse};
use std::collections::{BTreeMap, BTreeSet};
use rustc_hash::{FxHashMap, FxHashSet};
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest {
@ -55,7 +55,7 @@ impl HnswIndexManifest {
type CompoundKey = (Tuple, usize, i32);
struct VectorCache {
cache: BTreeMap<CompoundKey, Vector>,
cache: FxHashMap<CompoundKey, Vector>,
distance: HnswDistance,
}
@ -160,7 +160,7 @@ impl<'a> SessionTx<'a> {
idx_table: &RelationHandle,
) -> Result<()> {
let mut vec_cache = VectorCache {
cache: BTreeMap::new(),
cache: FxHashMap::default(),
distance: manifest.distance,
};
let tuple_key = &tuple[..orig_table.metadata.keys.len()];
@ -406,11 +406,11 @@ impl<'a> SessionTx<'a> {
orig_table,
vec_cache,
)?;
let mut old_candidate_set = BTreeSet::new();
let mut old_candidate_set = FxHashSet::default();
for (old, _) in &candidates {
old_candidate_set.insert(old.clone());
}
let mut new_candidate_set = BTreeSet::new();
let mut new_candidate_set = FxHashSet::default();
for (new, _) in &new_candidates {
new_candidate_set.insert(new.clone());
}
@ -558,7 +558,7 @@ impl<'a> SessionTx<'a> {
found_nn: &mut PriorityQueue<CompoundKey, OrderedFloat<f64>>,
vec_cache: &mut VectorCache,
) -> Result<()> {
let mut visited: BTreeSet<CompoundKey> = BTreeSet::new();
let mut visited: FxHashSet<CompoundKey> = FxHashSet::default();
// min queue
let mut candidates: PriorityQueue<CompoundKey, Reverse<OrderedFloat<f64>>> =
PriorityQueue::new();
@ -738,7 +738,7 @@ impl<'a> SessionTx<'a> {
) -> Result<()> {
let mut prefix = vec![DataValue::from(0)];
prefix.extend_from_slice(&tuple[0..orig_table.metadata.keys.len()]);
let candidates: BTreeSet<_> = idx_table
let candidates: FxHashSet<_> = idx_table
.scan_prefix(self, &prefix)
.filter_map(|t| match t {
Ok(t) => Some({

@ -360,12 +360,12 @@ impl RelationHandle {
Ok(tx
.temp_store_tx
.get(&key_data, false)?
.map(|val_data| decode_tuple_from_kv(&key_data, &val_data)))
.map(|val_data| decode_tuple_from_kv(&key_data, &val_data, Some(self.arity()))))
} else {
Ok(tx
.store_tx
.get(&key_data, false)?
.map(|val_data| decode_tuple_from_kv(&key_data, &val_data)))
.map(|val_data| decode_tuple_from_kv(&key_data, &val_data, Some(self.arity()))))
}
}
@ -465,11 +465,13 @@ impl RelationHandle {
}
}
const DEFAULT_SIZE_HINT: usize = 16;
/// Decode tuple from key-value pairs. Used for customizing storage
/// in trait [`StoreTx`](crate::StoreTx).
#[inline]
pub fn decode_tuple_from_kv(key: &[u8], val: &[u8]) -> Tuple {
let mut tup = decode_tuple_from_key(key);
pub fn decode_tuple_from_kv(key: &[u8], val: &[u8], size_hint: Option<usize>) -> Tuple {
let mut tup = decode_tuple_from_key(key, size_hint.unwrap_or(DEFAULT_SIZE_HINT));
extend_tuple_from_v(&mut tup, val);
tup
}

@ -854,7 +854,7 @@ fn test_vec_index() {
fn test_insertions() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(
r":create a {k => v: <F32; 100> default rand_vec(100)}",
r":create a {k => v: <F32; 1536> default rand_vec(1536)}",
Default::default(),
)
.unwrap();
@ -863,7 +863,7 @@ fn test_insertions() {
db.run_script(r"?[k, v] := *a{k, v}", Default::default())
.unwrap();
db.run_script(
r"::hnsw create a:i {fields: [v], dim: 100, ef: 16, m: 32}",
r"::hnsw create a:i {fields: [v], dim: 1536, ef: 16, m: 32}",
Default::default(),
)
.unwrap();
@ -871,7 +871,7 @@ fn test_insertions() {
.unwrap();
db.run_script(r"?[k] <- [[1]] :put a {k}", Default::default())
.unwrap();
db.run_script(r"?[k] := k in int_range(10000) :put a {k}", Default::default()).unwrap();
db.run_script(r"?[k] := k in int_range(100000) :put a {k}", Default::default()).unwrap();
let res = db
.run_script(
r"?[dist, k] := ~a:i{k | query: v, bind_distance: dist, k:10, ef: 5}, *a{k: 8888, v}",

@ -191,7 +191,7 @@ impl<'s> StoreTx<'s> for MemTx<'s> {
match self {
MemTx::Reader(rdr) => Box::new(
rdr.range(lower.to_vec()..upper.to_vec())
.map(|(k, v)| Ok(decode_tuple_from_kv(k, v))),
.map(|(k, v)| Ok(decode_tuple_from_kv(k, v, None))),
),
MemTx::Writer(wtr, cache) => Box::new(CacheIter {
change_iter: cache.range(lower.to_vec()..upper.to_vec()).fuse(),
@ -215,6 +215,7 @@ impl<'s> StoreTx<'s> for MemTx<'s> {
upper: upper.to_vec(),
valid_at,
next_bound: lower.to_vec(),
size_hint: None,
}
.map(Ok),
),
@ -389,24 +390,24 @@ impl CacheIter<'_> {
let (k, cv) = self.change_cache.take().unwrap();
match cv {
None => continue,
Some(v) => return Ok(Some(decode_tuple_from_kv(k, v))),
Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
}
}
(None, Some(_)) => {
let (k, v) = self.db_cache.take().unwrap();
return Ok(Some(decode_tuple_from_kv(k, v)));
return Ok(Some(decode_tuple_from_kv(k, v, None)));
}
(Some((ck, _)), Some((dk, _))) => match ck.cmp(dk) {
Ordering::Less => {
let (k, sv) = self.change_cache.take().unwrap();
match sv {
None => continue,
Some(v) => return Ok(Some(decode_tuple_from_kv(k, v))),
Some(v) => return Ok(Some(decode_tuple_from_kv(k, v, None))),
}
}
Ordering::Greater => {
let (k, v) = self.db_cache.take().unwrap();
return Ok(Some(decode_tuple_from_kv(k, v)));
return Ok(Some(decode_tuple_from_kv(k, v, None)));
}
Ordering::Equal => {
self.db_cache.take();
@ -433,6 +434,7 @@ pub(crate) struct SkipIterator<'a> {
pub(crate) upper: Vec<u8>,
pub(crate) valid_at: ValidityTs,
pub(crate) next_bound: Vec<u8>,
pub(crate) size_hint: Option<usize>,
}
impl<'a> Iterator for SkipIterator<'a> {
@ -450,7 +452,7 @@ impl<'a> Iterator for SkipIterator<'a> {
match nxt {
None => return None,
Some((candidate_key, candidate_val)) => {
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at);
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, self.size_hint);
self.next_bound = nxt_bound;
if let Some(mut nk) = ret {
extend_tuple_from_v(&mut nk, candidate_val);
@ -493,7 +495,7 @@ impl<'a> Iterator for SkipDualIterator<'a> {
(None, None) => return None,
(None, Some((delta_key, maybe_delta_val))) => match maybe_delta_val {
None => {
let (_, nxt_seek) = check_key_for_validity(delta_key, self.valid_at);
let (_, nxt_seek) = check_key_for_validity(delta_key, self.valid_at, None);
self.next_bound = nxt_seek;
continue;
}
@ -507,7 +509,7 @@ impl<'a> Iterator for SkipDualIterator<'a> {
match maybe_delta_val {
None => {
let (_, nxt_seek) =
check_key_for_validity(delta_key, self.valid_at);
check_key_for_validity(delta_key, self.valid_at, None);
self.next_bound = nxt_seek;
continue;
}
@ -516,7 +518,7 @@ impl<'a> Iterator for SkipDualIterator<'a> {
}
}
};
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at);
let (ret, nxt_bound) = check_key_for_validity(candidate_key, self.valid_at, None);
self.next_bound = nxt_bound;
if let Some(mut nk) = ret {
extend_tuple_from_v(&mut nk, candidate_val);

@ -101,7 +101,7 @@ pub trait StoreTx<'s>: Sync {
's: 'a,
{
let it = self.range_scan(lower, upper);
Box::new(it.map_ok(|(k, v)| decode_tuple_from_kv(&k, &v)))
Box::new(it.map_ok(|(k, v)| decode_tuple_from_kv(&k, &v, None)))
}
/// Scan on a range with a certain validity.

@ -271,7 +271,7 @@ impl RocksDbIterator {
None
} else {
// upper bound is exclusive
Some(decode_tuple_from_kv(k_slice, v_slice))
Some(decode_tuple_from_kv(k_slice, v_slice, None))
}
}
})
@ -305,7 +305,7 @@ impl RocksDbSkipIterator {
return Ok(None);
}
let (ret, nxt_bound) = check_key_for_validity(k_slice, self.valid_at);
let (ret, nxt_bound) = check_key_for_validity(k_slice, self.valid_at, None);
self.next_bound = nxt_bound;
if let Some(mut tup) = ret {
extend_tuple_from_v(&mut tup, v_slice);

@ -344,7 +344,7 @@ impl<'l> Iterator for TupleIter<'l> {
Ok(State::Row) => {
let k = self.0.read::<Vec<u8>, _>(0).unwrap();
let v = self.0.read::<Vec<u8>, _>(1).unwrap();
let tuple = decode_tuple_from_kv(&k, &v);
let tuple = decode_tuple_from_kv(&k, &v, None);
Some(Ok(tuple))
}
Err(err) => Some(Err(miette!(err))),
@ -388,7 +388,7 @@ impl<'l> SkipIter<'l> {
State::Done => return Ok(None),
State::Row => {
let k = self.stmt.read::<Vec<u8>, _>(0).unwrap();
let (ret, nxt_bound) = check_key_for_validity(&k, self.valid_at);
let (ret, nxt_bound) = check_key_for_validity(&k, self.valid_at, None);
self.next_bound = nxt_bound;
if let Some(mut tup) = ret {
let v = self.stmt.read::<Vec<u8>, _>(1).unwrap();

@ -95,7 +95,7 @@ impl<'s> StoreTx<'s> for TempTx {
Box::new(
self.store
.range(lower.to_vec()..upper.to_vec())
.map(|(k, v)| Ok(decode_tuple_from_kv(k, v))),
.map(|(k, v)| Ok(decode_tuple_from_kv(k, v, None))),
)
}
@ -111,6 +111,7 @@ impl<'s> StoreTx<'s> for TempTx {
upper: upper.to_vec(),
valid_at,
next_bound: lower.to_vec(),
size_hint: None,
}
.map(Ok),
)

Loading…
Cancel
Save