@@ -164,16 +164,17 @@ func TestConcurrencyManagerBasic(t *testing.T) {
164164 }
165165 latchSpans , lockSpans := c .collectSpans (t , txn , ts , reqs )
166166
167- c .requestsByName [reqName ] = concurrency.Request {
168- Txn : txn ,
169- Timestamp : ts ,
170- // TODO(nvanbenschoten): test Priority
171- ReadConsistency : readConsistency ,
172- WaitPolicy : waitPolicy ,
173- Requests : reqUnions ,
174- LatchSpans : latchSpans ,
175- LockSpans : lockSpans ,
176- }
167+ c .requestsByName [reqName ] = testReq {
168+ Request : concurrency.Request {
169+ Txn : txn ,
170+ Timestamp : ts ,
171+ // TODO(nvanbenschoten): test Priority
172+ ReadConsistency : readConsistency ,
173+ WaitPolicy : waitPolicy ,
174+ Requests : reqUnions ,
175+ LatchSpans : latchSpans ,
176+ LockSpans : lockSpans ,
177+ }}
177178 return ""
178179
179180 case "sequence" :
@@ -190,8 +191,8 @@ func TestConcurrencyManagerBasic(t *testing.T) {
190191 c .mu .Unlock ()
191192
192193 opName := fmt .Sprintf ("sequence %s" , reqName )
193- mon .runAsync (opName , func (ctx context.Context ) {
194- guard , resp , err := m .SequenceReq (ctx , prev , req )
194+ cancel := mon .runAsync (opName , func (ctx context.Context ) {
195+ guard , resp , err := m .SequenceReq (ctx , prev , req . Request )
195196 if err != nil {
196197 log .Eventf (ctx , "sequencing complete, returned error: %v" , err )
197198 } else if resp != nil {
@@ -205,6 +206,8 @@ func TestConcurrencyManagerBasic(t *testing.T) {
205206 log .Event (ctx , "sequencing complete, returned no guard" )
206207 }
207208 })
209+ req .cancel = cancel
210+ c .requestsByName [reqName ] = req
208211 return c .waitAndCollect (t , mon )
209212
210213 case "finish" :
@@ -477,6 +480,11 @@ func TestConcurrencyManagerBasic(t *testing.T) {
477480 })
478481}
479482
483+ type testReq struct {
484+ cancel func ()
485+ concurrency.Request
486+ }
487+
480488// cluster encapsulates the state of a running cluster and a set of requests.
481489// It serves as the test harness in TestConcurrencyManagerBasic - maintaining
482490// transaction and request declarations, recording the state of in-flight
@@ -491,7 +499,7 @@ type cluster struct {
491499 // Definitions.
492500 txnCounter uint32
493501 txnsByName map [string ]* roachpb.Transaction
494- requestsByName map [string ]concurrency. Request
502+ requestsByName map [string ]testReq
495503
496504 // Request state. Cleared on reset.
497505 mu syncutil.Mutex
@@ -511,6 +519,7 @@ type txnRecord struct {
511519type txnPush struct {
512520 ctx context.Context
513521 pusher , pushee uuid.UUID
522+ count int
514523}
515524
516525func newCluster () * cluster {
@@ -520,7 +529,7 @@ func newCluster() *cluster {
520529 rangeDesc : & roachpb.RangeDescriptor {RangeID : 1 },
521530
522531 txnsByName : make (map [string ]* roachpb.Transaction ),
523- requestsByName : make (map [string ]concurrency. Request ),
532+ requestsByName : make (map [string ]testReq ),
524533 guardsByReqName : make (map [string ]* concurrency.Guard ),
525534 txnRecords : make (map [uuid.UUID ]* txnRecord ),
526535 txnPushes : make (map [uuid.UUID ]* txnPush ),
@@ -533,6 +542,9 @@ func (c *cluster) makeConfig() concurrency.Config {
533542 RangeDesc : c .rangeDesc ,
534543 Settings : c .st ,
535544 IntentResolver : c ,
545+ OnContentionEvent : func (ev * roachpb.ContentionEvent ) {
546+ ev .Duration = 1234 * time .Millisecond // for determinism
547+ },
536548 TxnWaitMetrics : txnwait .NewMetrics (time .Minute ),
537549 }
538550}
@@ -680,11 +692,18 @@ func (r *txnRecord) asTxn() (*roachpb.Transaction, chan struct{}) {
680692func (c * cluster ) registerPush (ctx context.Context , pusher , pushee uuid.UUID ) (* txnPush , error ) {
681693 c .mu .Lock ()
682694 defer c .mu .Unlock ()
683- if _ , ok := c .txnPushes [pusher ]; ok {
684- return nil , errors .Errorf ("txn %v already pushing" , pusher )
695+ if p , ok := c .txnPushes [pusher ]; ok {
696+ if pushee != p .pushee {
697+ return nil , errors .Errorf ("pusher %s can't push two txns %s and %s at the same time" ,
698+ pusher .Short (), pushee .Short (), p .pushee .Short (),
699+ )
700+ }
701+ p .count ++
702+ return p , nil
685703 }
686704 p := & txnPush {
687705 ctx : ctx ,
706+ count : 1 ,
688707 pusher : pusher ,
689708 pushee : pushee ,
690709 }
@@ -695,7 +714,17 @@ func (c *cluster) registerPush(ctx context.Context, pusher, pushee uuid.UUID) (*
695714func (c * cluster ) unregisterPush (push * txnPush ) {
696715 c .mu .Lock ()
697716 defer c .mu .Unlock ()
698- delete (c .txnPushes , push .pusher )
717+ p , ok := c .txnPushes [push .pusher ]
718+ if ! ok {
719+ return
720+ }
721+ p .count --
722+ if p .count == 0 {
723+ delete (c .txnPushes , push .pusher )
724+ }
725+ if p .count < 0 {
726+ panic (fmt .Sprintf ("negative count: %+v" , p ))
727+ }
699728}
700729
701730// detectDeadlocks looks at all in-flight transaction pushes and determines
@@ -792,7 +821,7 @@ func (c *cluster) resetNamespace() {
792821 defer c .mu .Unlock ()
793822 c .txnCounter = 0
794823 c .txnsByName = make (map [string ]* roachpb.Transaction )
795- c .requestsByName = make (map [string ]concurrency. Request )
824+ c .requestsByName = make (map [string ]testReq )
796825 c .txnRecords = make (map [uuid.UUID ]* txnRecord )
797826}
798827
@@ -871,7 +900,7 @@ func (m *monitor) runSync(opName string, fn func(context.Context)) {
871900 atomic .StoreInt32 (& g .finished , 1 )
872901}
873902
874- func (m * monitor ) runAsync (opName string , fn func (context.Context )) {
903+ func (m * monitor ) runAsync (opName string , fn func (context.Context )) ( cancel func ()) {
875904 m .seq ++
876905 ctx , collect , cancel := tracing .ContextWithRecordingSpan (context .Background (), opName )
877906 g := & monitoredGoroutine {
@@ -887,6 +916,7 @@ func (m *monitor) runAsync(opName string, fn func(context.Context)) {
887916 fn (ctx )
888917 atomic .StoreInt32 (& g .finished , 1 )
889918 }()
919+ return cancel
890920}
891921
892922func (m * monitor ) numMonitored () int {
0 commit comments