Skip to content

Commit 6930d0c

Browse files
committed
another fix
1 parent dd23a4d commit 6930d0c

File tree

1 file changed

+81
-40
lines changed

1 file changed

+81
-40
lines changed

vm/src/builtins/type.rs

Lines changed: 81 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ impl Constructor for PyType {
10191019
attributes.insert(identifier!(vm, __hash__), vm.ctx.none.clone().into());
10201020
}
10211021

1022-
let heaptype_slots: Option<PyRef<PyTuple<PyStrRef>>> =
1022+
let (heaptype_slots, add_dict): (Option<PyRef<PyTuple<PyStrRef>>>, bool) =
10231023
if let Some(x) = attributes.get(identifier!(vm, __slots__)) {
10241024
let slots = if x.class().is(vm.ctx.types.str_type) {
10251025
let x = unsafe { x.downcast_unchecked_ref::<PyStr>() };
@@ -1036,9 +1036,26 @@ impl Constructor for PyType {
10361036
let tuple = elements.into_pytuple(vm);
10371037
tuple.try_into_typed(vm)?
10381038
};
1039-
Some(slots)
1039+
1040+
// Check if __dict__ is in slots
1041+
let dict_name = "__dict__";
1042+
let has_dict = slots.iter().any(|s| s.as_str() == dict_name);
1043+
1044+
// Filter out __dict__ from slots
1045+
let filtered_slots = if has_dict {
1046+
let filtered: Vec<PyStrRef> = slots
1047+
.iter()
1048+
.filter(|s| s.as_str() != dict_name)
1049+
.cloned()
1050+
.collect();
1051+
PyTuple::new_ref_typed(filtered, &vm.ctx)
1052+
} else {
1053+
slots
1054+
};
1055+
1056+
(Some(filtered_slots), has_dict)
10401057
} else {
1041-
None
1058+
(None, false)
10421059
};
10431060

10441061
// FIXME: this is a temporary fix. multi bases with multiple slots will break object
@@ -1051,8 +1068,10 @@ impl Constructor for PyType {
10511068
let member_count: usize = base_member_count + heaptype_member_count;
10521069

10531070
let mut flags = PyTypeFlags::heap_type_flags();
1054-
// Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined.
1055-
if heaptype_slots.is_none() {
1071+
// Add HAS_DICT and MANAGED_DICT if:
1072+
// 1. __slots__ is not defined, OR
1073+
// 2. __dict__ is in __slots__
1074+
if heaptype_slots.is_none() || add_dict {
10561075
flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT;
10571076
}
10581077

@@ -1446,63 +1465,85 @@ impl Representable for PyType {
14461465
}
14471466
}
14481467

1449-
fn find_base_dict_descr(cls: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
1450-
let dict_attr = identifier!(vm, __dict__);
1451-
1452-
// First check if the class itself has a user-defined __dict__ property
1453-
if let Some(descr) = cls.attributes.read().get(&dict_attr) {
1454-
// If it's a property (user-defined), return it
1455-
if descr.class().is(vm.ctx.types.property_type) {
1456-
return Some(descr.clone());
1468+
// Equivalent to CPython's get_builtin_base_with_dict
1469+
fn get_builtin_base_with_dict(typ: &Py<PyType>, vm: &VirtualMachine) -> Option<PyTypeRef> {
1470+
let mut current = Some(typ.to_owned());
1471+
while let Some(t) = current {
1472+
// In CPython: type->tp_dictoffset != 0 && !(type->tp_flags & Py_TPFLAGS_HEAPTYPE)
1473+
// Special case: type itself is a builtin with dict support
1474+
if t.is(vm.ctx.types.type_type) {
1475+
return Some(t);
1476+
}
1477+
// We check HAS_DICT flag (equivalent to tp_dictoffset != 0) and HEAPTYPE
1478+
if t.slots.flags.contains(PyTypeFlags::HAS_DICT)
1479+
&& !t.slots.flags.contains(PyTypeFlags::HEAPTYPE)
1480+
{
1481+
return Some(t);
14571482
}
1458-
// Skip getset descriptors in the current class to avoid recursion
1483+
current = t.__base__();
14591484
}
1485+
None
1486+
}
14601487

1461-
// Then check bases (like original implementation)
1462-
cls.iter_base_chain().skip(1).find_map(|cls| {
1463-
// TODO: should actually be some translation of:
1464-
// cls.slot_dictoffset != 0 && !cls.flags.contains(HEAPTYPE)
1465-
if cls.is(vm.ctx.types.type_type) {
1466-
cls.get_attr(dict_attr)
1467-
} else {
1468-
None
1469-
}
1470-
})
1488+
// Equivalent to CPython's get_dict_descriptor
1489+
fn get_dict_descriptor(base: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
1490+
let dict_attr = identifier!(vm, __dict__);
1491+
// Use _PyType_Lookup (which is lookup_ref in RustPython)
1492+
base.lookup_ref(dict_attr, vm)
14711493
}
14721494

14731495
fn subtype_get_dict(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
1474-
// TODO: obj.class().as_pyref() need to be supported
1475-
let ret = match find_base_dict_descr(obj.class(), vm) {
1476-
Some(descr) => vm.call_get_descriptor(&descr, obj).unwrap_or_else(|| {
1496+
let base = get_builtin_base_with_dict(obj.class(), vm);
1497+
1498+
if let Some(base_type) = base {
1499+
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
1500+
// Call the descriptor's tp_descr_get
1501+
vm.call_get_descriptor(&descr, obj.clone())
1502+
.unwrap_or_else(|| {
1503+
Err(vm.new_type_error(format!(
1504+
"this __dict__ descriptor does not support '{}' objects",
1505+
obj.class().name()
1506+
)))
1507+
})
1508+
} else {
14771509
Err(vm.new_type_error(format!(
14781510
"this __dict__ descriptor does not support '{}' objects",
1479-
descr.class()
1511+
obj.class().name()
14801512
)))
1481-
})?,
1482-
None => object::object_get_dict(obj, vm)?.into(),
1483-
};
1484-
Ok(ret)
1513+
}
1514+
} else {
1515+
// PyObject_GenericGetDict
1516+
object::object_get_dict(obj, vm).map(Into::into)
1517+
}
14851518
}
14861519

14871520
fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
1488-
let cls = obj.class();
1489-
match find_base_dict_descr(cls, vm) {
1490-
Some(descr) => {
1521+
// Following CPython's subtype_setdict exactly
1522+
let base = get_builtin_base_with_dict(obj.class(), vm);
1523+
1524+
if let Some(base_type) = base {
1525+
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
1526+
// Call the descriptor's tp_descr_set
14911527
let descr_set = descr
14921528
.class()
14931529
.mro_find_map(|cls| cls.slots.descr_set.load())
14941530
.ok_or_else(|| {
14951531
vm.new_type_error(format!(
14961532
"this __dict__ descriptor does not support '{}' objects",
1497-
cls.name()
1533+
obj.class().name()
14981534
))
14991535
})?;
15001536
descr_set(&descr, obj, PySetterValue::Assign(value), vm)
1537+
} else {
1538+
Err(vm.new_type_error(format!(
1539+
"this __dict__ descriptor does not support '{}' objects",
1540+
obj.class().name()
1541+
)))
15011542
}
1502-
None => {
1503-
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
1504-
Ok(())
1505-
}
1543+
} else {
1544+
// PyObject_GenericSetDict
1545+
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
1546+
Ok(())
15061547
}
15071548
}
15081549

0 commit comments

Comments
 (0)