parsing of HNSW search query

main
Ziyang Hu 1 year ago
parent e352c762f1
commit bf479c3a64

@ -52,7 +52,7 @@ param = @{"$" ~ (XID_CONTINUE | "_")*}
ident = @{XID_START ~ ("_" | XID_CONTINUE)*}
underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*}
relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)}
vector_index_ident = @{"~" ~ compound_or_index_ident}
vector_index_ident = _{"~" ~ compound_or_index_ident}
compound_ident = @{ident ~ ("." ~ ident)*}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?}
@ -81,7 +81,7 @@ relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"}
vector_search = {vector_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"}
disjunction = {(atom ~ "or" )* ~ atom}
atom = _{ negation | relation_named_apply | relation_apply | rule_apply | unify_multi | unify | expr | grouped}
atom = _{ negation | relation_named_apply | relation_apply | vector_search | rule_apply | unify_multi | unify | expr | grouped}
unify = {var ~ "=" ~ expr}
unify_multi = {var ~ "in" ~ expr}
negation = {"not" ~ atom}

@ -907,6 +907,26 @@ pub(crate) enum InputAtom {
Unification {
inner: Unification,
},
HnswSearch {
inner: HnswSearch,
},
}
#[derive(Clone)]
pub(crate) struct HnswSearch {
pub(crate) relation: Symbol,
pub(crate) index: Symbol,
pub(crate) bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
pub(crate) k: usize,
pub(crate) ef: usize,
pub(crate) query: Expr,
pub(crate) bind_field: Option<Expr>,
pub(crate) bind_field_idx: Option<Expr>,
pub(crate) bind_distance: Option<Expr>,
pub(crate) bind_vector: Option<Expr>,
pub(crate) radius: Option<f64>,
pub(crate) filter: Option<Expr>,
pub(crate) span: SourceSpan,
}
impl Debug for InputAtom {
@ -979,6 +999,9 @@ impl Display for InputAtom {
}
write!(f, "{expr}")?;
}
InputAtom::HnswSearch { .. } => {
todo!()
}
}
Ok(())
}
@ -1005,6 +1028,9 @@ impl InputAtom {
InputAtom::Relation { inner, .. } => inner.span,
InputAtom::Predicate { inner, .. } => inner.span(),
InputAtom::Unification { inner, .. } => inner.span,
InputAtom::HnswSearch { .. } => {
todo!()
}
}
}
}

@ -23,7 +23,7 @@ use crate::data::aggr::{parse_aggr, Aggregation};
use crate::data::expr::Expr;
use crate::data::functions::{str2vld, MAX_VALIDITY_TS};
use crate::data::program::{
FixedRuleApply, FixedRuleArg, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
FixedRuleApply, FixedRuleArg, HnswSearch, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom,
QueryAssertion, QueryOutOptions, RelationOp, SortDir, Unification,
};
@ -618,6 +618,147 @@ fn parse_atom(
},
}
}
Rule::vector_search => {
let span = src.extract_span();
let mut src = src.into_inner();
let name_p = src.next().unwrap();
let name_segs = name_p.as_str().split(':').collect_vec();
#[derive(Debug, Error, Diagnostic)]
#[error("Search head must be of the form `relation_name:index_name`")]
#[diagnostic(code(parser::invalid_search_head))]
struct InvalidSearchHead(#[label] SourceSpan);
ensure!(
name_segs.len() == 2,
InvalidSearchHead(name_p.extract_span())
);
let relation = Symbol::new(name_segs[0], name_p.extract_span());
let index = Symbol::new(name_segs[1], name_p.extract_span());
let bindings: BTreeMap<SmartString<LazyCompact>, Expr> = src
.next()
.unwrap()
.into_inner()
.map(|arg| extract_named_apply_arg(arg, param_pool))
.try_collect()?;
let mut opts = HnswSearch {
relation,
index,
bindings,
k: 0,
ef: 0,
query: Expr::Const {
val: DataValue::Bot,
span: SourceSpan::default(),
},
bind_field: None,
bind_field_idx: None,
bind_distance: None,
bind_vector: None,
radius: None,
filter: None,
span,
};
for field in src {
let mut fs = field.into_inner();
let fname = fs.next().unwrap();
let fval = fs.next().unwrap();
match fname.as_str() {
"query" => {
opts.query = build_expr(fval, param_pool)?;
}
"filter" => {
opts.filter = Some(build_expr(fval, param_pool)?);
}
"bind_field" => {
opts.bind_field = Some(build_expr(fval, param_pool)?);
}
"bind_field_idx" => {
opts.bind_field_idx = Some(build_expr(fval, param_pool)?);
}
"bind_distance" => {
opts.bind_distance = Some(build_expr(fval, param_pool)?);
}
"bind_vector" => {
opts.bind_vector = Some(build_expr(fval, param_pool)?);
}
"k" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForHnswK(#[label] SourceSpan);
let k = build_expr(fval, param_pool)?;
let k = k.eval_to_const()?;
let k = k
.get_int()
.ok_or_else(|| ExpectedPosIntForHnswK(fname.extract_span()))?;
ensure!(k > 0, ExpectedPosIntForHnswK(fname.extract_span()));
opts.k = k as usize;
}
"ef" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `ef`")]
#[diagnostic(code(parser::expected_int_for_hnsw_ef))]
struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
let ef = build_expr(fval, param_pool)?;
let ef = ef.eval_to_const()?;
let ef = ef
.get_int()
.ok_or_else(|| ExpectedPosIntForHnswEf(fname.extract_span()))?;
ensure!(ef > 0, ExpectedPosIntForHnswEf(fname.extract_span()));
opts.ef = ef as usize;
}
"radius" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `radius`")]
#[diagnostic(code(parser::expected_float_for_hnsw_radius))]
struct ExpectedPosFloatForHnswRadius(#[label] SourceSpan);
let radius = build_expr(fval, param_pool)?;
let radius = radius.eval_to_const()?;
let radius = radius
.get_float()
.ok_or_else(|| ExpectedPosFloatForHnswRadius(fname.extract_span()))?;
ensure!(
radius > 0.0,
ExpectedPosFloatForHnswRadius(fname.extract_span())
);
opts.radius = Some(radius);
}
_ => {
#[derive(Debug, Error, Diagnostic)]
#[error("Unexpected HNSW option `{0}`")]
#[diagnostic(code(parser::unexpected_hnsw_option))]
struct UnexpectedHnswOption(String, #[label] SourceSpan);
bail!(UnexpectedHnswOption(
fname.as_str().into(),
fname.extract_span()
))
}
}
}
if matches!(
opts.query,
Expr::Const {
val: DataValue::Bot,
..
}
) {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `query` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct HnswQueryRequired(#[label] SourceSpan);
bail!(HnswQueryRequired(span))
}
InputAtom::HnswSearch { inner: opts }
}
Rule::relation_named_apply => {
let span = src.extract_span();
let mut src = src.into_inner();
@ -627,19 +768,7 @@ fn parse_atom(
.next()
.unwrap()
.into_inner()
.map(|pair| -> Result<(SmartString<LazyCompact>, Expr)> {
let mut inner = pair.into_inner();
let name_p = inner.next().unwrap();
let name = SmartString::from(name_p.as_str());
let arg = match inner.next() {
Some(a) => build_expr(a, param_pool)?,
None => Expr::Binding {
var: Symbol::new(name.clone(), name_p.extract_span()),
tuple_pos: None,
},
};
Ok((name, arg))
})
.map(|arg| extract_named_apply_arg(arg, param_pool))
.try_collect()?;
let valid_at = match src.next() {
None => None,
@ -661,6 +790,23 @@ fn parse_atom(
})
}
fn extract_named_apply_arg(
pair: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,
) -> Result<(SmartString<LazyCompact>, Expr)> {
let mut inner = pair.into_inner();
let name_p = inner.next().unwrap();
let name = SmartString::from(name_p.as_str());
let arg = match inner.next() {
Some(a) => build_expr(a, param_pool)?,
None => Expr::Binding {
var: Symbol::new(name.clone(), name_p.extract_span()),
tuple_pos: None,
},
};
Ok((name, arg))
}
fn parse_rule_head(
src: Pair<'_>,
param_pool: &BTreeMap<String, DataValue>,

@ -121,7 +121,9 @@ impl InputAtom {
InputAtom::Unification { inner } => {
bail!(UnsafeNegation(inner.span))
}
InputAtom::HnswSearch { .. } => todo!(),
},
InputAtom::HnswSearch { .. } => todo!(),
})
}
@ -225,6 +227,9 @@ impl InputAtom {
InputAtom::Unification { inner: u } => {
Disjunction::singlet(NormalFormAtom::Unification(u))
}
InputAtom::HnswSearch { .. } => {
todo!()
}
})
}
}

@ -41,6 +41,7 @@ pub(crate) struct HnswIndexManifest {
pub(crate) keep_pruned_connections: bool,
}
#[derive(Clone)]
pub(crate) struct HnswKnnQueryOptions {
k: usize,
ef: usize,

@ -766,10 +766,7 @@ fn test_vec_types() {
let res = db
.run_script("?[v] <- [[rand_vec(5)]]", Default::default())
.unwrap();
assert_eq!(
5,
res.into_json()["rows"][0][0].as_array().unwrap().len()
);
assert_eq!(5, res.into_json()["rows"][0][0].as_array().unwrap().len());
let res = db
.run_script(r#"
val[v] <- [[vec([1,2,3,4,5,6,7,8])]]
@ -782,7 +779,8 @@ fn test_vec_types() {
#[test]
fn test_vec_index() {
let db = DbInstance::new("mem", "", "").unwrap();
db.run_script(r"
db.run_script(
r"
?[k, v] <- [['a', [1,2,3,4,5,6,7,8]],
['b', [2,3,4,5,6,7,8,9]],
['bb', [2,3,4,5,6,7,8,9]],
@ -791,9 +789,12 @@ fn test_vec_index() {
['b', [1,1,1,1,1,1,1,1]]]
:create a {k: String => v: <F32; 8>}
", Default::default())
.unwrap();
db.run_script(r"
",
Default::default(),
)
.unwrap();
db.run_script(
r"
::hnsw create a:vec {
dim: 8,
m: 50,
@ -802,9 +803,12 @@ fn test_vec_index() {
distance: Cosine,
ef_construction: 20,
filter: k != 'k1'
}", Default::default())
.unwrap();
db.run_script(r"
}",
Default::default(),
)
.unwrap();
db.run_script(
r"
?[k, v] <- [
['a2', [1,2,3,4,5,6,7,8]],
['b2', [2,3,4,5,6,7,8,9]],
@ -814,7 +818,16 @@ fn test_vec_index() {
['b2', [1,1,1,1,1,1,1,1]]
]
:put a {k => v}
", Default::default())
.unwrap();
",
Default::default(),
)
.unwrap();
db.run_script(
r"
?[key, v] := ~a:vec{k: 'a', v | query: [1,1,1,1,1,1,1,1]}
",
Default::default(),
)
.unwrap();
println!("{:#?}", db.export_relations(["a", "a:vec"].iter()));
}

Loading…
Cancel
Save