|
28 | 28 | use crate::types::{IntersectionType, Type, UnionType}; |
29 | 29 | use crate::{Db, FxOrderSet}; |
30 | 30 |
|
| 31 | +use super::builtins_symbol_ty_by_name; |
| 32 | + |
31 | 33 | pub(crate) struct UnionBuilder<'db> { |
32 | 34 | elements: FxOrderSet<Type<'db>>, |
33 | 35 | db: &'db dyn Db, |
@@ -56,7 +58,24 @@ impl<'db> UnionBuilder<'db> { |
56 | 58 | self |
57 | 59 | } |
58 | 60 |
|
59 | | - pub(crate) fn build(self) -> Type<'db> { |
| 61 | + /// Performs the following normalizations: |
| 62 | + /// - Replaces `Literal[True,False]` with `bool`. |
| 63 | + /// - TODO For enums `E` with members `X1`,...,`Xn`, replaces |
| 64 | + /// `Literal[E.X1,...,E.Xn]` with `E`. |
| 65 | + fn simplify(&mut self) { |
| 66 | + if self |
| 67 | + .elements |
| 68 | + .is_superset(&[Type::BooleanLiteral(true), Type::BooleanLiteral(false)].into()) |
| 69 | + { |
| 70 | + let bool_ty = builtins_symbol_ty_by_name(self.db, "bool"); |
| 71 | + self.elements.remove(&Type::BooleanLiteral(true)); |
| 72 | + self.elements.remove(&Type::BooleanLiteral(false)); |
| 73 | + self.elements.insert(bool_ty); |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | + pub(crate) fn build(mut self) -> Type<'db> { |
| 78 | + self.simplify(); |
60 | 79 | match self.elements.len() { |
61 | 80 | 0 => Type::Never, |
62 | 81 | 1 => self.elements[0], |
@@ -247,17 +266,38 @@ impl<'db> InnerIntersectionBuilder<'db> { |
247 | 266 | mod tests { |
248 | 267 | use super::{IntersectionBuilder, IntersectionType, Type, UnionBuilder, UnionType}; |
249 | 268 | use crate::db::tests::TestDb; |
250 | | - |
251 | | - fn setup_db() -> TestDb { |
252 | | - TestDb::new() |
253 | | - } |
| 269 | + use crate::program::{Program, SearchPathSettings}; |
| 270 | + use crate::python_version::PythonVersion; |
| 271 | + use crate::types::builtins_symbol_ty_by_name; |
| 272 | + use crate::ProgramSettings; |
| 273 | + use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; |
254 | 274 |
|
255 | 275 | impl<'db> UnionType<'db> { |
256 | 276 | fn elements_vec(self, db: &'db TestDb) -> Vec<Type<'db>> { |
257 | 277 | self.elements(db).into_iter().collect() |
258 | 278 | } |
259 | 279 | } |
260 | 280 |
|
| 281 | + fn setup_db() -> TestDb { |
| 282 | + let db = TestDb::new(); |
| 283 | + |
| 284 | + let src_root = SystemPathBuf::from("/src"); |
| 285 | + db.memory_file_system() |
| 286 | + .create_directory_all(&src_root) |
| 287 | + .unwrap(); |
| 288 | + |
| 289 | + Program::from_settings( |
| 290 | + &db, |
| 291 | + &ProgramSettings { |
| 292 | + target_version: PythonVersion::default(), |
| 293 | + search_paths: SearchPathSettings::new(src_root), |
| 294 | + }, |
| 295 | + ) |
| 296 | + .expect("Valid search path settings"); |
| 297 | + |
| 298 | + db |
| 299 | + } |
| 300 | + |
261 | 301 | #[test] |
262 | 302 | fn build_union() { |
263 | 303 | let db = setup_db(); |
@@ -296,6 +336,33 @@ mod tests { |
296 | 336 | assert_eq!(ty, t0); |
297 | 337 | } |
298 | 338 |
|
| 339 | + #[test] |
| 340 | + fn build_union_bool() { |
| 341 | + let db = setup_db(); |
| 342 | + let bool_ty = builtins_symbol_ty_by_name(&db, "bool"); |
| 343 | + |
| 344 | + let t0 = Type::BooleanLiteral(true); |
| 345 | + let t1 = Type::BooleanLiteral(true); |
| 346 | + let t2 = Type::BooleanLiteral(false); |
| 347 | + let t3 = Type::IntLiteral(17); |
| 348 | + |
| 349 | + let Type::Union(union) = UnionBuilder::new(&db).add(t0).add(t1).add(t3).build() else { |
| 350 | + panic!("expected a union"); |
| 351 | + }; |
| 352 | + assert_eq!(union.elements_vec(&db), &[t0, t3]); |
| 353 | + let Type::Union(union) = UnionBuilder::new(&db) |
| 354 | + .add(t0) |
| 355 | + .add(t1) |
| 356 | + .add(t2) |
| 357 | + .add(t3) |
| 358 | + .build() |
| 359 | + else { |
| 360 | + panic!("expected a union"); |
| 361 | + }; |
| 362 | + |
| 363 | + assert_eq!(union.elements_vec(&db), &[t3, bool_ty]); |
| 364 | + } |
| 365 | + |
299 | 366 | #[test] |
300 | 367 | fn build_union_flatten() { |
301 | 368 | let db = setup_db(); |
|
0 commit comments