happy clippy

main
Ziyang Hu 1 year ago
parent 5a23d515a2
commit e3e155cfa3

@ -568,7 +568,7 @@ impl Expr {
}) })
} }
pub(crate) fn to_var_list(&self) -> Result<Vec<SmartString<LazyCompact>>> { pub(crate) fn to_var_list(&self) -> Result<Vec<SmartString<LazyCompact>>> {
return match self { match self {
Expr::Apply { op, args, .. } => { Expr::Apply { op, args, .. } => {
if op.name != "OP_LIST" { if op.name != "OP_LIST" {
Err(miette!("Invalid fields op: {} for {}", op.name, self)) Err(miette!("Invalid fields op: {} for {}", op.name, self))
@ -585,7 +585,7 @@ impl Expr {
} }
Expr::Binding { var, .. } => Ok(vec![var.name.clone()]), Expr::Binding { var, .. } => Ok(vec![var.name.clone()]),
_ => Err(miette!("Invalid fields: {}", self)), _ => Err(miette!("Invalid fields: {}", self)),
}; }
} }
} }

@ -198,7 +198,7 @@ fn add_vecs(args: &[DataValue]) -> Result<DataValue> {
let f = b let f = b
.get_float() .get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?; .ok_or_else(|| miette!("can only add numbers to vectors"))?;
match a.clone() { match a {
Vector::F32(mut v) => { Vector::F32(mut v) => {
v += f as f32; v += f as f32;
Ok(DataValue::Vec(Vector::F32(v))) Ok(DataValue::Vec(Vector::F32(v)))
@ -360,7 +360,7 @@ fn mul_vecs(args: &[DataValue]) -> Result<DataValue> {
let f = b let f = b
.get_float() .get_float()
.ok_or_else(|| miette!("can only add numbers to vectors"))?; .ok_or_else(|| miette!("can only add numbers to vectors"))?;
match a.clone() { match a {
Vector::F32(mut v) => { Vector::F32(mut v) => {
v *= f as f32; v *= f as f32;
Ok(DataValue::Vec(Vector::F32(v))) Ok(DataValue::Vec(Vector::F32(v)))
@ -1708,7 +1708,7 @@ pub(crate) fn op_vec(args: &[DataValue]) -> Result<DataValue> {
let f = el let f = el
.get_float() .get_float()
.ok_or_else(|| miette!("'vec' requires a list of numbers"))?; .ok_or_else(|| miette!("'vec' requires a list of numbers"))?;
row.fill(f as f64); row.fill(f);
} }
Ok(DataValue::Vec(Vector::F64(res_arr))) Ok(DataValue::Vec(Vector::F64(res_arr)))
} }
@ -1794,7 +1794,7 @@ pub(crate) fn op_l2_dist(args: &[DataValue]) -> Result<DataValue> {
bail!("'l2_dist' requires two vectors of the same length"); bail!("'l2_dist' requires two vectors of the same length");
} }
let diff = a - b; let diff = a - b;
Ok(DataValue::from(diff.dot(&diff) as f64)) Ok(DataValue::from(diff.dot(&diff)))
} }
_ => bail!("'l2_dist' requires two vectors of the same type"), _ => bail!("'l2_dist' requires two vectors of the same type"),
} }
@ -1817,7 +1817,7 @@ pub(crate) fn op_ip_dist(args: &[DataValue]) -> Result<DataValue> {
bail!("'ip_dist' requires two vectors of the same length"); bail!("'ip_dist' requires two vectors of the same length");
} }
let dot = a.dot(b); let dot = a.dot(b);
Ok(DataValue::from(1. - dot as f64)) Ok(DataValue::from(1. - dot))
} }
_ => bail!("'ip_dist' requires two vectors of the same type"), _ => bail!("'ip_dist' requires two vectors of the same type"),
} }

@ -10,7 +10,6 @@ use std::cmp::Reverse;
use std::collections::BTreeSet; use std::collections::BTreeSet;
use std::io::Write; use std::io::Write;
use std::str::FromStr; use std::str::FromStr;
use ndarray;
use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use regex::Regex; use regex::Regex;

@ -190,7 +190,7 @@ impl serde::Serialize for Vector {
Vector::F32(a) => { Vector::F32(a) => {
state.serialize_element(&0u8)?; state.serialize_element(&0u8)?;
let arr = a.as_slice().unwrap(); let arr = a.as_slice().unwrap();
let len = arr.len() * std::mem::size_of::<f32>(); let len = std::mem::size_of_val(arr);
let ptr = arr.as_ptr() as *const u8; let ptr = arr.as_ptr() as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
state.serialize_element(&VecBytes(bytes))?; state.serialize_element(&VecBytes(bytes))?;
@ -198,7 +198,7 @@ impl serde::Serialize for Vector {
Vector::F64(a) => { Vector::F64(a) => {
state.serialize_element(&1u8)?; state.serialize_element(&1u8)?;
let arr = a.as_slice().unwrap(); let arr = a.as_slice().unwrap();
let len = arr.len() * std::mem::size_of::<f64>(); let len = std::mem::size_of_val(arr);
let ptr = arr.as_ptr() as *const u8; let ptr = arr.as_ptr() as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, len) }; let bytes = unsafe { std::slice::from_raw_parts(ptr, len) };
state.serialize_element(&VecBytes(bytes))?; state.serialize_element(&VecBytes(bytes))?;
@ -274,6 +274,13 @@ impl Vector {
Vector::F64(v) => v.len(), Vector::F64(v) => v.len(),
} }
} }
/// Check if the vector is empty
pub fn is_empty(&self) -> bool {
match self {
Vector::F32(v) => v.is_empty(),
Vector::F64(v) => v.is_empty(),
}
}
pub(crate) fn el_type(&self) -> VecElementType { pub(crate) fn el_type(&self) -> VecElementType {
match self { match self {
Vector::F32(_) => VecElementType::F32, Vector::F32(_) => VecElementType::F32,
@ -285,12 +292,12 @@ impl Vector {
match self { match self {
Vector::F32(v) => { Vector::F32(v) => {
for e in v.iter() { for e in v.iter() {
hasher.update(&e.to_le_bytes()); hasher.update(e.to_le_bytes());
} }
} }
Vector::F64(v) => { Vector::F64(v) => {
for e in v.iter() { for e in v.iter() {
hasher.update(&e.to_le_bytes()); hasher.update(e.to_le_bytes());
} }
} }
} }
@ -350,7 +357,7 @@ impl Ord for Vector {
o => return o, o => return o,
} }
} }
return Ordering::Equal; Ordering::Equal
} }
(Vector::F32(_), Vector::F64(_)) => Ordering::Less, (Vector::F32(_), Vector::F64(_)) => Ordering::Less,
(Vector::F64(l), Vector::F64(r)) => { (Vector::F64(l), Vector::F64(r)) => {
@ -364,7 +371,7 @@ impl Ord for Vector {
o => return o, o => return o,
} }
} }
return Ordering::Equal; Ordering::Equal
} }
(Vector::F64(_), Vector::F32(_)) => Ordering::Greater, (Vector::F64(_), Vector::F32(_)) => Ordering::Greater,
} }

@ -664,7 +664,7 @@ impl FixedRule for SimpleFixedRule {
.map(|s| s.name.to_string()) .map(|s| s.name.to_string())
.collect_vec(); .collect_vec();
let l = headers.len(); let l = headers.len();
let m = input.arg_manifest.arity(&payload.tx, &payload.stores)?; let m = input.arg_manifest.arity(payload.tx, payload.stores)?;
for i in l..m { for i in l..m {
headers.push(format!("_{i}")); headers.push(format!("_{i}"));
} }

@ -224,7 +224,7 @@ impl DbInstance {
self.run_script_fold_err(payload, params_json).to_string() self.run_script_fold_err(payload, params_json).to_string()
} }
/// Dispatcher method. See [crate::Db::export_relations]. /// Dispatcher method. See [crate::Db::export_relations].
pub fn export_relations<'a, I, T>(&self, relations: I) -> Result<BTreeMap<String, NamedRows>> pub fn export_relations<I, T>(&self, relations: I) -> Result<BTreeMap<String, NamedRows>>
where where
T: AsRef<str>, T: AsRef<str>,
I: Iterator<Item = T>, I: Iterator<Item = T>,

@ -23,9 +23,10 @@ use crate::data::aggr::{parse_aggr, Aggregation};
use crate::data::expr::Expr; use crate::data::expr::Expr;
use crate::data::functions::{str2vld, MAX_VALIDITY_TS}; use crate::data::functions::{str2vld, MAX_VALIDITY_TS};
use crate::data::program::{ use crate::data::program::{
FixedRuleApply, FixedRuleArg, HnswSearchInput, InputAtom, InputInlineRule, InputInlineRulesOrFixed, FixedRuleApply, FixedRuleArg, HnswSearchInput, InputAtom, InputInlineRule,
InputNamedFieldRelationApplyAtom, InputProgram, InputRelationApplyAtom, InputRuleApplyAtom, InputInlineRulesOrFixed, InputNamedFieldRelationApplyAtom, InputProgram,
QueryAssertion, QueryOutOptions, RelationOp, SortDir, Unification, InputRelationApplyAtom, InputRuleApplyAtom, QueryAssertion, QueryOutOptions, RelationOp,
SortDir, Unification,
}; };
use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata}; use crate::data::relation::{ColType, ColumnDef, NullableColType, StoredRelationMetadata};
use crate::data::symb::{Symbol, PROG_ENTRY}; use crate::data::symb::{Symbol, PROG_ENTRY};
@ -106,8 +107,10 @@ pub(crate) fn parse_query(
cur_vld: ValidityTs, cur_vld: ValidityTs,
) -> Result<InputProgram> { ) -> Result<InputProgram> {
let mut progs: BTreeMap<Symbol, InputInlineRulesOrFixed> = Default::default(); let mut progs: BTreeMap<Symbol, InputInlineRulesOrFixed> = Default::default();
let mut out_opts: QueryOutOptions = Default::default(); let mut out_opts: QueryOutOptions = QueryOutOptions {
out_opts.timeout = Some(DEFAULT_TIMEOUT); timeout: Some(DEFAULT_TIMEOUT),
..Default::default()
};
let mut stored_relation = None; let mut stored_relation = None;
for pair in src { for pair in src {

@ -501,7 +501,7 @@ impl<'a> SessionTx<'a> {
} }
} }
ret = ret.hnsw_search(s.clone(), own_bindings)?; ret = ret.hnsw_search(s.clone(), own_bindings)?;
if post_filters.len() > 0 { if !post_filters.is_empty() {
ret = ret.filter(Expr::build_and(post_filters, s.span)); ret = ret.filter(Expr::build_and(post_filters, s.span));
} }
} }

@ -202,7 +202,7 @@ impl NamedRows {
let row = row let row = row
.as_array() .as_array()
.ok_or_else(|| miette!("'rows' field must be an array of arrays"))?; .ok_or_else(|| miette!("'rows' field must be an array of arrays"))?;
Ok(row.iter().map(|el| DataValue::from(el)).collect_vec()) Ok(row.iter().map(DataValue::from).collect_vec())
}) })
.try_collect()?; .try_collect()?;
Ok(Self { Ok(Self {
@ -373,7 +373,7 @@ impl<'s, S: Storage<'s>> Db<S> {
/// Export relations to JSON data. /// Export relations to JSON data.
/// ///
/// `relations` contains names of the stored relations to export. /// `relations` contains names of the stored relations to export.
pub fn export_relations<'a, I, T>(&'s self, relations: I) -> Result<BTreeMap<String, NamedRows>> pub fn export_relations<I, T>(&'s self, relations: I) -> Result<BTreeMap<String, NamedRows>>
where where
T: AsRef<str>, T: AsRef<str>,
I: Iterator<Item = T>, I: Iterator<Item = T>,
@ -726,7 +726,7 @@ impl<'s, S: Storage<'s>> Db<S> {
}; };
let cb = CallbackDeclaration { let cb = CallbackDeclaration {
dependent: SmartString::from(relation), dependent: SmartString::from(relation),
sender: sender, sender,
}; };
let mut guard = self.event_callbacks.write().unwrap(); let mut guard = self.event_callbacks.write().unwrap();

@ -20,9 +20,9 @@ use miette::{bail, miette, Result};
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use priority_queue::PriorityQueue; use priority_queue::PriorityQueue;
use rand::Rng; use rand::Rng;
use rustc_hash::{FxHashMap, FxHashSet};
use smartstring::{LazyCompact, SmartString}; use smartstring::{LazyCompact, SmartString};
use std::cmp::{max, Reverse}; use std::cmp::{max, Reverse};
use rustc_hash::{FxHashMap, FxHashSet};
#[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)] #[derive(Debug, Clone, PartialEq, serde_derive::Serialize, serde_derive::Deserialize)]
pub(crate) struct HnswIndexManifest { pub(crate) struct HnswIndexManifest {
@ -81,8 +81,8 @@ impl VectorCache {
1.0 - dot / (a_norm * b_norm).sqrt() 1.0 - dot / (a_norm * b_norm).sqrt()
} }
(Vector::F64(a), Vector::F64(b)) => { (Vector::F64(a), Vector::F64(b)) => {
let a_norm = a.dot(a) as f64; let a_norm = a.dot(a);
let b_norm = b.dot(b) as f64; let b_norm = b.dot(b);
let dot = a.dot(b); let dot = a.dot(b);
1.0 - dot / (a_norm * b_norm).sqrt() 1.0 - dot / (a_norm * b_norm).sqrt()
} }
@ -98,7 +98,7 @@ impl VectorCache {
} }
(Vector::F64(a), Vector::F64(b)) => { (Vector::F64(a), Vector::F64(b)) => {
let dot = a.dot(b); let dot = a.dot(b);
1. - dot as f64 1. - dot
} }
_ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2), _ => panic!("Cannot compute inner product between {:?} and {:?}", v1, v2),
}, },
@ -164,7 +164,9 @@ impl<'a> SessionTx<'a> {
distance: manifest.distance, distance: manifest.distance,
}; };
let tuple_key = &tuple[..orig_table.metadata.keys.len()]; let tuple_key = &tuple[..orig_table.metadata.keys.len()];
vec_cache.cache.insert((tuple_key.to_vec(), idx, subidx), q.clone()); vec_cache
.cache
.insert((tuple_key.to_vec(), idx, subidx), q.clone());
let hash = q.get_hash(); let hash = q.get_hash();
let mut canary_tuple = vec![DataValue::from(0)]; let mut canary_tuple = vec![DataValue::from(0)];
for _ in 0..2 { for _ in 0..2 {
@ -339,7 +341,7 @@ impl<'a> SessionTx<'a> {
if target_degree > m_max { if target_degree > m_max {
// shrink links // shrink links
target_degree = self.hnsw_shrink_neighbour( target_degree = self.hnsw_shrink_neighbour(
&neighbour, neighbour,
m_max, m_max,
current_level, current_level,
manifest, manifest,
@ -383,8 +385,8 @@ impl<'a> SessionTx<'a> {
orig_table: &RelationHandle, orig_table: &RelationHandle,
vec_cache: &mut VectorCache, vec_cache: &mut VectorCache,
) -> Result<usize> { ) -> Result<usize> {
vec_cache.ensure_key(&target_key, orig_table, self)?; vec_cache.ensure_key(target_key, orig_table, self)?;
let vec = vec_cache.get_key(&target_key).clone(); let vec = vec_cache.get_key(target_key).clone();
let mut candidates = PriorityQueue::new(); let mut candidates = PriorityQueue::new();
for (neighbour_key, neighbour_dist) in for (neighbour_key, neighbour_dist) in
self.hnsw_get_neighbours(target_key, level, idx_table, false)? self.hnsw_get_neighbours(target_key, level, idx_table, false)?
@ -499,12 +501,7 @@ impl<'a> SessionTx<'a> {
if manifest.extend_candidates { if manifest.extend_candidates {
for (item, _) in found.iter() { for (item, _) in found.iter() {
// Extend by neighbours // Extend by neighbours
for (neighbour_key, _) in self.hnsw_get_neighbours( for (neighbour_key, _) in self.hnsw_get_neighbours(item, level, idx_table, false)? {
&item,
level,
idx_table,
false,
)? {
vec_cache.ensure_key(&neighbour_key, orig_table, self)?; vec_cache.ensure_key(&neighbour_key, orig_table, self)?;
let dist = vec_cache.v_dist(q, &neighbour_key); let dist = vec_cache.v_dist(q, &neighbour_key);
candidates.push( candidates.push(
@ -519,7 +516,7 @@ impl<'a> SessionTx<'a> {
let mut should_add = true; let mut should_add = true;
for (existing, _) in ret.iter() { for (existing, _) in ret.iter() {
vec_cache.ensure_key(&cand_key, orig_table, self)?; vec_cache.ensure_key(&cand_key, orig_table, self)?;
vec_cache.ensure_key(&existing, orig_table, self)?; vec_cache.ensure_key(existing, orig_table, self)?;
let dist_to_existing = vec_cache.k_dist(existing, &cand_key); let dist_to_existing = vec_cache.k_dist(existing, &cand_key);
if dist_to_existing < cand_dist_to_q { if dist_to_existing < cand_dist_to_q {
should_add = false; should_add = false;
@ -568,12 +565,9 @@ impl<'a> SessionTx<'a> {
break; break;
} }
// loop over each of the candidate's neighbors // loop over each of the candidate's neighbors
for (neighbour_key, _) in self.hnsw_get_neighbours( for (neighbour_key, _) in
&candidate, self.hnsw_get_neighbours(&candidate, cur_level, idx_table, false)?
cur_level, {
idx_table,
false,
)? {
if visited.contains(&neighbour_key) { if visited.contains(&neighbour_key) {
continue; continue;
} }
@ -702,7 +696,7 @@ impl<'a> SessionTx<'a> {
for idx in &manifest.vec_fields { for idx in &manifest.vec_fields {
let val = tuple.get(*idx).unwrap(); let val = tuple.get(*idx).unwrap();
if let DataValue::Vec(v) = val { if let DataValue::Vec(v) = val {
extracted_vectors.push((v, *idx, -1 as i32)); extracted_vectors.push((v, *idx, -1));
} else if let DataValue::List(l) = val { } else if let DataValue::List(l) = val {
for (sidx, v) in l.iter().enumerate() { for (sidx, v) in l.iter().enumerate() {
if let DataValue::Vec(v) = v { if let DataValue::Vec(v) = v {
@ -715,7 +709,7 @@ impl<'a> SessionTx<'a> {
return Ok(false); return Ok(false);
} }
for (vec, idx, sub) in extracted_vectors { for (vec, idx, sub) in extracted_vectors {
self.hnsw_put_vector(&tuple, vec, idx, sub, manifest, orig_table, idx_table)?; self.hnsw_put_vector(tuple, vec, idx, sub, manifest, orig_table, idx_table)?;
} }
Ok(true) Ok(true)
} }
@ -951,13 +945,11 @@ impl<'a> SessionTx<'a> {
.ok_or_else(|| miette!("corrupted index"))?; .ok_or_else(|| miette!("corrupted index"))?;
if config.bind_field.is_some() { if config.bind_field.is_some() {
let field = if cand_key.1 as usize >= config.base_handle.metadata.keys.len() { let field = if cand_key.1 >= config.base_handle.metadata.keys.len() {
config.base_handle.metadata.keys[cand_key.1 as usize] config.base_handle.metadata.keys[cand_key.1].name.clone()
.name
.clone()
} else { } else {
config.base_handle.metadata.non_keys config.base_handle.metadata.non_keys
[cand_key.1 as usize - config.base_handle.metadata.keys.len()] [cand_key.1 - config.base_handle.metadata.keys.len()]
.name .name
.clone() .clone()
}; };

@ -280,7 +280,7 @@ impl<'s, S: Storage<'s>> Db<S> {
}; };
match self.execute_imperative_stmts( match self.execute_imperative_stmts(
&ps, ps,
&mut tx, &mut tx,
&mut cleanups, &mut cleanups,
cur_vld, cur_vld,

@ -52,7 +52,7 @@ impl<'a> SessionTx<'a> {
bail!("Storage is used but un-versioned, probably created by an ancient version of Cozo.") bail!("Storage is used but un-versioned, probably created by an ancient version of Cozo.")
} }
Some(v) => { Some(v) => {
if &v != &CURRENT_STORAGE_VERSION { if v != CURRENT_STORAGE_VERSION {
bail!( bail!(
"Version mismatch: expect storage version {:?}, got {:?}", "Version mismatch: expect storage version {:?}, got {:?}",
CURRENT_STORAGE_VERSION, CURRENT_STORAGE_VERSION,

@ -53,7 +53,7 @@ fn js2value<'a>(
val: Handle<'a, JsValue>, val: Handle<'a, JsValue>,
coll: &mut DataValue, coll: &mut DataValue,
) -> JsResult<'a, JsUndefined> { ) -> JsResult<'a, JsUndefined> {
if let Ok(_) = val.downcast::<JsNull, _>(cx) { if val.downcast::<JsNull, _>(cx).is_ok() {
*coll = DataValue::Null; *coll = DataValue::Null;
} else if let Ok(n) = val.downcast::<JsNumber, _>(cx) { } else if let Ok(n) = val.downcast::<JsNumber, _>(cx) {
let n = n.value(cx); let n = n.value(cx);
@ -61,7 +61,7 @@ fn js2value<'a>(
} else if let Ok(b) = val.downcast::<JsBoolean, _>(cx) { } else if let Ok(b) = val.downcast::<JsBoolean, _>(cx) {
let b = b.value(cx); let b = b.value(cx);
*coll = DataValue::from(b); *coll = DataValue::from(b);
} else if let Ok(_) = val.downcast::<JsUndefined, _>(cx) { } else if val.downcast::<JsUndefined, _>(cx).is_ok() {
*coll = DataValue::Null; *coll = DataValue::Null;
} else if let Ok(s) = val.downcast::<JsString, _>(cx) { } else if let Ok(s) = val.downcast::<JsString, _>(cx) {
let s = s.value(cx); let s = s.value(cx);
@ -300,7 +300,6 @@ macro_rules! get_tx {
}}; }};
} }
macro_rules! remove_tx { macro_rules! remove_tx {
($cx:expr) => {{ ($cx:expr) => {{
let id = $cx.argument::<JsNumber>(0)?.value(&mut $cx) as u32; let id = $cx.argument::<JsNumber>(0)?.value(&mut $cx) as u32;
@ -321,7 +320,6 @@ macro_rules! remove_tx {
}}; }};
} }
fn multi_transact(mut cx: FunctionContext) -> JsResult<JsNumber> { fn multi_transact(mut cx: FunctionContext) -> JsResult<JsNumber> {
let db = get_db!(cx); let db = get_db!(cx);
let write = cx.argument::<JsBoolean>(1)?.value(&mut cx); let write = cx.argument::<JsBoolean>(1)?.value(&mut cx);
@ -398,7 +396,10 @@ fn query_tx(mut cx: FunctionContext) -> JsResult<JsUndefined> {
let callback = cx.argument::<JsFunction>(3)?.root(&mut cx); let callback = cx.argument::<JsFunction>(3)?.root(&mut cx);
let channel = cx.channel(); let channel = cx.channel();
match tx.sender.send(TransactionPayload::Query((query.clone(), params))) { match tx
.sender
.send(TransactionPayload::Query((query.clone(), params)))
{
Ok(_) => { Ok(_) => {
thread::spawn(move || { thread::spawn(move || {
let result = tx.receiver.recv(); let result = tx.receiver.recv();

Loading…
Cancel
Save