diff --git a/cozo-core/src/runtime/hnsw.rs b/cozo-core/src/runtime/hnsw.rs index 1b713c57..e52ee3a7 100644 --- a/cozo-core/src/runtime/hnsw.rs +++ b/cozo-core/src/runtime/hnsw.rs @@ -984,12 +984,12 @@ impl<'a> SessionTx<'a> { } if config.bind_vector.is_some() { let vec = if cand_key.2 < 0 { + cand_tuple[cand_key.1].clone() + } else { match &cand_tuple[cand_key.1] { DataValue::List(v) => v[cand_key.2 as usize].clone(), - _ => bail!("corrupted index"), + v => bail!("corrupted index value {:?}", v), } - } else { - cand_tuple[cand_key.1].clone() }; cand_tuple.push(vec); } diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 1a506edd..fe8cd0f3 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -708,7 +708,7 @@ fn test_vec_index_insertion() { :create a {k: String => v: , m: Bool} ", ) - .unwrap(); + .unwrap(); db.run_default( r" ::hnsw create a:vec { @@ -723,12 +723,17 @@ fn test_vec_index_insertion() { #keep_pruned_connections: true, }", ) + .unwrap(); + let res = db + .run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k") .unwrap(); - let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap(); assert_eq!(res.rows.len(), 1); println!("update!"); - db.run_default(r#"?[k, m] <- [["a", false]] :update a {}"#).unwrap(); - let res = db.run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k").unwrap(); + db.run_default(r#"?[k, m] <- [["a", false]] :update a {}"#) + .unwrap(); + let res = db + .run_default("?[k] := *a:vec{layer: 0, fr_k, to_k}, k = fr_k or k = to_k") + .unwrap(); assert_eq!(res.rows.len(), 0); println!("{}", res.into_json()); } @@ -908,7 +913,9 @@ fn test_lsh_indexing3() { #[test] fn filtering() { let db = DbInstance::default(); - let res = db.run_default(r" + let res = db + .run_default( + r" { ?[x, y] <- [[1, 2]] :create _rel {x => y} @@ -917,7 +924,8 @@ fn filtering() { { ?[x, y] := x = 1, *_rel{x, y: 3}, y = 2 } - ") + ", + ) .unwrap(); assert_eq!(0, res.rows.len()); } @@ -937,8 +945,7 @@ fn test_lsh_indexing4() { .unwrap(); db.run_default("?[k, v] <- [['a', 'ewiygfspeoighjsfcfxzdfncalsdf']] :put a {k => v}") .unwrap(); - db.run_default("?[k] <- [['a']] :rm a {k}") - .unwrap(); + db.run_default("?[k] <- [['a']] :rm a {k}").unwrap(); let res = db .run_default("?[k] := ~a:lsh{k | query: 'ewiygfspeoighjsfcfxzdfncalsdf', k: 1}") .unwrap(); @@ -946,7 +953,6 @@ fn test_lsh_indexing4() { } } - #[test] fn test_lsh_indexing() { let db = DbInstance::new("mem", "", "").unwrap(); @@ -1457,3 +1463,59 @@ fn crashy_imperative() { ) .unwrap(); } + +#[test] +fn hnsw_index() { + let db = DbInstance::default(); + db.run_default( + r#" + :create beliefs { + belief_id: Uuid, + character_id: Uuid, + belief: String, + last_accessed_at: Validity default [floor(now()), true], + => + details: String default "", + parent_belief_id: Uuid? default null, + valence: Float default 0, + aspects: [(String, Float, String, String)] default [], + belief_embedding: , + details_embedding: , + } + "#, + ) + .unwrap(); + db.run_default( + r#" + ::hnsw create beliefs:embedding_space { + dim: 768, + m: 50, + dtype: F32, + fields: [belief_embedding, details_embedding], + distance: Cosine, + ef_construction: 20, + extend_candidates: false, + keep_pruned_connections: false, + } + "#, + ) + .unwrap(); + db.run_default(r#" + ?[belief_id, character_id, belief, belief_embedding, details_embedding] <- [[rand_uuid_v1(), rand_uuid_v1(), "test", rand_vec(768), rand_vec(768)]] + :put beliefs {} + "#).unwrap(); + let res = db.run_default(r#" + ?[belief, valence, dist, character_id, vector] := ~beliefs:embedding_space{ belief, valence, character_id | + query: rand_vec(768), + k: 100, + ef: 20, + radius: 1.0, + bind_distance: dist, + bind_vector: vector + } + + :order -valence + :order dist + "#).unwrap(); + println!("{}", res.into_json()["rows"][0][4]); +}