Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ func (ctx *RequestCtx) Hijacked() bool {
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) {
ctx.userValues.Set(key, value)
}

Expand All @@ -688,7 +688,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
}

// UserValue returns the value stored via SetUserValue* under the given key.
func (ctx *RequestCtx) UserValue(key string) interface{} {
func (ctx *RequestCtx) UserValue(key interface{}) interface{} {
return ctx.userValues.Get(key)
}

Expand All @@ -698,11 +698,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}

// VisitUserValues calls visitor for each existing userValue.
// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
if _, ok := kv.key.(string); ok {
visitor(s2b(kv.key.(string)), kv.value)
}
}
}

// VisitUserValuesAll calls visitor for each existing userValue.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
Expand All @@ -715,7 +728,7 @@ func (ctx *RequestCtx) ResetUserValues() {
}

// RemoveUserValue removes the given key and the value under it in ctx.
func (ctx *RequestCtx) RemoveUserValue(key string) {
func (ctx *RequestCtx) RemoveUserValue(key interface{}) {
ctx.userValues.Remove(key)
}

Expand Down Expand Up @@ -2696,10 +2709,7 @@ func (ctx *RequestCtx) Err() error {
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
if keyString, ok := key.(string); ok {
return ctx.UserValue(keyString)
}
return nil
return ctx.UserValue(key)
}

var fakeServer = &Server{
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ func TestRequestCtxUserValue(t *testing.T) {
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
v := ctx.UserValueBytes(key)
v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
Expand Down
33 changes: 21 additions & 12 deletions userdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ import (
)

type userDataKV struct {
key []byte
key interface{}
value interface{}
}

type userData []userDataKV

func (d *userData) Set(key string, value interface{}) {
func (d *userData) Set(key interface{}, value interface{}) {
if b, ok := key.([]byte); ok {
key = string(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
kv.value = value
return
}
Expand All @@ -30,36 +33,39 @@ func (d *userData) Set(key string, value interface{}) {
if c > n {
args = args[:n+1]
kv := &args[n]
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = args
return
}

kv := userDataKV{}
kv.key = append(kv.key[:0], key...)
kv.key = key
kv.value = value
*d = append(args, kv)
}

func (d *userData) SetBytes(key []byte, value interface{}) {
d.Set(b2s(key), value)
d.Set(key, value)
}

func (d *userData) Get(key string) interface{} {
func (d *userData) Get(key interface{}) interface{} {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
return kv.value
}
}
return nil
}

func (d *userData) GetBytes(key []byte) interface{} {
return d.Get(b2s(key))
return d.Get(key)
}

func (d *userData) Reset() {
Expand All @@ -74,12 +80,15 @@ func (d *userData) Reset() {
*d = (*d)[:0]
}

func (d *userData) Remove(key string) {
func (d *userData) Remove(key interface{}) {
if b, ok := key.([]byte); ok {
key = b2s(b)
}
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
if string(kv.key) == key {
if kv.key == key {
n--
args[i], args[n] = args[n], args[i]
args[n].value = nil
Expand All @@ -91,5 +100,5 @@ func (d *userData) Remove(key string) {
}

func (d *userData) RemoveBytes(key []byte) {
d.Remove(b2s(key))
d.Remove(key)
}