diff --git a/cozo-core/src/fts/tokenizer/mod.rs b/cozo-core/src/fts/tokenizer/mod.rs index a90e3184..77896213 100644 --- a/cozo-core/src/fts/tokenizer/mod.rs +++ b/cozo-core/src/fts/tokenizer/mod.rs @@ -156,10 +156,12 @@ pub(crate) use self::whitespace_tokenizer::WhitespaceTokenizer; #[cfg(test)] pub(crate) mod tests { - use super::{ - Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, - }; - use crate::fts::tokenizer::TextAnalyzer; + // use super::{ + // Language, LowerCaser, RemoveLongFilter, SimpleTokenizer, Stemmer, Token, + // }; + // use crate::fts::tokenizer::TextAnalyzer; + + use crate::fts::tokenizer::Token; /// This is a function that can be used in tests and doc tests /// to assert a token's correctness. diff --git a/cozo-core/src/fts/tokenizer/ngram_tokenizer.rs b/cozo-core/src/fts/tokenizer/ngram_tokenizer.rs index 2f7982a3..22e777c9 100644 --- a/cozo-core/src/fts/tokenizer/ngram_tokenizer.rs +++ b/cozo-core/src/fts/tokenizer/ngram_tokenizer.rs @@ -301,16 +301,16 @@ fn utf8_codepoint_width(b: u8) -> usize { #[cfg(test)] mod tests { - use super::{utf8_codepoint_width, CodepointFrontiers, NgramTokenizer, StutteringIterator}; - use crate::fts::tokenizer::tests::assert_token; + use super::{utf8_codepoint_width, CodepointFrontiers, StutteringIterator}; + // use crate::fts::tokenizer::tests::assert_token; // use crate::fts::tokenizer::tokenizer_impl::Tokenizer; - use crate::fts::tokenizer::{BoxTokenStream, Token}; + // use crate::fts::tokenizer::{BoxTokenStream, Token}; - fn test_helper(mut tokenizer: BoxTokenStream<'_>) -> Vec { - let mut tokens: Vec = vec![]; - tokenizer.process(&mut |token: &Token| tokens.push(token.clone())); - tokens - } + // fn test_helper(mut tokenizer: BoxTokenStream<'_>) -> Vec { + // let mut tokens: Vec = vec![]; + // tokenizer.process(&mut |token: &Token| tokens.push(token.clone())); + // tokens + // } #[test] fn test_utf8_codepoint_width() { diff --git a/cozo-core/src/query/stored.rs b/cozo-core/src/query/stored.rs index 0353ea6c..60ded480 100644 --- a/cozo-core/src/query/stored.rs +++ b/cozo-core/src/query/stored.rs @@ -23,6 +23,7 @@ use crate::data::tuple::{Tuple, ENCODED_KEY_MIN_LEN}; use crate::data::value::{DataValue, ValidityTs}; use crate::fixed_rule::utilities::constant::Constant; use crate::fixed_rule::FixedRuleHandle; +use crate::fts::tokenizer::TextAnalyzer; use crate::parse::expr::build_expr; use crate::parse::{parse_script, CozoScriptParser, Rule}; use crate::runtime::callback::{CallbackCollector, CallbackOp}; @@ -236,6 +237,7 @@ impl<'a> SessionTx<'a> { || (propagate_triggers && !relation_store.put_triggers.is_empty())); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let has_fts_indices = !relation_store.fts_indices.is_empty(); let mut new_tuples: Vec = vec![]; let mut old_tuples: Vec = vec![]; @@ -248,6 +250,7 @@ impl<'a> SessionTx<'a> { key_extractors.extend(val_extractors); let mut stack = vec![]; let hnsw_filters = Self::make_hnsw_filters(relation_store)?; + let fts_processors = self.make_fts_processors(relation_store)?; for tuple in res_iter { let extracted: Vec = key_extractors @@ -258,12 +261,13 @@ impl<'a> SessionTx<'a> { let key = relation_store.encode_key_for_store(&extracted, span)?; let val = relation_store.encode_val_for_store(&extracted, span)?; - if need_to_collect || has_indices || has_hnsw_indices { + if need_to_collect || has_indices || has_hnsw_indices || has_fts_indices { if let Some(existing) = self.store_tx.get(&key, false)? { let mut tup = extracted[0..relation_store.metadata.keys.len()].to_vec(); extend_tuple_from_v(&mut tup, &existing); if has_indices && extracted != tup { self.update_in_index(relation_store, &extracted, &tup)?; + self.del_in_fts(relation_store, &mut stack, &fts_processors, &tup)?; } if need_to_collect { @@ -282,6 +286,7 @@ impl<'a> SessionTx<'a> { } self.update_in_hnsw(relation_store, &mut stack, &hnsw_filters, &extracted)?; + self.put_in_fts(relation_store, &mut stack, &fts_processors, &extracted)?; if need_to_collect { new_tuples.push(DataValue::List(extracted)); @@ -312,6 +317,34 @@ impl<'a> SessionTx<'a> { Ok(()) } + fn put_in_fts( + &mut self, + rel_handle: &RelationHandle, + stack: &mut Vec, + processors: &BTreeMap, (Arc, Vec)>, + new_kv: &[DataValue], + ) -> Result<()> { + for (k, (idx_handle, _)) in rel_handle.fts_indices.iter() { + let (tokenizer, extractor) = processors.get(k).unwrap(); + self.put_fts_index_item(new_kv, extractor, stack, tokenizer, rel_handle, idx_handle)?; + } + Ok(()) + } + + fn del_in_fts( + &mut self, + rel_handle: &RelationHandle, + stack: &mut Vec, + processors: &BTreeMap, (Arc, Vec)>, + old_kv: &[DataValue], + ) -> Result<()> { + for (k, (idx_handle, _)) in rel_handle.fts_indices.iter() { + let (tokenizer, extractor) = processors.get(k).unwrap(); + self.del_fts_index_item(old_kv, extractor, stack, tokenizer, rel_handle, idx_handle)?; + } + Ok(()) + } + fn update_in_hnsw( &mut self, relation_store: &RelationHandle, @@ -333,6 +366,31 @@ impl<'a> SessionTx<'a> { Ok(()) } + fn make_fts_processors( + &self, + relation_store: &RelationHandle, + ) -> Result, (Arc, Vec)>> { + let mut fts_processors = BTreeMap::new(); + for (name, (_, manifest)) in relation_store.fts_indices.iter() { + let tokenizer = self.tokenizers.get( + &relation_store.name, + &manifest.tokenizer, + &manifest.filters, + )?; + + let parsed = CozoScriptParser::parse(Rule::expr, &manifest.extractor) + .into_diagnostic()? + .next() + .unwrap(); + let mut code_expr = build_expr(parsed, &Default::default())?; + let binding_map = relation_store.raw_binding_map(); + code_expr.fill_binding_indices(&binding_map)?; + let extractor = code_expr.compile()?; + fts_processors.insert(name.clone(), (tokenizer, extractor)); + } + Ok(fts_processors) + } + fn make_hnsw_filters( relation_store: &RelationHandle, ) -> Result, Vec>> { @@ -390,6 +448,7 @@ impl<'a> SessionTx<'a> { || (propagate_triggers && !relation_store.put_triggers.is_empty())); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let has_fts_indices = !relation_store.fts_indices.is_empty(); let mut new_tuples: Vec = vec![]; let mut old_tuples: Vec = vec![]; @@ -402,6 +461,7 @@ impl<'a> SessionTx<'a> { let mut stack = vec![]; let hnsw_filters = Self::make_hnsw_filters(relation_store)?; + let fts_processors = self.make_fts_processors(relation_store)?; for tuple in res_iter { let mut new_kv: Vec = key_extractors @@ -442,7 +502,8 @@ impl<'a> SessionTx<'a> { } let new_val = relation_store.encode_val_for_store(&new_kv, span)?; - if need_to_collect || has_indices || has_hnsw_indices { + if need_to_collect || has_indices || has_hnsw_indices || has_fts_indices { + self.del_in_fts(relation_store, &mut stack, &fts_processors, &old_kv)?; self.update_in_index(relation_store, &new_kv, &old_kv)?; if need_to_collect { @@ -450,6 +511,7 @@ impl<'a> SessionTx<'a> { } self.update_in_hnsw(relation_store, &mut stack, &hnsw_filters, &new_kv)?; + self.put_in_fts(relation_store, &mut stack, &fts_processors, &new_kv)?; if need_to_collect { new_tuples.push(DataValue::List(new_kv)); @@ -762,8 +824,11 @@ impl<'a> SessionTx<'a> { || (propagate_triggers && !relation_store.rm_triggers.is_empty())); let has_indices = !relation_store.indices.is_empty(); let has_hnsw_indices = !relation_store.hnsw_indices.is_empty(); + let has_fts_indices = !relation_store.fts_indices.is_empty(); + let fts_processors = self.make_fts_processors(relation_store)?; let mut new_tuples: Vec = vec![]; let mut old_tuples: Vec = vec![]; + let mut stack = vec![]; for tuple in res_iter { let extracted: Vec = key_extractors @@ -771,10 +836,11 @@ impl<'a> SessionTx<'a> { .map(|ex| ex.extract_data(&tuple, cur_vld)) .try_collect()?; let key = relation_store.encode_key_for_store(&extracted, span)?; - if need_to_collect || has_indices || has_hnsw_indices { + if need_to_collect || has_indices || has_hnsw_indices || has_fts_indices { if let Some(existing) = self.store_tx.get(&key, false)? { let mut tup = extracted.clone(); extend_tuple_from_v(&mut tup, &existing); + self.del_in_fts(relation_store, &mut stack, &fts_processors, &tup)?; if has_indices { for (idx_rel, extractor) in relation_store.indices.values() { let idx_tup = extractor.iter().map(|i| tup[*i].clone()).collect_vec(); diff --git a/cozo-core/src/runtime/tests.rs b/cozo-core/src/runtime/tests.rs index 1833675e..3ee961af 100644 --- a/cozo-core/src/runtime/tests.rs +++ b/cozo-core/src/runtime/tests.rs @@ -917,7 +917,10 @@ fn test_fts_indexing() { ) .unwrap(); db.run_script( - r"?[k, v] <- [['c', 'see you at the end of the world!']] :put a {k => v}", + r"?[k, v] <- [ + ['b', 'the world is square!'], + ['c', 'see you at the end of the world!'] + ] :put a {k => v}", Default::default(), ) .unwrap();