refactor search parsing

main
Ziyang Hu 1 year ago
parent fdb5e0fddc
commit 378dd33fb6

@ -54,7 +54,7 @@ param = @{"$" ~ (XID_CONTINUE | "_")*}
ident = @{XID_START ~ ("_" | XID_CONTINUE)*}
underscore_ident = @{("_" | XID_START) ~ ("_" | XID_CONTINUE)*}
relation_ident = @{"*" ~ (compound_or_index_ident | underscore_ident)}
vector_index_ident = _{"~" ~ compound_or_index_ident}
search_index_ident = _{"~" ~ compound_or_index_ident}
compound_ident = @{ident ~ ("." ~ ident)*}
compound_or_index_ident = @{ident ~ ("." ~ ident)* ~ (":" ~ ident)?}
@ -80,10 +80,10 @@ rule_body = {(disjunction ~ ",")* ~ disjunction?}
rule_apply = {underscore_ident ~ "[" ~ apply_args ~ "]"}
relation_named_apply = {relation_ident ~ "{" ~ named_apply_args ~ validity_clause? ~ "}"}
relation_apply = {relation_ident ~ "[" ~ apply_args ~ validity_clause? ~ "]"}
vector_search = {vector_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"}
search_apply = {search_index_ident ~ "{" ~ named_apply_args ~ "|" ~ (index_opt_field ~ ",")* ~ index_opt_field? ~ "}"}
disjunction = {(atom ~ "or" )* ~ atom}
atom = _{ negation | relation_named_apply | relation_apply | vector_search | rule_apply | unify_multi | unify | expr | grouped}
atom = _{ negation | relation_named_apply | relation_apply | search_apply | rule_apply | unify_multi | unify | expr | grouped}
unify = {var ~ "=" ~ expr}
unify_multi = {var ~ "in" ~ expr}
negation = {"not" ~ atom}

@ -11,7 +11,7 @@ use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;
use miette::{bail, ensure, Diagnostic, Result};
use miette::{bail, ensure, miette, Diagnostic, Result};
use smallvec::SmallVec;
use smartstring::{LazyCompact, SmartString};
use thiserror::Error;
@ -915,25 +915,17 @@ pub(crate) enum InputAtom {
Unification {
inner: Unification,
},
HnswSearch {
inner: HnswSearchInput,
Search {
inner: SearchInput,
},
}
#[derive(Clone)]
pub(crate) struct HnswSearchInput {
pub(crate) struct SearchInput {
pub(crate) relation: Symbol,
pub(crate) index: Symbol,
pub(crate) bindings: BTreeMap<SmartString<LazyCompact>, Expr>,
pub(crate) k: usize,
pub(crate) ef: usize,
pub(crate) query: Expr,
pub(crate) bind_field: Option<Expr>,
pub(crate) bind_field_idx: Option<Expr>,
pub(crate) bind_distance: Option<Expr>,
pub(crate) bind_vector: Option<Expr>,
pub(crate) radius: Option<f64>,
pub(crate) filter: Option<Expr>,
pub(crate) parameters: BTreeMap<SmartString<LazyCompact>, Expr>,
pub(crate) span: SourceSpan,
}
@ -966,41 +958,14 @@ impl HnswSearch {
}
}
impl HnswSearchInput {
pub(crate) fn normalize(
impl SearchInput {
fn normalize_hnsw(
mut self,
base_handle: RelationHandle,
idx_handle: RelationHandle,
manifest: HnswIndexManifest,
gen: &mut TempSymbGen,
tx: &SessionTx<'_>,
) -> Result<Disjunction> {
let base_handle = tx.get_relation(&self.relation, false)?;
if base_handle.access_level < AccessLevel::ReadOnly {
bail!(InsufficientAccessLevel(
base_handle.name.to_string(),
"reading rows".to_string(),
base_handle.access_level
));
}
let (idx_handle, manifest) = base_handle
.hnsw_indices
.get(&self.index.name)
.ok_or_else(|| {
#[derive(Debug, Error, Diagnostic)]
#[error("hnsw index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))]
struct HnswIndexNotFound {
relation: String,
name: String,
#[label]
span: SourceSpan,
}
HnswIndexNotFound {
relation: self.relation.name.to_string(),
name: self.index.name.to_string(),
span: self.index.span,
}
})?
.clone();
let mut conj = Vec::with_capacity(self.bindings.len() + 8);
let mut bindings = Vec::with_capacity(self.bindings.len());
let mut seen_variables = BTreeSet::new();
@ -1060,7 +1025,16 @@ impl HnswSearchInput {
));
}
let query = match self.query {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `{0}` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct HnswRequiredMissing(String, #[label] SourceSpan);
let query = match self
.parameters
.remove("query")
.ok_or_else(|| miette!(HnswRequiredMissing("query".to_string(), self.span)))?
{
Expr::Binding { var, .. } => var,
expr => {
let span = expr.span();
@ -1076,7 +1050,54 @@ impl HnswSearchInput {
}
};
let bind_field = match self.bind_field {
let k_expr = self
.parameters
.remove("k")
.ok_or_else(|| miette!(HnswRequiredMissing("k".to_string(), self.span)))?;
let k = k_expr.eval_to_const()?;
let k = k.get_int().ok_or(ExpectedPosIntForHnswK(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForHnswK(#[label] SourceSpan);
ensure!(k > 0, ExpectedPosIntForHnswK(self.span));
let ef_expr = self
.parameters
.remove("ef")
.ok_or_else(|| miette!(HnswRequiredMissing("ef".to_string(), self.span)))?;
let ef = ef_expr.eval_to_const()?;
let ef = ef.get_int().ok_or(ExpectedPosIntForHnswEf(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `ef`")]
#[diagnostic(code(parser::expected_int_for_hnsw_ef))]
struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
ensure!(ef > 0, ExpectedPosIntForHnswEf(self.span));
let radius_expr = self.parameters.remove("radius");
let radius = match radius_expr {
Some(expr) => {
let r = expr.eval_to_const()?;
let r = r.get_float().ok_or(ExpectedFloatForHnswRadius(self.span))?;
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `radius`")]
#[diagnostic(code(parser::expected_float_for_hnsw_radius))]
struct ExpectedFloatForHnswRadius(#[label] SourceSpan);
ensure!(r > 0.0, ExpectedFloatForHnswRadius(self.span));
Some(r)
}
None => None,
};
let filter = self.parameters.remove("filter");
let bind_field = match self.parameters.remove("bind_field") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
@ -1093,7 +1114,7 @@ impl HnswSearchInput {
}
};
let bind_field_idx = match self.bind_field_idx {
let bind_field_idx = match self.parameters.remove("bind_field_idx") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
@ -1110,7 +1131,7 @@ impl HnswSearchInput {
}
};
let bind_distance = match self.bind_distance {
let bind_distance = match self.parameters.remove("bind_distance") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
@ -1127,7 +1148,7 @@ impl HnswSearchInput {
}
};
let bind_vector = match self.bind_vector {
let bind_vector = match self.parameters.remove("bind_vector") {
None => None,
Some(Expr::Binding { var, .. }) => Some(var),
Some(expr) => {
@ -1149,37 +1170,53 @@ impl HnswSearchInput {
idx_handle,
manifest,
bindings,
k: self.k,
ef: self.ef,
k: k as usize,
ef: ef as usize,
query,
bind_field,
bind_field_idx,
bind_distance,
bind_vector,
radius: self.radius,
filter: self.filter,
radius,
filter,
span: self.span,
}));
// ret.push(if is_negated {
// NormalFormAtom::NegatedRelation(NormalFormRelationApplyAtom {
// name: self.name,
// args,
// valid_at: self.valid_at,
// span: self.span,
// })
// } else {
// NormalFormAtom::Relation(NormalFormRelationApplyAtom {
// name: self.name,
// args,
// valid_at: self.valid_at,
// span: self.span,
// })
// });
// Disjunction::conj(ret)
Ok(Disjunction::conj(conj))
}
pub(crate) fn normalize(
self,
gen: &mut TempSymbGen,
tx: &SessionTx<'_>,
) -> Result<Disjunction> {
let base_handle = tx.get_relation(&self.relation, false)?;
if base_handle.access_level < AccessLevel::ReadOnly {
bail!(InsufficientAccessLevel(
base_handle.name.to_string(),
"reading rows".to_string(),
base_handle.access_level
));
}
if let Some((idx_handle, manifest)) =
base_handle.hnsw_indices.get(&self.index.name).cloned()
{
return self.normalize_hnsw(base_handle, idx_handle, manifest, gen);
}
#[derive(Debug, Error, Diagnostic)]
#[error("Index {name} not found on relation {relation}")]
#[diagnostic(code(eval::hnsw_index_not_found))]
struct IndexNotFound {
relation: String,
name: String,
#[label]
span: SourceSpan,
}
bail!(IndexNotFound {
relation: self.relation.to_string(),
name: self.index.to_string(),
span: self.span,
})
}
}
impl Debug for InputAtom {
@ -1213,32 +1250,14 @@ impl Display for InputAtom {
write!(f, ":{name}")?;
f.debug_list().entries(args).finish()?;
}
InputAtom::HnswSearch { inner } => {
InputAtom::Search { inner } => {
write!(f, "~{}:{}{{", inner.relation, inner.index)?;
for (binding, expr) in &inner.bindings {
write!(f, "{binding}: {expr}, ")?;
}
write!(f, "| ")?;
write!(f, " query: {}, ", inner.query)?;
write!(f, " k: {}, ", inner.k)?;
write!(f, " ef: {}, ", inner.ef)?;
if let Some(radius) = &inner.radius {
write!(f, " radius: {}, ", radius)?;
}
if let Some(filter) = &inner.filter {
write!(f, " filter: {}, ", filter)?;
}
if let Some(bind_distance) = &inner.bind_distance {
write!(f, " bind_distance: {}, ", bind_distance)?;
}
if let Some(bind_field) = &inner.bind_field {
write!(f, " bind_field: {}, ", bind_field)?;
}
if let Some(bind_field_idx) = &inner.bind_field_idx {
write!(f, " bind_field_idx: {}, ", bind_field_idx)?;
}
if let Some(bind_vector) = &inner.bind_vector {
write!(f, " bind_vector: {}, ", bind_vector)?;
for (k, v) in inner.parameters.iter() {
write!(f, "{k}: {v}, ")?;
}
write!(f, "}}")?;
}
@ -1307,7 +1326,7 @@ impl InputAtom {
InputAtom::Relation { inner, .. } => inner.span,
InputAtom::Predicate { inner, .. } => inner.span(),
InputAtom::Unification { inner, .. } => inner.span,
InputAtom::HnswSearch { inner, .. } => inner.span,
InputAtom::Search { inner, .. } => inner.span,
}
}
}

@ -26,12 +26,12 @@ pub(crate) mod cangjie;
pub(crate) mod tokenizer;
#[derive(Debug, Clone, PartialEq, Eq, Hash, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct TokenizerFilterConfig {
pub(crate) struct TokenizerConfig {
pub(crate) name: SmartString<LazyCompact>,
pub(crate) args: Vec<DataValue>,
}
impl TokenizerFilterConfig {
impl TokenizerConfig {
// use sha256::digest;
pub(crate) fn config_hash(&self, filters: &[Self]) -> impl AsRef<[u8]> {
let mut hasher = Sha256::new();
@ -225,8 +225,8 @@ pub(crate) struct FtsIndexConfig {
base_relation: SmartString<LazyCompact>,
index_name: SmartString<LazyCompact>,
fts_fields: Vec<SmartString<LazyCompact>>,
tokenizer: TokenizerFilterConfig,
filters: Vec<TokenizerFilterConfig>,
tokenizer: TokenizerConfig,
filters: Vec<TokenizerConfig>,
}
#[derive(Default)]
@ -239,8 +239,8 @@ impl TokenizerCache {
pub(crate) fn get(
&self,
tokenizer_name: &str,
tokenizer: &TokenizerFilterConfig,
filters: &[TokenizerFilterConfig],
tokenizer: &TokenizerConfig,
filters: &[TokenizerConfig],
) -> Result<Arc<TextAnalyzer>> {
{
let idx_cache = self.named_cache.read().unwrap();

@ -22,12 +22,7 @@ use thiserror::Error;
use crate::data::aggr::{parse_aggr, Aggregation};
use crate::data::expr::Expr;
use crate::data::functions::{str2vld, MAX_VALIDITY_TS};
use crate::data::program::{
FixedRuleApply, FixedRuleArg, HnswSearchInput, InputAtom, InputInlineRule,
InputInlineRulesOrFixed, InputNamedFieldRelationApplyAtom, InputProgram,
InputRelationApplyAtom, InputRuleApplyAtom, QueryAssertion, QueryOutOptions, RelationOp,
SortDir, Unification,
};
use crate::data::program::{FixedRuleApply, FixedRuleArg, InputAtom, InputInlineRule, InputInlineRulesOrFixed, InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom, QueryAssertion, QueryOutOptions, RelationOp, SearchInput, SortDir, Unification};
use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata};
use crate::data::symb::{Symbol, PROG_ENTRY};
use crate::data::value::{DataValue, ValidityTs};
@ -626,7 +621,7 @@ fn parse_atom(
},
}
}
Rule::vector_search => {
Rule::search_apply => {
let span = src.extract_span();
let mut src = src.into_inner();
let name_p = src.next().unwrap();
@ -649,141 +644,19 @@ fn parse_atom(
.into_inner()
.map(|arg| extract_named_apply_arg(arg, param_pool))
.try_collect()?;
let parameters: BTreeMap<SmartString<LazyCompact>, Expr> = src
.map(|arg| extract_named_apply_arg(arg, param_pool))
.try_collect()?;
let mut opts = HnswSearchInput {
let opts = SearchInput {
relation,
index,
bindings,
k: 0,
ef: 0,
query: Expr::Const {
val: DataValue::Bot,
span: SourceSpan::default(),
},
bind_field: None,
bind_field_idx: None,
bind_distance: None,
bind_vector: None,
radius: None,
filter: None,
span,
parameters,
};
for field in src {
let mut fs = field.into_inner();
let fname = fs.next().unwrap();
let fval = fs.next().unwrap();
match fname.as_str() {
"query" => {
opts.query = build_expr(fval, param_pool)?;
}
"filter" => {
opts.filter = Some(build_expr(fval, param_pool)?);
}
"bind_field" => {
opts.bind_field = Some(build_expr(fval, param_pool)?);
}
"bind_field_idx" => {
opts.bind_field_idx = Some(build_expr(fval, param_pool)?);
}
"bind_distance" => {
opts.bind_distance = Some(build_expr(fval, param_pool)?);
}
"bind_vector" => {
opts.bind_vector = Some(build_expr(fval, param_pool)?);
}
"k" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `k`")]
#[diagnostic(code(parser::expected_int_for_hnsw_k))]
struct ExpectedPosIntForHnswK(#[label] SourceSpan);
let k = build_expr(fval, param_pool)?;
let k = k.eval_to_const()?;
let k = k
.get_int()
.ok_or_else(|| ExpectedPosIntForHnswK(fname.extract_span()))?;
ensure!(k > 0, ExpectedPosIntForHnswK(fname.extract_span()));
opts.k = k as usize;
}
"ef" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive integer for `ef`")]
#[diagnostic(code(parser::expected_int_for_hnsw_ef))]
struct ExpectedPosIntForHnswEf(#[label] SourceSpan);
let ef = build_expr(fval, param_pool)?;
let ef = ef.eval_to_const()?;
let ef = ef
.get_int()
.ok_or_else(|| ExpectedPosIntForHnswEf(fname.extract_span()))?;
ensure!(ef > 0, ExpectedPosIntForHnswEf(fname.extract_span()));
opts.ef = ef as usize;
}
"radius" => {
#[derive(Debug, Error, Diagnostic)]
#[error("Expected positive float for `radius`")]
#[diagnostic(code(parser::expected_float_for_hnsw_radius))]
struct ExpectedPosFloatForHnswRadius(#[label] SourceSpan);
let radius = build_expr(fval, param_pool)?;
let radius = radius.eval_to_const()?;
let radius = radius
.get_float()
.ok_or_else(|| ExpectedPosFloatForHnswRadius(fname.extract_span()))?;
ensure!(
radius > 0.0,
ExpectedPosFloatForHnswRadius(fname.extract_span())
);
opts.radius = Some(radius);
}
_ => {
#[derive(Debug, Error, Diagnostic)]
#[error("Unexpected HNSW option `{0}`")]
#[diagnostic(code(parser::unexpected_hnsw_option))]
struct UnexpectedHnswOption(String, #[label] SourceSpan);
bail!(UnexpectedHnswOption(
fname.as_str().into(),
fname.extract_span()
))
}
}
}
if matches!(
opts.query,
Expr::Const {
val: DataValue::Bot,
..
}
) {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `query` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_query_required))]
struct HnswQueryRequired(#[label] SourceSpan);
bail!(HnswQueryRequired(span))
}
if opts.k == 0 {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `k > 0` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_k_required))]
struct HnswKRequired(#[label] SourceSpan);
bail!(HnswKRequired(span))
}
if opts.ef == 0 {
#[derive(Debug, Error, Diagnostic)]
#[error("Field `ef > 0` is required for HNSW search")]
#[diagnostic(code(parser::hnsw_ef_required))]
struct HnswEfRequired(#[label] SourceSpan);
bail!(HnswEfRequired(span))
}
InputAtom::HnswSearch { inner: opts }
InputAtom::Search { inner: opts }
}
Rule::relation_named_apply => {
let span = src.extract_span();

@ -18,6 +18,7 @@ use crate::data::program::InputProgram;
use crate::data::relation::VecElementType;
use crate::data::symb::Symbol;
use crate::data::value::{DataValue, ValidityTs};
use crate::fts::TokenizerConfig;
use crate::parse::expr::build_expr;
use crate::parse::query::parse_query;
use crate::parse::{ExtractSpan, Pairs, Rule, SourceSpan};
@ -39,8 +40,17 @@ pub(crate) enum SysOp {
SetAccessLevel(Vec<Symbol>, AccessLevel),
CreateIndex(Symbol, Symbol, Vec<Symbol>),
CreateVectorIndex(HnswIndexConfig),
CreateFtsIndex(FtsIndexConfig),
RemoveIndex(Symbol, Symbol),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct FtsIndexConfig {
pub(crate) base_relation: SmartString<LazyCompact>,
pub(crate) index_name: SmartString<LazyCompact>,
pub(crate) extractor: String,
pub(crate) tokenizer: TokenizerConfig,
pub(crate) filters: Vec<TokenizerConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct HnswIndexConfig {
@ -227,7 +237,9 @@ pub(crate) fn parse_sys(
let v = build_expr(opt_val, param_pool)?
.eval_to_const()?
.get_int()
.ok_or_else(|| miette!("Invalid m_neighbours: {}", opt_val_str))?;
.ok_or_else(|| {
miette!("Invalid m_neighbours: {}", opt_val_str)
})?;
ensure!(v > 0, "Invalid m_neighbours: {}", v);
m_neighbours = v as usize;
}

@ -121,11 +121,11 @@ impl InputAtom {
InputAtom::Unification { inner } => {
bail!(UnsafeNegation(inner.span))
}
InputAtom::HnswSearch { inner } => {
InputAtom::Search { inner } => {
bail!(UnsafeNegation(inner.span))
}
},
InputAtom::HnswSearch { inner } => InputAtom::HnswSearch { inner },
InputAtom::Search { inner } => InputAtom::Search { inner },
})
}
@ -229,7 +229,7 @@ impl InputAtom {
InputAtom::Unification { inner: u } => {
Disjunction::singlet(NormalFormAtom::Unification(u))
}
InputAtom::HnswSearch { inner } => inner.normalize(gen, tx)?,
InputAtom::Search { inner } => inner.normalize(gen, tx)?,
})
}
}

@ -1192,6 +1192,9 @@ impl<'s, S: Storage<'s>> Db<S> {
vec![vec![DataValue::from(OK_STR)]],
))
}
SysOp::CreateFtsIndex(_) => {
todo!("FTS index creation is not yet implemented")
}
SysOp::RemoveIndex(rel_name, idx_name) => {
let lock = self
.obtain_relation_locks(iter::once(&rel_name.name))

@ -19,7 +19,7 @@ use crate::data::expr::Expr;
use crate::data::symb::Symbol;
use crate::data::value::DataValue;
use crate::fixed_rule::FixedRulePayload;
use crate::fts::{TokenizerCache, TokenizerFilterConfig};
use crate::fts::{TokenizerCache, TokenizerConfig};
use crate::parse::SourceSpan;
use crate::runtime::callback::CallbackOp;
use crate::runtime::db::Poison;
@ -948,7 +948,7 @@ fn tentivy_tokenizers() {
let tokenizer = tokenizers
.get(
"simple",
&TokenizerFilterConfig {
&TokenizerConfig {
name: "Simple".into(),
args: vec![],
},
@ -970,7 +970,7 @@ fn tentivy_tokenizers() {
let tokenizer = tokenizers
.get(
"cangjie",
&TokenizerFilterConfig {
&TokenizerConfig {
name: "Cangjie".into(),
args: vec![],
},

Loading…
Cancel
Save