Skip to content

Commit 32f0cf1

Browse files
committed
removed fmt strings and replaced with inline SQL | added unit tests
1 parent 87c355e commit 32f0cf1

2 files changed

Lines changed: 92 additions & 41 deletions

File tree

plugins/database/mssql/mssql.go

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ import (
55
"database/sql"
66
"errors"
77
"fmt"
8-
"strconv"
98
"strings"
109

1110
_ "github.com/denisenkom/go-mssqldb"
12-
multierror "github.com/hashicorp/go-multierror"
11+
"github.com/hashicorp/go-multierror"
12+
"github.com/hashicorp/go-secure-stdlib/parseutil"
1313
"github.com/hashicorp/go-secure-stdlib/strutil"
14-
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
14+
15+
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
1516
"github.com/hashicorp/vault/sdk/database/helper/connutil"
1617
"github.com/hashicorp/vault/sdk/database/helper/dbutil"
1718
"github.com/hashicorp/vault/sdk/helper/dbtxn"
@@ -98,20 +99,14 @@ func (m *MSSQL) Initialize(ctx context.Context, req dbplugin.InitializeRequest)
9899
return dbplugin.InitializeResponse{}, fmt.Errorf("invalid username template - did you reference a field that isn't available? : %w", err)
99100
}
100101

101-
containedDB := false
102-
containedDBRaw, err := strutil.GetString(req.Config, "contained_db")
103-
if err != nil {
104-
return dbplugin.InitializeResponse{}, fmt.Errorf("failed to retrieve contained_db: %w", err)
105-
}
106-
if containedDBRaw != "" {
107-
containedDB, err = strconv.ParseBool(containedDBRaw)
102+
if v, ok := req.Config["contained_db"]; ok {
103+
containedDB, err := parseutil.ParseBool(v)
108104
if err != nil {
109-
return dbplugin.InitializeResponse{}, fmt.Errorf("parsing error: incorrect boolean operator provided for contained_db: %w", err)
105+
return dbplugin.InitializeResponse{}, fmt.Errorf(`invalid value for "contained_db": %w`, err)
110106
}
107+
m.containedDB = containedDB
111108
}
112109

113-
m.containedDB = containedDB
114-
115110
resp := dbplugin.InitializeResponse{
116111
Config: newConf,
117112
}
@@ -221,24 +216,32 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
221216

222217
// Check if DB is contained
223218
if m.containedDB {
224-
revokeStmt, err := db.PrepareContext(ctx, fmt.Sprintf("DROP USER IF EXISTS [%s]", username))
219+
revokeQuery :=
220+
`DECLARE @stmt nvarchar(max);
221+
SET @stmt = 'DROP USER IF EXISTS ' + QuoteName(@username);
222+
EXEC(@stmt);`
223+
revokeStmt, err := db.PrepareContext(ctx, revokeQuery)
225224
if err != nil {
226225
return err
227226
}
228227
defer revokeStmt.Close()
229-
if _, err := revokeStmt.ExecContext(ctx); err != nil {
228+
if _, err := revokeStmt.ExecContext(ctx, sql.Named("username", username)); err != nil {
230229
return err
231230
}
232231
return nil
233232
}
234233

235234
// First disable server login
236-
disableStmt, err := db.PrepareContext(ctx, fmt.Sprintf("ALTER LOGIN [%s] DISABLE;", username))
237-
if err != nil {
235+
disableQuery :=
236+
`DECLARE @stmt nvarchar(max);
237+
SET @stmt = 'ALTER LOGIN ' + QuoteName(@username) + ' DISABLE';
238+
EXEC(@stmt);`
239+
disableStmt, err := db.PrepareContext(ctx, disableQuery)
240+
if err != nil{
238241
return err
239242
}
240243
defer disableStmt.Close()
241-
if _, err := disableStmt.ExecContext(ctx); err != nil {
244+
if _, err := disableStmt.ExecContext(ctx, sql.Named("username", username)); err != nil {
242245
return err
243246
}
244247

@@ -316,12 +319,12 @@ func (m *MSSQL) revokeUserDefault(ctx context.Context, username string) error {
316319
}
317320

318321
// Drop this login
319-
stmt, err = db.PrepareContext(ctx, fmt.Sprintf(dropLoginSQL, username, username))
322+
stmt, err = db.PrepareContext(ctx, dropLoginSQL)
320323
if err != nil {
321324
return err
322325
}
323326
defer stmt.Close()
324-
if _, err := stmt.ExecContext(ctx); err != nil {
327+
if _, err := stmt.ExecContext(ctx, sql.Named("username", username)); err != nil {
325328
return err
326329
}
327330

@@ -418,14 +421,12 @@ END
418421
`
419422

420423
const dropLoginSQL = `
421-
IF EXISTS
422-
(SELECT name
423-
FROM master.sys.server_principals
424-
WHERE name = N'%s')
425-
BEGIN
426-
DROP LOGIN [%s]
427-
END
428-
`
424+
DECLARE @stmt nvarchar(max)
425+
SET @stmt = 'IF EXISTS (SELECT name FROM [master].[sys].[server_principals] WHERE [name] = ' + QuoteName(@username, '''') + ') ' +
426+
'BEGIN ' +
427+
'DROP LOGIN ' + QuoteName(@username) + ' ' +
428+
'END'
429+
EXEC (@stmt)`
429430

430431
const alterLoginSQL = `
431432
ALTER LOGIN [{{username}}] WITH PASSWORD = '{{password}}'

plugins/database/mssql/mssql_test.go

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import (
1111
"time"
1212

1313
mssqlhelper "github.com/hashicorp/vault/helper/testhelpers/mssql"
14-
dbplugin "github.com/hashicorp/vault/sdk/database/dbplugin/v5"
14+
"github.com/hashicorp/vault/sdk/database/dbplugin/v5"
1515
dbtesting "github.com/hashicorp/vault/sdk/database/dbplugin/v5/testing"
1616
"github.com/hashicorp/vault/sdk/helper/dbtxn"
17+
"github.com/stretchr/testify/assert"
1718
)
1819

1920
func TestInitialize(t *testing.T) {
@@ -43,6 +44,15 @@ func TestInitialize(t *testing.T) {
4344
},
4445
},
4546
"contained_db set": {
47+
dbplugin.InitializeRequest{
48+
Config: map[string]interface{}{
49+
"connection_url": connURL,
50+
"contained_db": true,
51+
},
52+
VerifyConnection: true,
53+
},
54+
},
55+
"contained_db set string": {
4656
dbplugin.InitializeRequest{
4757
Config: map[string]interface{}{
4858
"connection_url": connURL,
@@ -253,7 +263,10 @@ func TestUpdateUser_password(t *testing.T) {
253263
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
254264
defer dbtesting.AssertClose(t, db)
255265

256-
createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
266+
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
267+
if err != nil {
268+
t.Fatalf("Failed to create user: %s", err)
269+
}
257270

258271
assertCredsExist(t, connURL, dbUser, initPassword)
259272

@@ -317,7 +330,10 @@ func TestDeleteUser(t *testing.T) {
317330
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
318331
defer dbtesting.AssertClose(t, db)
319332

320-
createTestMSSQLUser(t, connURL, dbUser, initPassword, testMSSQLLogin)
333+
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
334+
if err != nil {
335+
t.Fatalf("Failed to create user: %s", err)
336+
}
321337

322338
assertCredsExist(t, connURL, dbUser, initPassword)
323339

@@ -341,6 +357,44 @@ func TestDeleteUser(t *testing.T) {
341357
assertCredsDoNotExist(t, connURL, dbUser, initPassword)
342358
}
343359

360+
func TestSQLSanitization(t *testing.T) {
361+
cleanup, connURL := mssqlhelper.PrepareMSSQLTestContainer(t)
362+
defer cleanup()
363+
364+
injectionString := "vaultuser]"
365+
dbUser := "vaultuser"
366+
initPassword := "p4$sw0rd"
367+
368+
initReq := dbplugin.InitializeRequest{
369+
Config: map[string]interface{}{
370+
"connection_url": connURL,
371+
},
372+
VerifyConnection: true,
373+
}
374+
375+
db := new()
376+
377+
dbtesting.AssertInitializeCircleCiTest(t, db, initReq)
378+
defer dbtesting.AssertClose(t, db)
379+
380+
err := createTestMSSQLUser(connURL, dbUser, initPassword, testMSSQLLogin)
381+
if err != nil {
382+
t.Fatalf("Failed to create user: %s", err)
383+
}
384+
385+
assertCredsExist(t, connURL, dbUser, initPassword)
386+
387+
deleteReq := dbplugin.DeleteUserRequest{
388+
Username: injectionString,
389+
}
390+
391+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
392+
defer cancel()
393+
_, err = db.DeleteUser(ctx, deleteReq)
394+
395+
assert.EqualError(t, err, "mssql: Cannot alter the login 'vaultuser]', because it does not exist or you do not have permission.")
396+
}
397+
344398
func assertCredsExist(t testing.TB, connURL, username, password string) {
345399
t.Helper()
346400
err := testCredsExist(connURL, username, password)
@@ -369,18 +423,18 @@ func testCredsExist(connURL, username, password string) error {
369423
return db.Ping()
370424
}
371425

372-
func createTestMSSQLUser(t *testing.T, connURL string, username, password, query string) {
426+
func createTestMSSQLUser(connURL string, username, password, query string) error {
373427
db, err := sql.Open("mssql", connURL)
374428
defer db.Close()
375429
if err != nil {
376-
t.Fatal(err)
430+
return err
377431
}
378432

379433
// Start a transaction
380434
ctx := context.Background()
381435
tx, err := db.BeginTx(ctx, nil)
382436
if err != nil {
383-
t.Fatal(err)
437+
return err
384438
}
385439
defer func() {
386440
_ = tx.Rollback()
@@ -391,24 +445,20 @@ func createTestMSSQLUser(t *testing.T, connURL string, username, password, query
391445
"password": password,
392446
}
393447
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, query); err != nil {
394-
t.Fatal(err)
448+
return err
395449
}
396450
// Commit the transaction
397451
if err := tx.Commit(); err != nil {
398-
t.Fatal(err)
452+
return err
399453
}
454+
return nil
400455
}
401456

402457
const testMSSQLRole = `
403458
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
404459
CREATE USER [{{name}}] FOR LOGIN [{{name}}];
405460
GRANT SELECT, INSERT, UPDATE, DELETE ON SCHEMA::dbo TO [{{name}}];`
406461

407-
const testMSSQLDrop = `
408-
DROP USER [{{name}}];
409-
DROP LOGIN [{{name}}];
410-
`
411-
412462
const testMSSQLLogin = `
413463
CREATE LOGIN [{{name}}] WITH PASSWORD = '{{password}}';
414464
`

0 commit comments

Comments
 (0)