Skip to content

Commit 7ecea65

Browse files
authored
[UNW-177] CLI should create SSH config if it does not exist (#9)
1 parent c8c189f commit 7ecea65

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

ssh/config.go

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ func getUnweaveSSHConfigPath() string {
1616
return filepath.Join(config.GetGlobalConfigPath(), "ssh_config")
1717
}
1818

19+
var sshDirPath = func() string {
20+
homeDir, err := os.UserHomeDir()
21+
if err != nil {
22+
ui.Errorf("Failed to get user home directory: %v", err)
23+
os.Exit(1)
24+
}
25+
return filepath.Join(homeDir, ".ssh")
26+
}()
27+
28+
var sshConfigPath = filepath.Join(sshDirPath, "config")
29+
1930
func AddHost(alias, host, user string, port int, identityFile string) error {
2031
configEntry := fmt.Sprintf(`Host %s
2132
HostName %s
@@ -43,13 +54,16 @@ func AddHost(alias, host, user string, port int, identityFile string) error {
4354
return err
4455
}
4556

46-
home, err := os.UserHomeDir()
57+
// Add an Include directive to the user's ssh config to unweave_global SSH configs - used for vscode-remote:
58+
err = os.MkdirAll(sshDirPath, 0700)
4759
if err != nil {
48-
ui.Errorf("Failed to get user home directory: %v", err)
49-
os.Exit(1)
60+
fmt.Println("Failed to create .ssh folder:", err)
61+
}
62+
if _, err := os.Stat(sshConfigPath); os.IsNotExist(err) {
63+
if _, err = os.Create(sshConfigPath); err != nil {
64+
return err
65+
}
5066
}
51-
sshConfigPath := filepath.Join(home, ".ssh", "config")
52-
5367
lines, err := readLines(sshConfigPath)
5468
if err != nil {
5569
return err

ssh/config_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package ssh
2+
3+
import (
4+
"io/ioutil"
5+
"os"
6+
"path/filepath"
7+
"strings"
8+
"testing"
9+
10+
. "github.com/franela/goblin"
11+
)
12+
13+
func TestConfig(t *testing.T) {
14+
g := Goblin(t)
15+
16+
g.Describe("AddHost", func() {
17+
testCfgPath := filepath.Join(sshDirPath, "test_config")
18+
unweaveConfigPath := getUnweaveSSHConfigPath()
19+
20+
g.BeforeEach(func() {
21+
if _, err := os.Stat(testCfgPath); os.IsNotExist(err) {
22+
initialConfig := "# Test SSH Config\n"
23+
err = ioutil.WriteFile(testCfgPath, []byte(initialConfig), 0600)
24+
g.Assert(err).Equal(nil)
25+
}
26+
27+
sshConfigPath = testCfgPath
28+
})
29+
30+
g.AfterEach(func() {
31+
// Only perform create and destroy operations on the test config file
32+
if testCfgPath != sshConfigPath {
33+
return
34+
}
35+
36+
err := os.Remove(testCfgPath)
37+
g.Assert(err).Equal(nil)
38+
})
39+
40+
g.It("should add the Include directive to the .ssh/config file", func() {
41+
err := AddHost("example", "example.com", "user", 22, "")
42+
g.Assert(err).Equal(nil)
43+
44+
configData, err := ioutil.ReadFile(sshConfigPath)
45+
g.Assert(err).Equal(nil)
46+
47+
configContent := string(configData)
48+
expectedInclude := "Include " + unweaveConfigPath
49+
50+
g.Assert(strings.Contains(configContent, expectedInclude)).Equal(true)
51+
})
52+
})
53+
}

0 commit comments

Comments
 (0)