@@ -115,8 +115,7 @@ func (t *templateData) processAST(f *ast.File) error {
115115 continue
116116 }
117117 for _ , field := range st .Fields .List {
118- se , ok := field .Type .(* ast.StarExpr )
119- if len (field .Names ) == 0 || ! ok {
118+ if len (field .Names ) == 0 {
120119 continue
121120 }
122121
@@ -132,13 +131,25 @@ func (t *templateData) processAST(f *ast.File) error {
132131 continue
133132 }
134133
134+ se , ok := field .Type .(* ast.StarExpr )
135+ if ! ok {
136+ switch x := field .Type .(type ) {
137+ case * ast.MapType :
138+ t .addMapType (x , ts .Name .String (), fieldName .String (), false )
139+ continue
140+ }
141+
142+ logf ("Skipping field type %T, fieldName=%v" , field .Type , fieldName )
143+ continue
144+ }
145+
135146 switch x := se .X .(type ) {
136147 case * ast.ArrayType :
137148 t .addArrayType (x , ts .Name .String (), fieldName .String ())
138149 case * ast.Ident :
139150 t .addIdent (x , ts .Name .String (), fieldName .String ())
140151 case * ast.MapType :
141- t .addMapType (x , ts .Name .String (), fieldName .String ())
152+ t .addMapType (x , ts .Name .String (), fieldName .String (), true )
142153 case * ast.SelectorExpr :
143154 t .addSelectorExpr (x , ts .Name .String (), fieldName .String ())
144155 default :
@@ -232,7 +243,7 @@ func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) {
232243 t .Getters = append (t .Getters , newGetter (receiverType , fieldName , x .String (), zeroValue , namedStruct ))
233244}
234245
235- func (t * templateData ) addMapType (x * ast.MapType , receiverType , fieldName string ) {
246+ func (t * templateData ) addMapType (x * ast.MapType , receiverType , fieldName string , isAPointer bool ) {
236247 var keyType string
237248 switch key := x .Key .(type ) {
238249 case * ast.Ident :
@@ -253,7 +264,9 @@ func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string
253264
254265 fieldType := fmt .Sprintf ("map[%v]%v" , keyType , valueType )
255266 zeroValue := fmt .Sprintf ("map[%v]%v{}" , keyType , valueType )
256- t .Getters = append (t .Getters , newGetter (receiverType , fieldName , fieldType , zeroValue , false ))
267+ ng := newGetter (receiverType , fieldName , fieldType , zeroValue , false )
268+ ng .MapType = ! isAPointer
269+ t .Getters = append (t .Getters , ng )
257270}
258271
259272func (t * templateData ) addSelectorExpr (x * ast.SelectorExpr , receiverType , fieldName string ) {
@@ -300,6 +313,7 @@ type getter struct {
300313 FieldType string
301314 ZeroValue string
302315 NamedStruct bool // Getter for named struct.
316+ MapType bool
303317}
304318
305319type byName []* getter
@@ -332,6 +346,14 @@ func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} {
332346 }
333347 return {{.ReceiverVar}}.{{.FieldName}}
334348}
349+ {{else if .MapType}}
350+ // Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise.
351+ func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
352+ if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil {
353+ return {{.ZeroValue}}
354+ }
355+ return {{.ReceiverVar}}.{{.FieldName}}
356+ }
335357{{else}}
336358// Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise.
337359func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} {
@@ -368,6 +390,16 @@ func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
368390 {{.ReceiverVar}} = nil
369391 {{.ReceiverVar}}.Get{{.FieldName}}()
370392}
393+ {{else if .MapType}}
394+ func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
395+ zeroValue := {{.FieldType}}{}
396+ {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue }
397+ {{.ReceiverVar}}.Get{{.FieldName}}()
398+ {{.ReceiverVar}} = &{{.ReceiverType}}{}
399+ {{.ReceiverVar}}.Get{{.FieldName}}()
400+ {{.ReceiverVar}} = nil
401+ {{.ReceiverVar}}.Get{{.FieldName}}()
402+ }
371403{{else}}
372404func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) {
373405 var zeroValue {{.FieldType}}
0 commit comments