diff --git a/cozo-core/src/data/expr.rs b/cozo-core/src/data/expr.rs index faca5881..9e7fac71 100644 --- a/cozo-core/src/data/expr.rs +++ b/cozo-core/src/data/expr.rs @@ -12,10 +12,10 @@ use std::fmt::{Debug, Display, Formatter}; use std::mem; use itertools::Itertools; -use miette::{bail, Diagnostic, Result}; +use miette::{bail, miette, Diagnostic, Result}; use serde::de::{Error, Visitor}; use serde::{Deserializer, Serializer}; -use smartstring::SmartString; +use smartstring::{LazyCompact, SmartString}; use thiserror::Error; use crate::data::functions::*; @@ -560,6 +560,26 @@ impl Expr { }, }) } + pub(crate) fn to_var_list(&self) -> Result>> { + return match self { + Expr::Apply { op, args, .. } => { + if op.name != "OP_LIST" { + Err(miette!("Invalid fields op: {} for {}", op.name, self)) + } else { + let mut collected = vec![]; + for field in args.iter() { + match field { + Expr::Binding { var, .. } => collected.push(var.name.clone()), + _ => return Err(miette!("Invalid field element: {}", field)), + } + } + Ok(collected) + } + } + Expr::Binding { var, .. } => Ok(vec![var.name.clone()]), + _ => Err(miette!("Invalid fields: {}", self)), + }; + } } pub(crate) fn compute_bounds( diff --git a/cozo-core/src/parse/sys.rs b/cozo-core/src/parse/sys.rs index c6af4a1e..dcda0cb2 100644 --- a/cozo-core/src/parse/sys.rs +++ b/cozo-core/src/parse/sys.rs @@ -11,8 +11,10 @@ use std::sync::Arc; use itertools::Itertools; use miette::{ensure, miette, Diagnostic, Result}; +use smartstring::{LazyCompact, SmartString}; use thiserror::Error; +use crate::data::functions::OP_LIST; use crate::data::program::InputProgram; use crate::data::relation::VecElementType; use crate::data::symb::Symbol; @@ -21,7 +23,7 @@ use crate::parse::expr::build_expr; use crate::parse::query::parse_query; use crate::parse::{ExtractSpan, Pairs, Rule, SourceSpan}; use crate::runtime::relation::AccessLevel; -use crate::FixedRule; +use crate::{Expr, FixedRule}; pub(crate) enum SysOp { Compact, @@ -44,15 +46,16 @@ pub(crate) enum SysOp { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct HnswIndexConfig { - pub(crate) base_relation: Symbol, - pub(crate) index_name: Symbol, + pub(crate) base_relation: SmartString, + pub(crate) index_name: SmartString, pub(crate) vec_dim: usize, pub(crate) dtype: VecElementType, - pub(crate) vec_fields: Vec, - pub(crate) tag_fields: Vec, + pub(crate) vec_fields: Vec>, + pub(crate) tag_fields: Vec>, pub(crate) distance: HnswDistance, pub(crate) ef_construction: usize, pub(crate) max_elements: usize, + pub(crate) index_filter: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -198,6 +201,7 @@ pub(crate) fn parse_sys( let mut distance = HnswDistance::L2; let mut ef_construction = 0; let mut max_elements = 0; + let mut index_filter = None; // TODO this is a bit of a mess for opt_pair in inner { @@ -216,18 +220,17 @@ pub(crate) fn parse_sys( "F32" | "Float" => VecElementType::F32, "F64" | "Double" => VecElementType::F64, _ => { - return Err(miette!( - "Invalid dtype: {}", - opt_val.as_str() - )) + return Err(miette!("Invalid dtype: {}", opt_val.as_str())) } } } "fields" => { - todo!() + let fields = build_expr(opt_val, &Default::default())?; + vec_fields = fields.to_var_list()?; } "tags" => { - todo!() + let fields = build_expr(opt_val, &Default::default())?; + tag_fields = fields.to_var_list()?; } "distance" => { distance = match opt_val.as_str() { @@ -254,17 +257,15 @@ pub(crate) fn parse_sys( .parse() .map_err(|e| miette!("Invalid max_elements: {}", e))?; } - _ => { - return Err(miette!( - "Invalid option: {}", - opt_name.as_str() - )) + "filter" => { + index_filter = Some(opt_val.as_str().to_string()); } + _ => return Err(miette!("Invalid option: {}", opt_name.as_str())), } } SysOp::CreateVectorIndex(HnswIndexConfig { - base_relation: Symbol::new(rel.as_str(), rel.extract_span()), - index_name: Symbol::new(name.as_str(), name.extract_span()), + base_relation: SmartString::from(rel.as_str()), + index_name: SmartString::from(name.as_str()), vec_dim, dtype, vec_fields, @@ -272,6 +273,7 @@ pub(crate) fn parse_sys( distance, ef_construction, max_elements, + index_filter, }) } Rule::index_drop => { diff --git a/cozo-core/src/runtime/db.rs b/cozo-core/src/runtime/db.rs index 01578608..27751be2 100644 --- a/cozo-core/src/runtime/db.rs +++ b/cozo-core/src/runtime/db.rs @@ -1158,6 +1158,7 @@ impl<'s, S: Storage<'s>> Db { )) } SysOp::CreateVectorIndex(config) => { + dbg!(&config); todo!() } SysOp::RemoveVectorIndex(rel_name, idx_name) => { diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index e2000dec..7770830b 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -784,6 +784,16 @@ fn test_vec_index() { let db = DbInstance::new("mem", "", "").unwrap(); db.run_script(":create a {k: String => tags: [String], v: }", Default::default()) .unwrap(); - // db.run_script("::hnsw create a:vec {dim: 128, dtype: f32, fields: [v], tags: tags, distance: Cosine, ef_construction: 20, max_elements: 50}", Default::default()) - // .unwrap(); + db.run_script(r" + ::hnsw create a:vec { + dim: 128, + dtype: F32, + fields: [v], + tags: tags, + distance: Cosine, + ef_construction: 20, + max_elements: 50, + filter: k != 'k1' + }", Default::default()) + .unwrap(); }