-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[red-knot] Inference for comparison of union types #13781
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Comparison: Unions | ||
|
||
## Union on one side of the comparison | ||
|
||
Comparisons on union types need to consider all possible cases: | ||
|
||
```py | ||
one_or_two = 1 if flag else 2 | ||
|
||
reveal_type(one_or_two <= 2) # revealed: Literal[True] | ||
reveal_type(one_or_two <= 1) # revealed: bool | ||
reveal_type(one_or_two <= 0) # revealed: Literal[False] | ||
|
||
reveal_type(2 >= one_or_two) # revealed: Literal[True] | ||
reveal_type(1 >= one_or_two) # revealed: bool | ||
reveal_type(0 >= one_or_two) # revealed: Literal[False] | ||
|
||
reveal_type(one_or_two < 1) # revealed: Literal[False] | ||
reveal_type(one_or_two < 2) # revealed: bool | ||
reveal_type(one_or_two < 3) # revealed: Literal[True] | ||
|
||
reveal_type(one_or_two > 0) # revealed: Literal[True] | ||
reveal_type(one_or_two > 1) # revealed: bool | ||
reveal_type(one_or_two > 2) # revealed: Literal[False] | ||
|
||
reveal_type(one_or_two == 3) # revealed: Literal[False] | ||
reveal_type(one_or_two == 1) # revealed: bool | ||
|
||
reveal_type(one_or_two != 3) # revealed: Literal[True] | ||
reveal_type(one_or_two != 1) # revealed: bool | ||
|
||
a_or_ab = "a" if flag else "ab" | ||
|
||
reveal_type(a_or_ab in "ab") # revealed: Literal[True] | ||
reveal_type("a" in a_or_ab) # revealed: Literal[True] | ||
|
||
reveal_type("c" not in a_or_ab) # revealed: Literal[True] | ||
reveal_type("a" not in a_or_ab) # revealed: Literal[False] | ||
|
||
reveal_type("b" in a_or_ab) # revealed: bool | ||
reveal_type("b" not in a_or_ab) # revealed: bool | ||
|
||
one_or_none = 1 if flag else None | ||
|
||
reveal_type(one_or_none is None) # revealed: bool | ||
reveal_type(one_or_none is not None) # revealed: bool | ||
``` | ||
|
||
## Union on both sides of the comparison | ||
|
||
With unions on both sides, we need to consider the full cross product of | ||
options when building the resulting (union) type: | ||
|
||
```py | ||
small = 1 if flag_s else 2 | ||
large = 2 if flag_l else 3 | ||
|
||
reveal_type(small <= large) # revealed: Literal[True] | ||
reveal_type(small >= large) # revealed: bool | ||
|
||
reveal_type(small < large) # revealed: bool | ||
reveal_type(small > large) # revealed: Literal[False] | ||
``` | ||
|
||
## Unsupported operations | ||
|
||
Make sure we emit a diagnostic if *any* of the possible comparisons is | ||
unsupported. For now, we fall back to `bool` for the result type instead of | ||
trying to infer something more precise from the other (supported) variants: | ||
|
||
```py | ||
x = [1, 2] if flag else 1 | ||
|
||
result = 1 in x # error: "Operator `in` is not supported" | ||
reveal_type(result) # revealed: bool | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,7 +57,7 @@ use crate::types::{ | |
}; | ||
use crate::Db; | ||
|
||
use super::KnownClass; | ||
use super::{KnownClass, UnionBuilder}; | ||
|
||
/// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. | ||
/// Use when checking a scope, or needing to provide a type for an arbitrary expression in the | ||
|
@@ -2717,6 +2717,21 @@ impl<'db> TypeInferenceBuilder<'db> { | |
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal | ||
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal | ||
match (left, right) { | ||
(Type::Union(union), other) => { | ||
let mut builder = UnionBuilder::new(self.db); | ||
for element in union.elements(self.db) { | ||
builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?); | ||
} | ||
Some(builder.build()) | ||
} | ||
(other, Type::Union(union)) => { | ||
let mut builder = UnionBuilder::new(self.db); | ||
for element in union.elements(self.db) { | ||
builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?); | ||
} | ||
Some(builder.build()) | ||
Comment on lines
+2728
to
+2732
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be nice to avoid the duplication here, but I couldn't think of a good way to do this without introducing much more code. I also looked for ways to replace the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind the repetition (or use of a for loop, I'm a huge fan of for loops) but maybe https://docs.rs/itertools/latest/itertools/trait.Itertools.html#method.fold_options There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will require refactoring this method (so shouldn't be done in this PR), but in the long term something like #13787 seems like a good idea. I also think we'll want to emit a separate diagnostic on each member of the union that would be invalid for this operation (or one diagnostic that mentions all possibly erroring Union members). So I'd also just leave this as it is for now There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hm, not quite. union
.elements(self.db)
.iter()
.fold_while(Some(UnionBuilder::new(self.db)), |builder, element| {
if let Some(ty) = self.infer_binary_type_comparison(other, op, *element) {
FoldWhile::Continue(builder.map(|b| b.add(ty)))
} else {
FoldWhile::Done(None)
}
})
.into_inner()
.map(UnionBuilder::build) I'll go with the |
||
} | ||
|
||
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op { | ||
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)), | ||
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)), | ||
|
@@ -2908,6 +2923,7 @@ impl<'db> TypeInferenceBuilder<'db> { | |
} | ||
} | ||
} | ||
|
||
// Lookup the rich comparison `__dunder__` methods on instances | ||
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { | ||
ast::CmpOp::Lt => { | ||
|
@@ -2917,7 +2933,10 @@ impl<'db> TypeInferenceBuilder<'db> { | |
_ => Some(Type::Todo), | ||
}, | ||
// TODO: handle more types | ||
_ => Some(Type::Todo), | ||
_ => match op { | ||
ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)), | ||
sharkdp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_ => Some(Type::Todo), | ||
}, | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch, I really thought we had this already 🤦
There are unit tests at the bottom of this module for
is_subtype_of
, ideally we'd add one for this case. (Though the union-builder tests are great, too, and do transitively test this.)