-
Notifications
You must be signed in to change notification settings - Fork 329
Expand file tree
/
Copy pathgeneric.py
More file actions
200 lines (159 loc) · 6.83 KB
/
Copy pathgeneric.py
File metadata and controls
200 lines (159 loc) · 6.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from collections.abc import Iterable
import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session
from sqlalchemy.util import set_creation_order
from .exceptions import ImproperlyConfigured
from .functions import identity
from .functions.orm import _get_class_registry
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
def __init__(self, *args, **kwargs):
"""
The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22,
adding a 'default_function' required positional argument before 'dispatch'.
This adjustment ensures compatibility across versions by inserting None for
'default_function' in versions >= 2.0.22.
Arguments received: (class, key, dispatch)
Required by AttributeImpl: (class, key, default_function, dispatch)
Setting None as default_function here.
"""
# Adjust for SQLAlchemy version change
sqlalchemy_version = tuple(map(int, sa.__version__.split('.')))
if sqlalchemy_version >= (2, 0, 22):
args = (*args[:2], None, *args[2:])
super().__init__(*args, **kwargs)
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if self.key in dict_:
return dict_[self.key]
# Retrieve the session bound to the state in order to perform
# a lazy query for the attribute.
session = _state_session(state)
if session is None:
# State is not bound to a session; we cannot proceed.
return None
# Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup?
discriminator = self.get_state_discriminator(state)
target_class = _get_class_registry(state.class_).get(discriminator)
if target_class is None:
# Unknown discriminator; return nothing.
return None
id = self.get_state_id(state)
target = session.get(target_class, id)
# Return found (or not found) target.
return target
def get_state_discriminator(self, state):
discriminator = self.parent_token.discriminator
if isinstance(discriminator, hybrid_property):
return getattr(state.obj(), discriminator.__name__)
else:
return state.attrs[discriminator.key].value
def get_state_id(self, state):
# Lookup row with the discriminator and id.
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
def set(
self,
state,
dict_,
initiator,
passive=attributes.PASSIVE_OFF,
check_old=None,
pop=False,
):
# Set us on the state.
dict_[self.key] = initiator
if initiator is None:
# Nullify relationship args
for id in self.parent_token.id:
dict_[id.key] = None
dict_[self.parent_token.discriminator.key] = None
else:
# Get the primary key of the initiator and ensure we
# can support this assignment.
class_ = type(initiator)
mapper = class_mapper(class_)
pk = mapper.identity_key_from_instance(initiator)[1]
# Set the identifier and the discriminator.
discriminator = class_.__name__
for index, id in enumerate(self.parent_token.id):
dict_[id.key] = pk[index]
dict_[self.parent_token.discriminator.key] = discriminator
class GenericRelationshipProperty(MapperProperty):
"""A generic form of the relationship property.
Creates a 1 to many relationship between the parent model
and any other models using a discriminator (the table name).
:param discriminator:
Field to discriminate which model we are referring to.
:param id:
Field to point to the model we are referring to.
"""
def __init__(self, discriminator, id, doc=None):
super().__init__()
self._discriminator_col = discriminator
self._id_cols = id
self._id = None
self._discriminator = None
self.doc = doc
set_creation_order(self)
def _column_to_property(self, column):
if isinstance(column, hybrid_property):
attr_key = column.__name__
for key, attr in self.parent.all_orm_descriptors.items():
if key == attr_key:
return attr
else:
for attr in self.parent.attrs.values():
if isinstance(attr, ColumnProperty):
if attr.columns[0].name == column.name:
return attr
def init(self):
def convert_strings(column):
if isinstance(column, str):
return self.parent.columns[column]
return column
self._discriminator_col = convert_strings(self._discriminator_col)
self._id_cols = convert_strings(self._id_cols)
if isinstance(self._id_cols, Iterable):
self._id_cols = list(map(convert_strings, self._id_cols))
else:
self._id_cols = [self._id_cols]
self.discriminator = self._column_to_property(self._discriminator_col)
if self.discriminator is None:
raise ImproperlyConfigured('Could not find discriminator descriptor.')
self.id = list(map(self._column_to_property, self._id_cols))
class Comparator(PropComparator):
def __init__(self, prop, parentmapper):
self.property = prop
self._parententity = parentmapper
def __eq__(self, other):
discriminator = type(other).__name__
q = self.property._discriminator_col == discriminator
other_id = identity(other)
for index, id in enumerate(self.property._id_cols):
q &= id == other_id[index]
return q
def __ne__(self, other):
return ~(self == other)
def is_type(self, other):
mapper = sa.inspect(other)
# Iterate through the weak sequence in order to get the actual
# mappers
class_names = [other.__name__]
class_names.extend(
[submapper.class_.__name__ for submapper in mapper._inheriting_mappers]
)
return self.property._discriminator_col.in_(class_names)
def instrument_class(self, mapper):
attributes.register_attribute(
mapper.class_,
self.key,
comparator=self.Comparator(self, mapper),
parententity=mapper,
doc=self.doc,
impl_class=GenericAttributeImpl,
parent_token=self,
)
def generic_relationship(*args, **kwargs):
return GenericRelationshipProperty(*args, **kwargs)