Skip to content

Commit

Permalink
[red-knot] Enhancing Diagnostics for Compare Expression Inference (#1…
Browse files Browse the repository at this point in the history
…3819)

## Summary

- Refactored comparison type inference functions in `infer.rs`: Changed
the return type from `Option` to `Result` to lay the groundwork for
providing more detailed diagnostics.
- Updated diagnostic messages.

This is a small step toward improving diagnostics in the future.

Please refer to #13787

## Test Plan

mdtest included!

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
  • Loading branch information
cake-monotone and carljm authored Oct 19, 2024
1 parent 55bccf6 commit fb66f71
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@ reveal_type(a) # revealed: bool
b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`"
reveal_type(b) # revealed: bool

c = object() < 5 # error: "Operator `<` is not supported for types `object` and `Literal[5]`"
c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`"
reveal_type(c) # revealed: Unknown

# TODO should error, need to check if __lt__ signature is valid for right operand
d = 5 < object()
# TODO: should be `Unknown`
reveal_type(d) # revealed: bool

int_literal_or_str_literal = 1 if flag else "foo"
# error: "Operator `in` is not supported for types `Literal[42]` and `Literal[1]`, in comparing `Literal[42]` with `Literal[1] | Literal["foo"]`"
e = 42 in int_literal_or_str_literal
reveal_type(e) # revealed: bool

# TODO: should error, need to check if __lt__ signature is valid for right operand
# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]`
f = (1, 2) < (1, "hello")
reveal_type(f) # revealed: @Todo
```
172 changes: 109 additions & 63 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2776,18 +2776,28 @@ impl<'db> TypeInferenceBuilder<'db> {
let right_ty = self.expression_ty(right);

self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
.unwrap_or_else(|error| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator `{}` is not supported for types `{}` and `{}`",
op,
left_ty.display(self.db),
right_ty.display(self.db)
"Operator `{}` is not supported for types `{}` and `{}`{}",
error.op,
error.left_ty.display(self.db),
error.right_ty.display(self.db),
if (left_ty, right_ty) == (error.left_ty, error.right_ty) {
String::new()
} else {
format!(
", in comparing `{}` with `{}`",
left_ty.display(self.db),
right_ty.display(self.db)
)
}
),
);

match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
Expand All @@ -2814,7 +2824,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: Type<'db>,
op: ast::CmpOp,
right: Type<'db>,
) -> Option<Type<'db>> {
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the
// language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
Expand All @@ -2825,39 +2835,43 @@ impl<'db> TypeInferenceBuilder<'db> {
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?);
}
Some(builder.build())
Ok(builder.build())
}
(other, Type::Union(union)) => {
let mut builder = UnionBuilder::new(self.db);
for element in union.elements(self.db) {
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?);
}
Some(builder.build())
Ok(builder.build())
}

(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
// Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => None,
ast::CmpOp::In | ast::CmpOp::NotIn => Err(CompareUnsupportedError {
op,
left_ty: left,
right_ty: right,
}),
},
(Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right)
Expand Down Expand Up @@ -2888,26 +2902,26 @@ impl<'db> TypeInferenceBuilder<'db> {
let s1 = salsa_s1.value(self.db);
let s2 = salsa_s2.value(self.db);
match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(s1 == s2)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(s1 != s2)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(s1 < s2)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(s1 <= s2)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(s1 > s2)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(s1 >= s2)),
ast::CmpOp::In => Ok(Type::BooleanLiteral(s2.contains(s1.as_ref()))),
ast::CmpOp::NotIn => Ok(Type::BooleanLiteral(!s2.contains(s1.as_ref()))),
ast::CmpOp::Is => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if s1 == s2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
}
Expand All @@ -2930,30 +2944,30 @@ impl<'db> TypeInferenceBuilder<'db> {
let b1 = &**salsa_b1.value(self.db);
let b2 = &**salsa_b2.value(self.db);
match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(b1 == b2)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(b1 != b2)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(b1 < b2)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(b1 <= b2)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(b1 > b2)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(b1 >= b2)),
ast::CmpOp::Eq => Ok(Type::BooleanLiteral(b1 == b2)),
ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(b1 != b2)),
ast::CmpOp::Lt => Ok(Type::BooleanLiteral(b1 < b2)),
ast::CmpOp::LtE => Ok(Type::BooleanLiteral(b1 <= b2)),
ast::CmpOp::Gt => Ok(Type::BooleanLiteral(b1 > b2)),
ast::CmpOp::GtE => Ok(Type::BooleanLiteral(b1 >= b2)),
ast::CmpOp::In => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some()))
}
ast::CmpOp::NotIn => {
Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none()))
}
ast::CmpOp::Is => {
if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
Ok(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if b1 == b2 {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
Ok(Type::BooleanLiteral(true))
}
}
}
Expand Down Expand Up @@ -2991,7 +3005,7 @@ impl<'db> TypeInferenceBuilder<'db> {
).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`");

match eq_result {
Type::Todo => return Some(Type::Todo),
Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) {
Truthiness::AlwaysTrue => eq_count += 1,
Truthiness::AlwaysFalse => not_eq_count += 1,
Expand All @@ -3001,11 +3015,11 @@ impl<'db> TypeInferenceBuilder<'db> {
}

if eq_count >= 1 {
Some(Type::BooleanLiteral(op.is_in()))
Ok(Type::BooleanLiteral(op.is_in()))
} else if not_eq_count == rhs_elements.len() {
Some(Type::BooleanLiteral(op.is_not_in()))
Ok(Type::BooleanLiteral(op.is_not_in()))
} else {
Some(KnownClass::Bool.to_instance(self.db))
Ok(KnownClass::Bool.to_instance(self.db))
}
}
ast::CmpOp::Is | ast::CmpOp::IsNot => {
Expand All @@ -3016,7 +3030,7 @@ impl<'db> TypeInferenceBuilder<'db> {
"infer_binary_type_comparison should never return None for `CmpOp::Eq`",
);

Some(match eq_result {
Ok(match eq_result {
Type::Todo => Type::Todo,
ty => match ty.bool(self.db) {
Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()),
Expand All @@ -3029,16 +3043,19 @@ impl<'db> TypeInferenceBuilder<'db> {

// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
}
ast::CmpOp::Lt => perform_rich_comparison(
self.db,
left_class_ty,
right_class_ty,
RichCompareOperator::Lt,
),
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo),
_ => Ok(Type::Todo),
},
// TODO: handle more types
_ => match op {
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)),
_ => Some(Type::Todo),
ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)),
_ => Ok(Type::Todo),
},
}
}
Expand All @@ -3053,7 +3070,7 @@ impl<'db> TypeInferenceBuilder<'db> {
left: &[Type<'db>],
op: RichCompareOperator,
right: &[Type<'db>],
) -> Option<Type<'db>> {
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// Compare paired elements from left and right slices
for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) {
let eq_result = self
Expand All @@ -3062,7 +3079,7 @@ impl<'db> TypeInferenceBuilder<'db> {

match eq_result {
// If propagation is required, return the result as is
Type::Todo => return Some(Type::Todo),
Type::Todo => return Ok(Type::Todo),
ty => match ty.bool(self.db) {
// Types are equal, continue to the next pair
Truthiness::AlwaysTrue => continue,
Expand All @@ -3072,7 +3089,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral.
// In this case, we simply return a bool instance.
Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)),
Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)),
},
}
}
Expand All @@ -3082,7 +3099,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// We return a comparison of the slice lengths based on the operator.
let (left_len, right_len) = (left.len(), right.len());

Some(Type::BooleanLiteral(match op {
Ok(Type::BooleanLiteral(match op {
RichCompareOperator::Eq => left_len == right_len,
RichCompareOperator::Ne => left_len != right_len,
RichCompareOperator::Lt => left_len < right_len,
Expand Down Expand Up @@ -3556,6 +3573,26 @@ impl From<RichCompareOperator> for ast::CmpOp {
}
}

impl RichCompareOperator {
const fn dunder_name(self) -> &'static str {
match self {
RichCompareOperator::Eq => "__eq__",
RichCompareOperator::Ne => "__ne__",
RichCompareOperator::Lt => "__lt__",
RichCompareOperator::Le => "__le__",
RichCompareOperator::Gt => "__gt__",
RichCompareOperator::Ge => "__ge__",
}
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CompareUnsupportedError<'db> {
op: ast::CmpOp,
left_ty: Type<'db>,
right_ty: Type<'db>,
}

fn format_import_from_module(level: u32, module: Option<&str>) -> String {
format!(
"{}{}",
Expand Down Expand Up @@ -3636,26 +3673,35 @@ fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
dunder_name: &str,
) -> Option<Type<'db>> {
op: RichCompareOperator,
) -> Result<Type<'db>, CompareUnsupportedError<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined

let dunder = left.class_member(db, dunder_name);
let dunder = left.class_member(db, op.dunder_name());
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db);
.return_ty(db)
.ok_or_else(|| CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
});
}

// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None
Err(CompareUnsupportedError {
op: op.into(),
left_ty: Type::Instance(left),
right_ty: Type::Instance(right),
})
}

#[cfg(test)]
Expand Down

0 comments on commit fb66f71

Please sign in to comment.