|
|
@ -11,7 +11,7 @@ use std::collections::{BTreeMap, BTreeSet};
|
|
|
|
use std::fmt::{Debug, Display, Formatter};
|
|
|
|
use std::fmt::{Debug, Display, Formatter};
|
|
|
|
use std::sync::Arc;
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
|
|
|
|
|
|
|
use miette::{bail, ensure, Diagnostic, Result};
|
|
|
|
use miette::{bail, ensure, miette, Diagnostic, Result};
|
|
|
|
use smallvec::SmallVec;
|
|
|
|
use smallvec::SmallVec;
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
use smartstring::{LazyCompact, SmartString};
|
|
|
|
use thiserror::Error;
|
|
|
|
use thiserror::Error;
|
|
|
@ -915,25 +915,17 @@ pub(crate) enum InputAtom {
|
|
|
|
Unification {
|
|
|
|
Unification {
|
|
|
|
inner: Unification,
|
|
|
|
inner: Unification,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
HnswSearch {
|
|
|
|
Search {
|
|
|
|
inner: HnswSearchInput,
|
|
|
|
inner: SearchInput,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
#[derive(Clone)]
|
|
|
|
pub(crate) struct HnswSearchInput {
|
|
|
|
pub(crate) struct SearchInput {
|
|
|
|
pub(crate) relation: Symbol,
|
|
|
|
pub(crate) relation: Symbol,
|
|
|
|
pub(crate) index: Symbol,
|
|
|
|
pub(crate) index: Symbol,
|
|
|
|
pub(crate) bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
|
|
|
|
pub(crate) bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
|
|
|
|
pub(crate) k: usize,
|
|
|
|
pub(crate) parameters: BTreeMap<SmartString<LazyCompact>, Expr>,
|
|
|
|
pub(crate) ef: usize,
|
|
|
|
|
|
|
|
pub(crate) query: Expr,
|
|
|
|
|
|
|
|
pub(crate) bind_field: Option<Expr>,
|
|
|
|
|
|
|
|
pub(crate) bind_field_idx: Option<Expr>,
|
|
|
|
|
|
|
|
pub(crate) bind_distance: Option<Expr>,
|
|
|
|
|
|
|
|
pub(crate) bind_vector: Option<Expr>,
|
|
|
|
|
|
|
|
pub(crate) radius: Option<f64>,
|
|
|
|
|
|
|
|
pub(crate) filter: Option<Expr>,
|
|
|
|
|
|
|
|
pub(crate) span: SourceSpan,
|
|
|
|
pub(crate) span: SourceSpan,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -966,41 +958,14 @@ impl HnswSearch {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl HnswSearchInput {
|
|
|
|
impl SearchInput {
|
|
|
|
pub(crate) fn normalize(
|
|
|
|
fn normalize_hnsw(
|
|
|
|
mut self,
|
|
|
|
mut self,
|
|
|
|
|
|
|
|
base_handle: RelationHandle,
|
|
|
|
|
|
|
|
idx_handle: RelationHandle,
|
|
|
|
|
|
|
|
manifest: HnswIndexManifest,
|
|
|
|
gen: &mut TempSymbGen,
|
|
|
|
gen: &mut TempSymbGen,
|
|
|
|
tx: &SessionTx<'_>,
|
|
|
|
|
|
|
|
) -> Result<Disjunction> {
|
|
|
|
) -> Result<Disjunction> {
|
|
|
|
let base_handle = tx.get_relation(&self.relation, false)?;
|
|
|
|
|
|
|
|
if base_handle.access_level < AccessLevel::ReadOnly {
|
|
|
|
|
|
|
|
bail!(InsufficientAccessLevel(
|
|
|
|
|
|
|
|
base_handle.name.to_string(),
|
|
|
|
|
|
|
|
"reading rows".to_string(),
|
|
|
|
|
|
|
|
base_handle.access_level
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
let (idx_handle, manifest) = base_handle
|
|
|
|
|
|
|
|
.hnsw_indices
|
|
|
|
|
|
|
|
.get(&self.index.name)
|
|
|
|
|
|
|
|
.ok_or_else(|| {
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("hnsw index {name} not found on relation {relation}")]
|
|
|
|
|
|
|
|
#[diagnostic(code(eval::hnsw_index_not_found))]
|
|
|
|
|
|
|
|
struct HnswIndexNotFound {
|
|
|
|
|
|
|
|
relation: String,
|
|
|
|
|
|
|
|
name: String,
|
|
|
|
|
|
|
|
#[label]
|
|
|
|
|
|
|
|
span: SourceSpan,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
HnswIndexNotFound {
|
|
|
|
|
|
|
|
relation: self.relation.name.to_string(),
|
|
|
|
|
|
|
|
name: self.index.name.to_string(),
|
|
|
|
|
|
|
|
span: self.index.span,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})?
|
|
|
|
|
|
|
|
.clone();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
|
|
|
|
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
|
|
|
|
let mut bindings = Vec::with_capacity(self.bindings.len());
|
|
|
|
let mut bindings = Vec::with_capacity(self.bindings.len());
|
|
|
|
let mut seen_variables = BTreeSet::new();
|
|
|
|
let mut seen_variables = BTreeSet::new();
|
|
|
@ -1060,7 +1025,16 @@ impl HnswSearchInput {
|
|
|
|
));
|
|
|
|
));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
let query = match self.query {
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("Field `{0}` is required for HNSW search")]
|
|
|
|
|
|
|
|
#[diagnostic(code(parser::hnsw_query_required))]
|
|
|
|
|
|
|
|
struct HnswRequiredMissing(String, #[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let query = match self
|
|
|
|
|
|
|
|
.parameters
|
|
|
|
|
|
|
|
.remove("query")
|
|
|
|
|
|
|
|
.ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))?
|
|
|
|
|
|
|
|
{
|
|
|
|
Expr::Binding { var, .. } => var,
|
|
|
|
Expr::Binding { var, .. } => var,
|
|
|
|
expr => {
|
|
|
|
expr => {
|
|
|
|
let span = expr.span();
|
|
|
|
let span = expr.span();
|
|
|
@ -1076,7 +1050,54 @@ impl HnswSearchInput {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let bind_field = match self.bind_field {
|
|
|
|
let k_expr = self
|
|
|
|
|
|
|
|
.parameters
|
|
|
|
|
|
|
|
.remove("k")
|
|
|
|
|
|
|
|
.ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?;
|
|
|
|
|
|
|
|
let k = k_expr.eval_to_const()?;
|
|
|
|
|
|
|
|
let k = k.get_int().ok_or(ExpectedPosIntForHnswK(self.span))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("Expected positive integer for `k`")]
|
|
|
|
|
|
|
|
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
|
|
|
|
|
|
|
|
struct ExpectedPosIntForHnswK(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensure!(k > 0, ExpectedPosIntForHnswK(self.span));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let ef_expr = self
|
|
|
|
|
|
|
|
.parameters
|
|
|
|
|
|
|
|
.remove("ef")
|
|
|
|
|
|
|
|
.ok_or_else(|| miette!(HnswRequiredMissing("ef".to_string(), self.span)))?;
|
|
|
|
|
|
|
|
let ef = ef_expr.eval_to_const()?;
|
|
|
|
|
|
|
|
let ef = ef.get_int().ok_or(ExpectedPosIntForHnswEf(self.span))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("Expected positive integer for `ef`")]
|
|
|
|
|
|
|
|
#[diagnostic(code(parser::expected_int_for_hnsw_ef))]
|
|
|
|
|
|
|
|
struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensure!(ef > 0, ExpectedPosIntForHnswEf(self.span));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let radius_expr = self.parameters.remove("radius");
|
|
|
|
|
|
|
|
let radius = match radius_expr {
|
|
|
|
|
|
|
|
Some(expr) => {
|
|
|
|
|
|
|
|
let r = expr.eval_to_const()?;
|
|
|
|
|
|
|
|
let r = r.get_float().ok_or(ExpectedFloatForHnswRadius(self.span))?;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("Expected positive float for `radius`")]
|
|
|
|
|
|
|
|
#[diagnostic(code(parser::expected_float_for_hnsw_radius))]
|
|
|
|
|
|
|
|
struct ExpectedFloatForHnswRadius(#[label] SourceSpan);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensure!(r > 0.0, ExpectedFloatForHnswRadius(self.span));
|
|
|
|
|
|
|
|
Some(r)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
None => None,
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let filter = self.parameters.remove("filter");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let bind_field = match self.parameters.remove("bind_field") {
|
|
|
|
None => None,
|
|
|
|
None => None,
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(expr) => {
|
|
|
|
Some(expr) => {
|
|
|
@ -1093,7 +1114,7 @@ impl HnswSearchInput {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let bind_field_idx = match self.bind_field_idx {
|
|
|
|
let bind_field_idx = match self.parameters.remove("bind_field_idx") {
|
|
|
|
None => None,
|
|
|
|
None => None,
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(expr) => {
|
|
|
|
Some(expr) => {
|
|
|
@ -1110,7 +1131,7 @@ impl HnswSearchInput {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let bind_distance = match self.bind_distance {
|
|
|
|
let bind_distance = match self.parameters.remove("bind_distance") {
|
|
|
|
None => None,
|
|
|
|
None => None,
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(expr) => {
|
|
|
|
Some(expr) => {
|
|
|
@ -1127,7 +1148,7 @@ impl HnswSearchInput {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
let bind_vector = match self.bind_vector {
|
|
|
|
let bind_vector = match self.parameters.remove("bind_vector") {
|
|
|
|
None => None,
|
|
|
|
None => None,
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(Expr::Binding { var, .. }) => Some(var),
|
|
|
|
Some(expr) => {
|
|
|
|
Some(expr) => {
|
|
|
@ -1149,37 +1170,53 @@ impl HnswSearchInput {
|
|
|
|
idx_handle,
|
|
|
|
idx_handle,
|
|
|
|
manifest,
|
|
|
|
manifest,
|
|
|
|
bindings,
|
|
|
|
bindings,
|
|
|
|
k: self.k,
|
|
|
|
k: k as usize,
|
|
|
|
ef: self.ef,
|
|
|
|
ef: ef as usize,
|
|
|
|
query,
|
|
|
|
query,
|
|
|
|
bind_field,
|
|
|
|
bind_field,
|
|
|
|
bind_field_idx,
|
|
|
|
bind_field_idx,
|
|
|
|
bind_distance,
|
|
|
|
bind_distance,
|
|
|
|
bind_vector,
|
|
|
|
bind_vector,
|
|
|
|
radius: self.radius,
|
|
|
|
radius,
|
|
|
|
filter: self.filter,
|
|
|
|
filter,
|
|
|
|
span: self.span,
|
|
|
|
span: self.span,
|
|
|
|
}));
|
|
|
|
}));
|
|
|
|
|
|
|
|
|
|
|
|
// ret.push(if is_negated {
|
|
|
|
|
|
|
|
// NormalFormAtom::NegatedRelation(NormalFormRelationApplyAtom {
|
|
|
|
|
|
|
|
// name: self.name,
|
|
|
|
|
|
|
|
// args,
|
|
|
|
|
|
|
|
// valid_at: self.valid_at,
|
|
|
|
|
|
|
|
// span: self.span,
|
|
|
|
|
|
|
|
// })
|
|
|
|
|
|
|
|
// } else {
|
|
|
|
|
|
|
|
// NormalFormAtom::Relation(NormalFormRelationApplyAtom {
|
|
|
|
|
|
|
|
// name: self.name,
|
|
|
|
|
|
|
|
// args,
|
|
|
|
|
|
|
|
// valid_at: self.valid_at,
|
|
|
|
|
|
|
|
// span: self.span,
|
|
|
|
|
|
|
|
// })
|
|
|
|
|
|
|
|
// });
|
|
|
|
|
|
|
|
// Disjunction::conj(ret)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Ok(Disjunction::conj(conj))
|
|
|
|
Ok(Disjunction::conj(conj))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
pub(crate) fn normalize(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
gen: &mut TempSymbGen,
|
|
|
|
|
|
|
|
tx: &SessionTx<'_>,
|
|
|
|
|
|
|
|
) -> Result<Disjunction> {
|
|
|
|
|
|
|
|
let base_handle = tx.get_relation(&self.relation, false)?;
|
|
|
|
|
|
|
|
if base_handle.access_level < AccessLevel::ReadOnly {
|
|
|
|
|
|
|
|
bail!(InsufficientAccessLevel(
|
|
|
|
|
|
|
|
base_handle.name.to_string(),
|
|
|
|
|
|
|
|
"reading rows".to_string(),
|
|
|
|
|
|
|
|
base_handle.access_level
|
|
|
|
|
|
|
|
));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some((idx_handle, manifest)) =
|
|
|
|
|
|
|
|
base_handle.hnsw_indices.get(&self.index.name).cloned()
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
|
|
return self.normalize_hnsw(base_handle, idx_handle, manifest, gen);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#[derive(Debug, Error, Diagnostic)]
|
|
|
|
|
|
|
|
#[error("Index {name} not found on relation {relation}")]
|
|
|
|
|
|
|
|
#[diagnostic(code(eval::hnsw_index_not_found))]
|
|
|
|
|
|
|
|
struct IndexNotFound {
|
|
|
|
|
|
|
|
relation: String,
|
|
|
|
|
|
|
|
name: String,
|
|
|
|
|
|
|
|
#[label]
|
|
|
|
|
|
|
|
span: SourceSpan,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
bail!(IndexNotFound {
|
|
|
|
|
|
|
|
relation: self.relation.to_string(),
|
|
|
|
|
|
|
|
name: self.index.to_string(),
|
|
|
|
|
|
|
|
span: self.span,
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl Debug for InputAtom {
|
|
|
|
impl Debug for InputAtom {
|
|
|
@ -1213,32 +1250,14 @@ impl Display for InputAtom {
|
|
|
|
write!(f, ":{name}")?;
|
|
|
|
write!(f, ":{name}")?;
|
|
|
|
f.debug_list().entries(args).finish()?;
|
|
|
|
f.debug_list().entries(args).finish()?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
InputAtom::HnswSearch { inner } => {
|
|
|
|
InputAtom::Search { inner } => {
|
|
|
|
write!(f, "~{}:{}{{", inner.relation, inner.index)?;
|
|
|
|
write!(f, "~{}:{}{{", inner.relation, inner.index)?;
|
|
|
|
for (binding, expr) in &inner.bindings {
|
|
|
|
for (binding, expr) in &inner.bindings {
|
|
|
|
write!(f, "{binding}: {expr}, ")?;
|
|
|
|
write!(f, "{binding}: {expr}, ")?;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
write!(f, "| ")?;
|
|
|
|
write!(f, "| ")?;
|
|
|
|
write!(f, " query: {}, ", inner.query)?;
|
|
|
|
for (k, v) in inner.parameters.iter() {
|
|
|
|
write!(f, " k: {}, ", inner.k)?;
|
|
|
|
write!(f, "{k}: {v}, ")?;
|
|
|
|
write!(f, " ef: {}, ", inner.ef)?;
|
|
|
|
|
|
|
|
if let Some(radius) = &inner.radius {
|
|
|
|
|
|
|
|
write!(f, " radius: {}, ", radius)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(filter) = &inner.filter {
|
|
|
|
|
|
|
|
write!(f, " filter: {}, ", filter)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(bind_distance) = &inner.bind_distance {
|
|
|
|
|
|
|
|
write!(f, " bind_distance: {}, ", bind_distance)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(bind_field) = &inner.bind_field {
|
|
|
|
|
|
|
|
write!(f, " bind_field: {}, ", bind_field)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(bind_field_idx) = &inner.bind_field_idx {
|
|
|
|
|
|
|
|
write!(f, " bind_field_idx: {}, ", bind_field_idx)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if let Some(bind_vector) = &inner.bind_vector {
|
|
|
|
|
|
|
|
write!(f, " bind_vector: {}, ", bind_vector)?;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
write!(f, "}}")?;
|
|
|
|
write!(f, "}}")?;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -1307,7 +1326,7 @@ impl InputAtom {
|
|
|
|
InputAtom::Relation { inner, .. } => inner.span,
|
|
|
|
InputAtom::Relation { inner, .. } => inner.span,
|
|
|
|
InputAtom::Predicate { inner, .. } => inner.span(),
|
|
|
|
InputAtom::Predicate { inner, .. } => inner.span(),
|
|
|
|
InputAtom::Unification { inner, .. } => inner.span,
|
|
|
|
InputAtom::Unification { inner, .. } => inner.span,
|
|
|
|
InputAtom::HnswSearch { inner, .. } => inner.span,
|
|
|
|
InputAtom::Search { inner, .. } => inner.span,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|