diff --git a/Cargo.lock b/Cargo.lock index f125e7da..5d717d50 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,6 +353,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -705,6 +714,7 @@ dependencies = [ "serde_bytes", "serde_derive", "serde_json", + "sha2 0.9.9", "sled", "smallvec", "smartstring", @@ -1034,13 +1044,22 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "digest" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "crypto-common", ] @@ -2346,6 +2365,12 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "openssl" version = "0.10.50" @@ -2547,7 +2572,7 @@ checksum = "6733073c7cff3d8459fda0e42f13a047870242aed8b509fe98000928975f359e" dependencies = [ "once_cell", "pest", - "sha2", + "sha2 0.10.6", ] [[package]] @@ -3413,6 +3438,19 @@ dependencies = [ "serde", ] +[[package]] +name = "sha2" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if 1.0.0", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + [[package]] name = "sha2" version = "0.10.6" @@ -3421,7 +3459,7 @@ checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" dependencies = [ "cfg-if 1.0.0", "cpufeatures", - "digest", + "digest 0.10.6", ] [[package]] diff --git a/cozo-core/Cargo.toml b/cozo-core/Cargo.toml index 914284a5..8dd0a527 100644 --- a/cozo-core/Cargo.toml +++ b/cozo-core/Cargo.toml @@ -127,4 +127,5 @@ sqlite3-src = { version = "0.4.0", optional = true, features = ["bundled"] } js-sys = { version = "0.3.60", optional = true } graph = { version = "0.3.0", optional = true } crossbeam = "0.8.2" -ndarray = { version = "0.15.6", features = ["serde"] } \ No newline at end of file +ndarray = { version = "0.15.6", features = ["serde"] } +sha2 = "0.9.8" \ No newline at end of file diff --git a/cozo-core/src/data/value.rs b/cozo-core/src/data/value.rs index 2d7032db..9b86ba37 100644 --- a/cozo-core/src/data/value.rs +++ b/cozo-core/src/data/value.rs @@ -156,7 +156,7 @@ pub enum DataValue { Bot, } -#[derive(Clone, serde_derive::Serialize, serde_derive::Deserialize)] +#[derive(Debug, Clone, serde_derive::Serialize, serde_derive::Deserialize)] pub enum Vector { F32(Array1), F64(Array1), diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index 0fcf5946..e997fcb6 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -49,28 +49,14 @@ pub(crate) struct HnswIndexConfig { pub(crate) vec_dim: usize, pub(crate) dtype: VecElementType, pub(crate) vec_fields: Vec>, - pub(crate) tag_fields: Vec>, pub(crate) distance: HnswDistance, pub(crate) ef_construction: usize, - pub(crate) max_elements: usize, + pub(crate) m_neighbours: usize, pub(crate) index_filter: Option, + pub(crate) extend_candidates: bool, + pub(crate) keep_pruned_connections: bool } -#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)] -pub(crate) struct HnswIndexManifest { - pub(crate) base_relation: SmartString, - pub(crate) index_name: SmartString, - pub(crate) vec_dim: usize, - pub(crate) dtype: VecElementType, - pub(crate) vec_fields: Vec, - pub(crate) tag_fields: Vec, - pub(crate) distance: HnswDistance, - pub(crate) ef_construction: usize, - pub(crate) max_elements: usize, - pub(crate) index_filter: Option, -} - - #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize, )] @@ -80,17 +66,6 @@ pub(crate) enum HnswDistance { Cosine, } -pub(crate) struct HnswKnnQueryOptions { - k: usize, - ef: usize, - max_distance: f64, - min_margin: f64, - auto_margin_factor: Option, - bind_field: Option, - bind_distance: Option, - bind_vector: Option, -} - #[derive(Debug, Diagnostic, Error)] #[error("Cannot interpret {0} as process ID")] #[diagnostic(code(parser::not_proc_id))] @@ -211,13 +186,13 @@ pub(crate) fn parse_sys( let mut vec_dim = 0; let mut dtype = VecElementType::F32; let mut vec_fields = vec![]; - let mut tag_fields = vec![]; let mut distance = HnswDistance::L2; let mut ef_construction = 0; let mut max_elements = 0; let mut index_filter = None; + let mut extend_candidates = false; + let mut keep_pruned_connections = false; - // TODO this is a bit of a mess for opt_pair in inner { let mut opt_inner = opt_pair.into_inner(); let opt_name = opt_inner.next().unwrap(); @@ -242,11 +217,7 @@ pub(crate) fn parse_sys( let fields = build_expr(opt_val, &Default::default())?; vec_fields = fields.to_var_list()?; } - "tags" => { - let fields = build_expr(opt_val, &Default::default())?; - tag_fields = fields.to_var_list()?; - } - "distance" => { + "distance" | "dist" => { distance = match opt_val.as_str() { "L2" => HnswDistance::L2, "IP" => HnswDistance::InnerProduct, @@ -259,13 +230,13 @@ pub(crate) fn parse_sys( } } } - "ef_construction" => { + "ef_construction" | "ef" => { ef_construction = opt_val .as_str() .parse() .map_err(|e| miette!("Invalid ef_construction: {}", e))?; } - "max_elements" => { + "m_neighbours" | "m" | "M" => { max_elements = opt_val .as_str() .parse() @@ -274,6 +245,12 @@ pub(crate) fn parse_sys( "filter" => { index_filter = Some(opt_val.as_str().to_string()); } + "extend_candidates" => { + extend_candidates = opt_val.as_str() == "true"; + } + "keep_pruned_connections" => { + keep_pruned_connections = opt_val.as_str() == "true"; + } _ => return Err(miette!("Invalid option: {}", opt_name.as_str())), } } @@ -283,11 +260,12 @@ pub(crate) fn parse_sys( vec_dim, dtype, vec_fields, - tag_fields, distance, ef_construction, - max_elements, + m_neighbours: max_elements, index_filter, + extend_candidates, + keep_pruned_connections, }) } Rule::index_drop => { diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index e504c582..f1ffc8e7 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -153,7 +153,7 @@ impl<'a> SessionTx<'a> { let mut old_tuples: Vec = vec![]; for tuple in res_iter { - let extracted = key_extractors + let extracted: Vec = key_extractors .iter() .map(|ex| ex.extract_data(&tuple, cur_vld)) .try_collect()?; @@ -309,7 +309,7 @@ impl<'a> SessionTx<'a> { key_extractors.extend(val_extractors); for tuple in res_iter { - let extracted = key_extractors + let extracted: Vec = key_extractors .iter() .map(|ex| ex.extract_data(&tuple, cur_vld)) .try_collect()?; @@ -360,7 +360,7 @@ impl<'a> SessionTx<'a> { )?; for tuple in res_iter { - let extracted = key_extractors + let extracted: Vec = key_extractors .iter() .map(|ex| ex.extract_data(&tuple, cur_vld)) .try_collect()?; @@ -411,7 +411,7 @@ impl<'a> SessionTx<'a> { key_extractors.extend(val_extractors); for tuple in res_iter { - let extracted = key_extractors + let extracted: Vec = key_extractors .iter() .map(|ex| ex.extract_data(&tuple, cur_vld)) .try_collect()?; diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index cbda0519..02883666 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -7,30 +7,392 @@ */ use crate::data::expr::{eval_bytecode_pred, Bytecode}; -use crate::data::tuple::Tuple; +use crate::data::relation::VecElementType; +use crate::data::tuple::{decode_tuple_from_key, Tuple}; use crate::data::value::Vector; -use crate::parse::sys::HnswIndexManifest; +use crate::parse::sys::HnswDistance; use crate::runtime::relation::RelationHandle; use crate::runtime::transact::SessionTx; -use crate::DataValue; -use miette::Result; +use crate::{decode_tuple_from_kv, DataValue, Symbol}; +use miette::{bail, Result}; +use ordered_float::OrderedFloat; +use priority_queue::PriorityQueue; +use rand::Rng; use smartstring::{LazyCompact, SmartString}; +use std::cmp::{max, Reverse}; +use std::collections::BTreeSet; + +#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] +pub(crate) struct HnswIndexManifest { + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, + pub(crate) vec_dim: usize, + pub(crate) dtype: VecElementType, + pub(crate) vec_fields: Vec, + pub(crate) distance: HnswDistance, + pub(crate) ef_construction: usize, + pub(crate) m_neighbours: usize, + pub(crate) m_max: usize, + pub(crate) m_max0: usize, + pub(crate) level_multiplier: f64, + pub(crate) index_filter: Option, + pub(crate) extend_candidates: bool, + pub(crate) keep_pruned_connections: bool, +} + +pub(crate) struct HnswKnnQueryOptions { + k: usize, + ef: usize, + max_distance: f64, + min_margin: f64, + auto_margin_factor: Option, + bind_field: Option, + bind_distance: Option, + bind_vector: Option, +} + +impl HnswIndexManifest { + fn get_random_level(&self) -> i64 { + let mut rng = rand::thread_rng(); + let uniform_num: f64 = rng.gen_range(0.0..1.0); + let r = -uniform_num.ln() * self.level_multiplier; + // the level is the largest integer smaller than r + -(r.floor() as i64) + } + fn get_vector(&self, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result { + let field = tuple.get(idx).unwrap(); + if sub_idx >= 0 { + match field { + DataValue::List(l) => match l.get(sub_idx as usize) { + Some(DataValue::Vec(v)) => Ok(v.clone()), + _ => bail!( + "Cannot extract vector from {} for sub index {}", + field, + sub_idx + ), + }, + _ => bail!("Cannot interpret {} as list", field), + } + } else { + match field { + DataValue::Vec(v) => Ok(v.clone()), + _ => bail!("Cannot interpret {} as vector", field), + } + } + } + fn get_distance(&self, q: &Vector, tuple: &Tuple, idx: usize, sub_idx: i32) -> Result { + let field = tuple.get(idx).unwrap(); + let target = if sub_idx >= 0 { + match field { + DataValue::List(l) => match l.get(sub_idx as usize) { + Some(DataValue::Vec(v)) => v, + _ => bail!( + "Cannot extract vector from {} for sub index {}", + field, + sub_idx + ), + }, + _ => bail!("Cannot interpret {} as list", field), + } + } else { + match field { + DataValue::Vec(v) => v, + _ => bail!("Cannot interpret {} as vector", field), + } + }; + Ok(match self.distance { + HnswDistance::L2 => match (q, target) { + (Vector::F32(a), Vector::F32(b)) => { + let diff = a - b; + diff.dot(&diff) as f64 + } + (Vector::F64(a), Vector::F64(b)) => { + let diff = a - b; + diff.dot(&diff) + } + _ => bail!( + "Cannot compute L2 distance between {:?} and {:?}", + q, + target + ), + }, + HnswDistance::Cosine => match (q, target) { + (Vector::F32(a), Vector::F32(b)) => { + let a_norm = a.dot(a) as f64; + let b_norm = b.dot(b) as f64; + let dot = a.dot(b) as f64; + 1.0 - dot / (a_norm * b_norm).sqrt() + } + (Vector::F64(a), Vector::F64(b)) => { + let a_norm = a.dot(a) as f64; + let b_norm = b.dot(b) as f64; + let dot = a.dot(b); + 1.0 - dot / (a_norm * b_norm).sqrt() + } + _ => bail!( + "Cannot compute cosine distance between {:?} and {:?}", + q, + target + ), + }, + HnswDistance::InnerProduct => match (q, target) { + (Vector::F32(a), Vector::F32(b)) => { + let dot = a.dot(b); + 1. - dot as f64 + } + (Vector::F64(a), Vector::F64(b)) => { + let dot = a.dot(b); + 1. - dot as f64 + } + _ => bail!( + "Cannot compute inner product between {:?} and {:?}", + q, + target + ), + }, + }) + } +} impl<'a> SessionTx<'a> { fn hnsw_put_vector( &mut self, - vec: &Vector, + tuple: &Tuple, + q: &Vector, idx: usize, subidx: i32, + manifest: &HnswIndexManifest, orig_table: &RelationHandle, idx_table: &RelationHandle, - tags: &[SmartString] ) -> Result<()> { + let start_tuple = + idx_table.encode_key_for_store(&vec![DataValue::from(i64::MIN)], Default::default())?; + let end_tuple = + idx_table.encode_key_for_store(&vec![DataValue::from(1)], Default::default())?; + let ep_res = self.store_tx.range_scan(&start_tuple, &end_tuple).next(); + if let Some(ep) = ep_res { + let (ep_key_bytes, _) = ep?; + let ep_key_tuple = decode_tuple_from_key(&ep_key_bytes); + // bottom level since we are going up + let bottom_level = ep_key_tuple[0].get_int().unwrap(); + let ep_key = ep_key_tuple[1..orig_table.metadata.keys.len() + 1].to_vec(); + let ep_idx = ep_key_tuple[orig_table.metadata.keys.len() + 1] + .get_int() + .unwrap() as usize; + let ep_subidx = ep_key_tuple[orig_table.metadata.keys.len() + 2] + .get_int() + .unwrap() as i32; + let ep_distance = + self.hnsw_compare_vector(q, &ep_key, idx, subidx, manifest, orig_table)?; + let mut found_nn = PriorityQueue::new(); + found_nn.push((ep_key, ep_idx, ep_subidx), OrderedFloat(ep_distance)); + let target_level = manifest.get_random_level(); + if target_level < bottom_level { + // this becomes the entry point + self.hnsw_put_fresh_at_levels( + tuple, + idx, + subidx, + orig_table, + idx_table, + target_level, + bottom_level - 1, + )?; + } + for current_level in bottom_level..target_level { + self.hnsw_search_level( + q, + 1, + current_level, + manifest, + orig_table, + idx_table, + &mut found_nn, + )?; + } + for current_level in max(target_level, bottom_level)..=0 { + self.hnsw_search_level( + q, + manifest.ef_construction, + current_level, + manifest, + orig_table, + idx_table, + &mut found_nn, + )?; + // add bidirectional links to the nearest neighbors + todo!(); + // shrink links if necessary + todo!(); + } + } else { + // This is the first vector in the index. + let level = manifest.get_random_level(); + self.hnsw_put_fresh_at_levels(tuple, idx, subidx, orig_table, idx_table, level, 0)?; + } + Ok(()) + } + fn hnsw_compare_vector( + &self, + q: &Vector, + target_key: &[DataValue], + target_idx: usize, + target_subidx: i32, + manifest: &HnswIndexManifest, + orig_table: &RelationHandle, + ) -> Result { + let target_key_bytes = orig_table.encode_key_for_store(target_key, Default::default())?; + let bytes = match self.store_tx.get(&target_key_bytes, false)? { + Some(bytes) => bytes, + None => bail!("Indexed data not found, this signifies a bug in the index."), + }; + let target_tuple = decode_tuple_from_kv(&target_key_bytes, &bytes); + manifest.get_distance(q, &target_tuple, target_idx, target_subidx) + } + fn hnsw_select_neighbours_heuristic(&self) -> Result<()> { todo!() } - pub(crate) fn hnsw_put( + fn hnsw_search_level( + &self, + q: &Vector, + ef: usize, + cur_level: i64, + manifest: &HnswIndexManifest, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + found_nn: &mut PriorityQueue<(Tuple, usize, i32), OrderedFloat>, + ) -> Result<()> { + let mut visited: BTreeSet<(Tuple, usize, i32)> = BTreeSet::new(); + let mut candidates: PriorityQueue<(Tuple, usize, i32), Reverse>> = + PriorityQueue::new(); + + for item in found_nn.iter() { + visited.insert(item.0.clone()); + candidates.push(item.0.clone(), Reverse(*item.1)); + } + + while let Some((candidate, Reverse(OrderedFloat(candidate_dist)))) = candidates.pop() { + let (_, OrderedFloat(furtherest_dist)) = found_nn.peek().unwrap(); + let furtherest_dist = *furtherest_dist; + if candidate_dist > furtherest_dist { + break; + } + // loop over each of the candidate's neighbors + for neighbour_triple in self.hnsw_get_neighbours( + candidate.0, + candidate.1, + candidate.2, + cur_level, + idx_table, + )? { + if visited.contains(&neighbour_triple) { + continue; + } + let neighbour_dist = self.hnsw_compare_vector( + q, + &neighbour_triple.0, + neighbour_triple.1, + neighbour_triple.2, + manifest, + orig_table, + )?; + let (_, OrderedFloat(cand_furtherest_dist)) = found_nn.peek().unwrap(); + if found_nn.len() < ef || neighbour_dist < *cand_furtherest_dist { + candidates.push( + neighbour_triple.clone(), + Reverse(OrderedFloat(neighbour_dist)), + ); + found_nn.push(neighbour_triple.clone(), OrderedFloat(neighbour_dist)); + if found_nn.len() > ef { + found_nn.pop(); + } + } + visited.insert(neighbour_triple); + } + } + + Ok(()) + } + fn hnsw_get_neighbours<'b>( + &'b self, + cand_key: Vec, + cand_idx: usize, + cand_sub_idx: i32, + level: i64, + idx_handle: &RelationHandle, + ) -> Result + 'b> { + let mut start_tuple = Vec::with_capacity(cand_key.len() + 3); + start_tuple.push(DataValue::from(level)); + start_tuple.extend_from_slice(&cand_key); + start_tuple.push(DataValue::from(cand_idx as i64)); + start_tuple.push(DataValue::from(cand_sub_idx as i64)); + let mut end_tuple = start_tuple.clone(); + end_tuple.push(DataValue::Bot); + let start_bytes = idx_handle.encode_key_for_store(&start_tuple, Default::default())?; + let end_bytes = idx_handle.encode_key_for_store(&end_tuple, Default::default())?; + Ok(self + .store_tx + .range_scan(&start_bytes, &end_bytes) + .filter_map(move |res| { + let (key, _value) = res.unwrap(); + let key_tuple = decode_tuple_from_key(&key); + let key_total_len = key_tuple.len(); + let key_idx = key_tuple[key_total_len - 2].get_int().unwrap() as usize; + let key_subidx = key_tuple[key_total_len - 1].get_int().unwrap() as i32; + let key_slice = key_tuple[cand_key.len() + 3..key_total_len - 2].to_vec(); + if key_slice == cand_key { + None + } else { + Some((key_slice, key_idx, key_subidx)) + } + })) + } + fn hnsw_put_fresh_at_levels( &mut self, - config: &HnswIndexManifest, + tuple: &Tuple, + idx: usize, + subidx: i32, + orig_table: &RelationHandle, + idx_table: &RelationHandle, + bottom_level: i64, + top_level: i64, + ) -> Result<()> { + let mut target_key = vec![DataValue::Null]; + let mut canary_key = vec![DataValue::from(1)]; + for _ in 0..2 { + for i in 0..orig_table.metadata.keys.len() { + target_key.push(tuple.get(i).unwrap().clone()); + canary_key.push(DataValue::Null); + } + target_key.push(DataValue::from(idx as i64)); + target_key.push(DataValue::from(subidx as i64)); + canary_key.push(DataValue::Null); + canary_key.push(DataValue::Null); + } + let target_value = [DataValue::from(0.0), DataValue::Null]; + let target_key_bytes = idx_table.encode_key_for_store(&target_key, Default::default())?; + + // canary value is for conflict detection: prevent the scenario of disconnected graphs at all levels + let canary_value = [ + DataValue::from(bottom_level), + DataValue::Bytes(target_key_bytes), + ]; + let canary_key_bytes = idx_table.encode_key_for_store(&canary_key, Default::default())?; + let canary_value_bytes = + idx_table.encode_val_for_store(&canary_value, Default::default())?; + self.store_tx.put(&canary_key_bytes, &canary_value_bytes)?; + + for cur_level in bottom_level..=top_level { + target_key[0] = DataValue::from(cur_level); + let key = idx_table.encode_key_for_store(&target_key, Default::default())?; + let val = idx_table.encode_val_for_store(&target_value, Default::default())?; + self.store_tx.put(&key, &val)?; + } + Ok(()) + } + pub(crate) fn hnsw_put( + &'a mut self, + manifest: &HnswIndexManifest, orig_table: &RelationHandle, idx_table: &RelationHandle, filter: Option<(&[Bytecode], &mut Vec)>, @@ -42,7 +404,7 @@ impl<'a> SessionTx<'a> { } } let mut extracted_vectors = vec![]; - for idx in &config.vec_fields { + for idx in &manifest.vec_fields { let val = tuple.get(*idx).unwrap(); if let DataValue::Vec(v) = val { extracted_vectors.push((v, *idx, -1 as i32)); @@ -57,34 +419,37 @@ impl<'a> SessionTx<'a> { if extracted_vectors.is_empty() { return Ok(false); } - let mut extracted_tags: Vec> = vec![]; - for tag_idx in &config.tag_fields { - let tag_field = tuple.get(*tag_idx).unwrap(); - if let Some(s) = tag_field.get_str() { - extracted_tags.push(SmartString::from(s)); - } else if let DataValue::List(l) = tag_field { - for tag in l { - if let Some(s) = tag.get_str() { - extracted_tags.push(SmartString::from(s)); - } - } - } - } for (vec, idx, sub) in extracted_vectors { - self.hnsw_put_vector(vec, idx, sub, orig_table, idx_table, &extracted_tags)?; + self.hnsw_put_vector(&tuple, vec, idx, sub, manifest, orig_table, idx_table)?; } Ok(true) } - pub(crate) fn hnsw_remove( - &mut self, - config: &HnswIndexManifest, - orig_table: &RelationHandle, - idx_table: &RelationHandle, - tuple: &Tuple, - ) -> Result<()> { + pub(crate) fn hnsw_remove(&mut self) -> Result<()> { todo!() } - pub(crate) fn hnsw_knn(&self, node: u64, k: usize) -> Vec<(u64, f32)> { + pub(crate) fn hnsw_knn(&self) -> Result<()> { todo!() } } + +#[cfg(test)] +mod tests { + use rand::Rng; + use std::collections::BTreeMap; + + #[test] + fn test_random_level() { + let m = 20; + let mult = 1. / (m as f64).ln(); + let mut rng = rand::thread_rng(); + let mut collected = BTreeMap::new(); + for _ in 0..10000 { + let uniform_num: f64 = rng.gen_range(0.0..1.0); + let r = -uniform_num.ln() * mult; + // the level is the largest integer smaller than r + let level = -(r.floor() as i64); + collected.entry(level).and_modify(|x| *x += 1).or_insert(1); + } + println!("{:?}", collected); + } +} diff --git a/cozo-core/src/runtime/relation.rs b/cozo-core/src/runtime/relation.rs index ace0d0e8..ee58527c 100644 --- a/cozo-core/src/runtime/relation.rs +++ b/cozo-core/src/runtime/relation.rs @@ -23,11 +23,12 @@ use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationM use crate::data::symb::Symbol; use crate::data::tuple::{decode_tuple_from_key, Tuple, TupleT, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; -use crate::parse::sys::{HnswIndexConfig, HnswIndexManifest}; +use crate::parse::sys::HnswIndexConfig; use crate::parse::SourceSpan; use crate::query::compile::IndexPositionUse; use crate::runtime::transact::SessionTx; use crate::{NamedRows, StoreTx}; +use crate::runtime::hnsw::HnswIndexManifest; #[derive( Copy, @@ -65,7 +66,7 @@ impl RelationId { } } -#[derive(Clone, Eq, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] +#[derive(Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] pub(crate) struct RelationHandle { pub(crate) name: SmartString, pub(crate) id: RelationId, @@ -204,7 +205,11 @@ impl RelationHandle { } chosen } - pub(crate) fn encode_key_for_store(&self, tuple: &Tuple, span: SourceSpan) -> Result> { + pub(crate) fn encode_key_for_store( + &self, + tuple: &[DataValue], + span: SourceSpan, + ) -> Result> { let len = self.metadata.keys.len(); ensure!( tuple.len() >= len, @@ -221,7 +226,11 @@ impl RelationHandle { } Ok(ret) } - pub(crate) fn encode_val_for_store(&self, tuple: &Tuple, _span: SourceSpan) -> Result> { + pub(crate) fn encode_val_for_store( + &self, + tuple: &[DataValue], + _span: SourceSpan, + ) -> Result> { let start = self.metadata.keys.len(); let len = self.metadata.non_keys.len(); let mut ret = self.encode_key_prefix(len); @@ -699,34 +708,6 @@ impl<'a> SessionTx<'a> { } } - // We only allow string tags - let mut tag_field_indices = vec![]; - for field in config.tag_fields.iter() { - for (i, col) in rel_handle - .metadata - .keys - .iter() - .chain(rel_handle.metadata.non_keys.iter()).enumerate() - { - if col.name == *field { - let mut col_type = col.typing.coltype.clone(); - if let ColType::List { eltype, .. } = &col_type { - col_type = eltype.coltype.clone(); - } - - if col_type != ColType::String { - bail!( - "Cannot create HNSW index with field {} of type {:?} (expected Str)", - field, - col_type - ); - } - tag_field_indices.push(i); - break; - } - } - } - // Build key columns definitions let mut idx_keys: Vec = vec![ColumnDef { // layer -1 stores the self-loops @@ -775,16 +756,10 @@ impl<'a> SessionTx<'a> { }, // For self-loops, stores a hash of the neighbours, for conflict detection ColumnDef { - name: SmartString::from("tags"), + name: SmartString::from("hash"), typing: NullableColType { - coltype: ColType::List { - eltype: Box::new(NullableColType { - coltype: ColType::String, - nullable: false, - }), - len: None, - }, - nullable: false, + coltype: ColType::Bytes, + nullable: true, }, default_gen: None, }, @@ -823,11 +798,15 @@ impl<'a> SessionTx<'a> { vec_dim: config.vec_dim, dtype: config.dtype, vec_fields: vec_field_indices, - tag_fields: tag_field_indices, distance: config.distance, ef_construction: config.ef_construction, - max_elements: config.max_elements, + m_neighbours: config.m_neighbours, + m_max: config.m_neighbours, + m_max0: config.m_neighbours * 2, + level_multiplier: 1. / (config.m_neighbours as f64).ln(), index_filter: config.index_filter, + extend_candidates: config.extend_candidates, + keep_pruned_connections: config.keep_pruned_connections, }; rel_handle .hnsw_indices diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 6948b543..ced36861 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -787,12 +787,11 @@ fn test_vec_index() { db.run_script(r" ::hnsw create a:vec { dim: 8, + m: 50, dtype: F32, fields: [v], - tags: tags, distance: Cosine, ef_construction: 20, - max_elements: 50, filter: k != 'k1' }", Default::default()) .unwrap();