1818import queue
1919
2020from google .cloud .exceptions import NotFound
21+ from google .cloud .spanner_v1 import BatchCreateSessionsRequest
22+ from google .cloud .spanner_v1 import Session
2123from google .cloud .spanner_v1 ._helpers import _metadata_with_prefix
2224
2325
@@ -30,14 +32,18 @@ class AbstractSessionPool(object):
3032 :type labels: dict (str -> str) or None
3133 :param labels: (Optional) user-assigned labels for sessions created
3234 by the pool.
35+
36+ :type database_role: str
37+ :param database_role: (Optional) user-assigned database_role for the session.
3338 """
3439
3540 _database = None
3641
37- def __init__ (self , labels = None ):
42+ def __init__ (self , labels = None , database_role = None ):
3843 if labels is None :
3944 labels = {}
4045 self ._labels = labels
46+ self ._database_role = database_role
4147
4248 @property
4349 def labels (self ):
@@ -48,6 +54,15 @@ def labels(self):
4854 """
4955 return self ._labels
5056
57+ @property
58+ def database_role (self ):
59+ """User-assigned database_role for sessions created by the pool.
60+
61+ :rtype: str
62+ :returns: database_role assigned by the user
63+ """
64+ return self ._database_role
65+
5166 def bind (self , database ):
5267 """Associate the pool with a database.
5368
@@ -104,9 +119,9 @@ def _new_session(self):
104119 :rtype: :class:`~google.cloud.spanner_v1.session.Session`
105120 :returns: new session instance.
106121 """
107- if self .labels :
108- return self ._database . session ( labels = self .labels )
109- return self . _database . session ( )
122+ return self ._database . session (
123+ labels = self .labels , database_role = self .database_role
124+ )
110125
111126 def session (self , ** kwargs ):
112127 """Check out a session from the pool.
@@ -146,13 +161,22 @@ class FixedSizePool(AbstractSessionPool):
146161 :type labels: dict (str -> str) or None
147162 :param labels: (Optional) user-assigned labels for sessions created
148163 by the pool.
164+
165+ :type database_role: str
166+ :param database_role: (Optional) user-assigned database_role for the session.
149167 """
150168
151169 DEFAULT_SIZE = 10
152170 DEFAULT_TIMEOUT = 10
153171
154- def __init__ (self , size = DEFAULT_SIZE , default_timeout = DEFAULT_TIMEOUT , labels = None ):
155- super (FixedSizePool , self ).__init__ (labels = labels )
172+ def __init__ (
173+ self ,
174+ size = DEFAULT_SIZE ,
175+ default_timeout = DEFAULT_TIMEOUT ,
176+ labels = None ,
177+ database_role = None ,
178+ ):
179+ super (FixedSizePool , self ).__init__ (labels = labels , database_role = database_role )
156180 self .size = size
157181 self .default_timeout = default_timeout
158182 self ._sessions = queue .LifoQueue (size )
@@ -167,9 +191,14 @@ def bind(self, database):
167191 self ._database = database
168192 api = database .spanner_api
169193 metadata = _metadata_with_prefix (database .name )
194+ self ._database_role = self ._database_role or self ._database .database_role
195+ request = BatchCreateSessionsRequest (
196+ session_template = Session (creator_role = self .database_role ),
197+ )
170198
171199 while not self ._sessions .full ():
172200 resp = api .batch_create_sessions (
201+ request = request ,
173202 database = database .name ,
174203 session_count = self .size - self ._sessions .qsize (),
175204 metadata = metadata ,
@@ -243,10 +272,13 @@ class BurstyPool(AbstractSessionPool):
243272 :type labels: dict (str -> str) or None
244273 :param labels: (Optional) user-assigned labels for sessions created
245274 by the pool.
275+
276+ :type database_role: str
277+ :param database_role: (Optional) user-assigned database_role for the session.
246278 """
247279
248- def __init__ (self , target_size = 10 , labels = None ):
249- super (BurstyPool , self ).__init__ (labels = labels )
280+ def __init__ (self , target_size = 10 , labels = None , database_role = None ):
281+ super (BurstyPool , self ).__init__ (labels = labels , database_role = database_role )
250282 self .target_size = target_size
251283 self ._database = None
252284 self ._sessions = queue .LifoQueue (target_size )
@@ -259,6 +291,7 @@ def bind(self, database):
259291 when needed.
260292 """
261293 self ._database = database
294+ self ._database_role = self ._database_role or self ._database .database_role
262295
263296 def get (self ):
264297 """Check a session out from the pool.
@@ -340,10 +373,20 @@ class PingingPool(AbstractSessionPool):
340373 :type labels: dict (str -> str) or None
341374 :param labels: (Optional) user-assigned labels for sessions created
342375 by the pool.
376+
377+ :type database_role: str
378+ :param database_role: (Optional) user-assigned database_role for the session.
343379 """
344380
345- def __init__ (self , size = 10 , default_timeout = 10 , ping_interval = 3000 , labels = None ):
346- super (PingingPool , self ).__init__ (labels = labels )
381+ def __init__ (
382+ self ,
383+ size = 10 ,
384+ default_timeout = 10 ,
385+ ping_interval = 3000 ,
386+ labels = None ,
387+ database_role = None ,
388+ ):
389+ super (PingingPool , self ).__init__ (labels = labels , database_role = database_role )
347390 self .size = size
348391 self .default_timeout = default_timeout
349392 self ._delta = datetime .timedelta (seconds = ping_interval )
@@ -360,9 +403,15 @@ def bind(self, database):
360403 api = database .spanner_api
361404 metadata = _metadata_with_prefix (database .name )
362405 created_session_count = 0
406+ self ._database_role = self ._database_role or self ._database .database_role
407+
408+ request = BatchCreateSessionsRequest (
409+ session_template = Session (creator_role = self .database_role ),
410+ )
363411
364412 while created_session_count < self .size :
365413 resp = api .batch_create_sessions (
414+ request = request ,
366415 database = database .name ,
367416 session_count = self .size - created_session_count ,
368417 metadata = metadata ,
@@ -470,13 +519,27 @@ class TransactionPingingPool(PingingPool):
470519 :type labels: dict (str -> str) or None
471520 :param labels: (Optional) user-assigned labels for sessions created
472521 by the pool.
522+
523+ :type database_role: str
524+ :param database_role: (Optional) user-assigned database_role for the session.
473525 """
474526
475- def __init__ (self , size = 10 , default_timeout = 10 , ping_interval = 3000 , labels = None ):
527+ def __init__ (
528+ self ,
529+ size = 10 ,
530+ default_timeout = 10 ,
531+ ping_interval = 3000 ,
532+ labels = None ,
533+ database_role = None ,
534+ ):
476535 self ._pending_sessions = queue .Queue ()
477536
478537 super (TransactionPingingPool , self ).__init__ (
479- size , default_timeout , ping_interval , labels = labels
538+ size ,
539+ default_timeout ,
540+ ping_interval ,
541+ labels = labels ,
542+ database_role = database_role ,
480543 )
481544
482545 self .begin_pending_transactions ()
@@ -489,6 +552,7 @@ def bind(self, database):
489552 when needed.
490553 """
491554 super (TransactionPingingPool , self ).bind (database )
555+ self ._database_role = self ._database_role or self ._database .database_role
492556 self .begin_pending_transactions ()
493557
494558 def put (self , session ):
0 commit comments