|
|
|
@ -23,7 +23,7 @@ use crate::data::aggr::{parse_aggr, Aggregation};
|
|
|
|
|
use crate::data::expr::Expr;
|
|
|
|
|
use crate::data::functions::{str2vld, MAX_VALIDITY_TS};
|
|
|
|
|
use crate::data::program::{
|
|
|
|
|
FixedRuleApply, FixedRuleArg, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
|
|
|
|
|
FixedRuleApply, FixedRuleArg, HnswSearch, InputAtom, InputInlineRule, InputInlineRulesOrFixed,
|
|
|
|
|
InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom,
|
|
|
|
|
QueryAssertion, QueryOutOptions, RelationOp, SortDir, Unification,
|
|
|
|
|
};
|
|
|
|
@ -618,6 +618,147 @@ fn parse_atom(
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Rule::vector_search => {
|
|
|
|
|
let span = src.extract_span();
|
|
|
|
|
let mut src = src.into_inner();
|
|
|
|
|
let name_p = src.next().unwrap();
|
|
|
|
|
let name_segs = name_p.as_str().split(':').collect_vec();
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Search head must be of the form `relation_name:index_name`")]
|
|
|
|
|
#[diagnostic(code(parser::invalid_search_head))]
|
|
|
|
|
struct InvalidSearchHead(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
ensure!(
|
|
|
|
|
name_segs.len() == 2,
|
|
|
|
|
InvalidSearchHead(name_p.extract_span())
|
|
|
|
|
);
|
|
|
|
|
let relation = Symbol::new(name_segs[0], name_p.extract_span());
|
|
|
|
|
let index = Symbol::new(name_segs[1], name_p.extract_span());
|
|
|
|
|
let bindings: BTreeMap<SmartString<LazyCompact>, Expr> = src
|
|
|
|
|
.next()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.into_inner()
|
|
|
|
|
.map(|arg| extract_named_apply_arg(arg, param_pool))
|
|
|
|
|
.try_collect()?;
|
|
|
|
|
|
|
|
|
|
let mut opts = HnswSearch {
|
|
|
|
|
relation,
|
|
|
|
|
index,
|
|
|
|
|
bindings,
|
|
|
|
|
k: 0,
|
|
|
|
|
ef: 0,
|
|
|
|
|
query: Expr::Const {
|
|
|
|
|
val: DataValue::Bot,
|
|
|
|
|
span: SourceSpan::default(),
|
|
|
|
|
},
|
|
|
|
|
bind_field: None,
|
|
|
|
|
bind_field_idx: None,
|
|
|
|
|
bind_distance: None,
|
|
|
|
|
bind_vector: None,
|
|
|
|
|
radius: None,
|
|
|
|
|
filter: None,
|
|
|
|
|
span,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for field in src {
|
|
|
|
|
let mut fs = field.into_inner();
|
|
|
|
|
let fname = fs.next().unwrap();
|
|
|
|
|
let fval = fs.next().unwrap();
|
|
|
|
|
match fname.as_str() {
|
|
|
|
|
"query" => {
|
|
|
|
|
opts.query = build_expr(fval, param_pool)?;
|
|
|
|
|
}
|
|
|
|
|
"filter" => {
|
|
|
|
|
opts.filter = Some(build_expr(fval, param_pool)?);
|
|
|
|
|
}
|
|
|
|
|
"bind_field" => {
|
|
|
|
|
opts.bind_field = Some(build_expr(fval, param_pool)?);
|
|
|
|
|
}
|
|
|
|
|
"bind_field_idx" => {
|
|
|
|
|
opts.bind_field_idx = Some(build_expr(fval, param_pool)?);
|
|
|
|
|
}
|
|
|
|
|
"bind_distance" => {
|
|
|
|
|
opts.bind_distance = Some(build_expr(fval, param_pool)?);
|
|
|
|
|
}
|
|
|
|
|
"bind_vector" => {
|
|
|
|
|
opts.bind_vector = Some(build_expr(fval, param_pool)?);
|
|
|
|
|
}
|
|
|
|
|
"k" => {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Expected positive integer for `k`")]
|
|
|
|
|
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
|
|
|
|
|
struct ExpectedPosIntForHnswK(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
let k = build_expr(fval, param_pool)?;
|
|
|
|
|
let k = k.eval_to_const()?;
|
|
|
|
|
let k = k
|
|
|
|
|
.get_int()
|
|
|
|
|
.ok_or_else(|| ExpectedPosIntForHnswK(fname.extract_span()))?;
|
|
|
|
|
ensure!(k > 0, ExpectedPosIntForHnswK(fname.extract_span()));
|
|
|
|
|
opts.k = k as usize;
|
|
|
|
|
}
|
|
|
|
|
"ef" => {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Expected positive integer for `ef`")]
|
|
|
|
|
#[diagnostic(code(parser::expected_int_for_hnsw_ef))]
|
|
|
|
|
struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
let ef = build_expr(fval, param_pool)?;
|
|
|
|
|
let ef = ef.eval_to_const()?;
|
|
|
|
|
let ef = ef
|
|
|
|
|
.get_int()
|
|
|
|
|
.ok_or_else(|| ExpectedPosIntForHnswEf(fname.extract_span()))?;
|
|
|
|
|
ensure!(ef > 0, ExpectedPosIntForHnswEf(fname.extract_span()));
|
|
|
|
|
opts.ef = ef as usize;
|
|
|
|
|
}
|
|
|
|
|
"radius" => {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Expected positive float for `radius`")]
|
|
|
|
|
#[diagnostic(code(parser::expected_float_for_hnsw_radius))]
|
|
|
|
|
struct ExpectedPosFloatForHnswRadius(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
let radius = build_expr(fval, param_pool)?;
|
|
|
|
|
let radius = radius.eval_to_const()?;
|
|
|
|
|
let radius = radius
|
|
|
|
|
.get_float()
|
|
|
|
|
.ok_or_else(|| ExpectedPosFloatForHnswRadius(fname.extract_span()))?;
|
|
|
|
|
ensure!(
|
|
|
|
|
radius > 0.0,
|
|
|
|
|
ExpectedPosFloatForHnswRadius(fname.extract_span())
|
|
|
|
|
);
|
|
|
|
|
opts.radius = Some(radius);
|
|
|
|
|
}
|
|
|
|
|
_ => {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Unexpected HNSW option `{0}`")]
|
|
|
|
|
#[diagnostic(code(parser::unexpected_hnsw_option))]
|
|
|
|
|
struct UnexpectedHnswOption(String, #[label] SourceSpan);
|
|
|
|
|
bail!(UnexpectedHnswOption(
|
|
|
|
|
fname.as_str().into(),
|
|
|
|
|
fname.extract_span()
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if matches!(
|
|
|
|
|
opts.query,
|
|
|
|
|
Expr::Const {
|
|
|
|
|
val: DataValue::Bot,
|
|
|
|
|
..
|
|
|
|
|
}
|
|
|
|
|
) {
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
#[error("Field `query` is required for HNSW search")]
|
|
|
|
|
#[diagnostic(code(parser::hnsw_query_required))]
|
|
|
|
|
struct HnswQueryRequired(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
bail!(HnswQueryRequired(span))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
InputAtom::HnswSearch { inner: opts }
|
|
|
|
|
}
|
|
|
|
|
Rule::relation_named_apply => {
|
|
|
|
|
let span = src.extract_span();
|
|
|
|
|
let mut src = src.into_inner();
|
|
|
|
@ -627,19 +768,7 @@ fn parse_atom(
|
|
|
|
|
.next()
|
|
|
|
|
.unwrap()
|
|
|
|
|
.into_inner()
|
|
|
|
|
.map(|pair| -> Result<(SmartString<LazyCompact>, Expr)> {
|
|
|
|
|
let mut inner = pair.into_inner();
|
|
|
|
|
let name_p = inner.next().unwrap();
|
|
|
|
|
let name = SmartString::from(name_p.as_str());
|
|
|
|
|
let arg = match inner.next() {
|
|
|
|
|
Some(a) => build_expr(a, param_pool)?,
|
|
|
|
|
None => Expr::Binding {
|
|
|
|
|
var: Symbol::new(name.clone(), name_p.extract_span()),
|
|
|
|
|
tuple_pos: None,
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
Ok((name, arg))
|
|
|
|
|
})
|
|
|
|
|
.map(|arg| extract_named_apply_arg(arg, param_pool))
|
|
|
|
|
.try_collect()?;
|
|
|
|
|
let valid_at = match src.next() {
|
|
|
|
|
None => None,
|
|
|
|
@ -661,6 +790,23 @@ fn parse_atom(
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn extract_named_apply_arg(
|
|
|
|
|
pair: Pair<'_>,
|
|
|
|
|
param_pool: &BTreeMap<String, DataValue>,
|
|
|
|
|
) -> Result<(SmartString<LazyCompact>, Expr)> {
|
|
|
|
|
let mut inner = pair.into_inner();
|
|
|
|
|
let name_p = inner.next().unwrap();
|
|
|
|
|
let name = SmartString::from(name_p.as_str());
|
|
|
|
|
let arg = match inner.next() {
|
|
|
|
|
Some(a) => build_expr(a, param_pool)?,
|
|
|
|
|
None => Expr::Binding {
|
|
|
|
|
var: Symbol::new(name.clone(), name_p.extract_span()),
|
|
|
|
|
tuple_pos: None,
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
Ok((name, arg))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn parse_rule_head(
|
|
|
|
|
src: Pair<'_>,
|
|
|
|
|
param_pool: &BTreeMap<String, DataValue>,
|
|
|
|
|