Skip to content

Commit fe7f3f7

Browse files
authored
Deprecate Node Type, add CLI support for MultiGPU instances (#13)
* Temporarily update the CLI to use GPU.Type as NodeType * update gomod to latest commit * Function to parse HardwareSpec * Deprecate NodeTypeID add new flags * Wire Cobra to pass HardwareSpec values * Feedback pass 1 * Pair programming pass
1 parent ff25c7c commit fe7f3f7

6 files changed

Lines changed: 160 additions & 57 deletions

File tree

cmd/session.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,31 +124,23 @@ func generateSSHKey(ctx context.Context) (string, []byte, error) {
124124

125125
func sessionCreate(ctx context.Context, execConfig types.ExecConfig, gitConfig types.GitConfig) (string, error) {
126126
var region, image *string
127-
var nodeTypeIDs []string
128127

129128
if config.Config.Project.DefaultProvider == "" && config.Provider == "" {
130129
ui.Errorf("No provider specified. Either set a default provider in you project config or specify a provider with the --provider flag")
131130
os.Exit(1)
132131
}
133-
134132
provider := config.Config.Project.DefaultProvider
135133
if config.Provider != "" {
136134
provider = config.Provider
137135
}
138136

139-
if p, ok := config.Config.Project.Providers[provider]; ok {
140-
nodeTypeIDs = p.NodeTypes
141-
}
142-
if len(config.NodeTypeID) != 0 {
143-
nodeTypeIDs = []string{config.NodeTypeID}
137+
spec, err := parseHardwareSpec()
138+
if err != nil {
139+
return "", err
144140
}
145141
if config.NodeRegion != "" {
146142
region = &config.NodeRegion
147143
}
148-
if len(nodeTypeIDs) == 0 {
149-
ui.Errorf("No node types specified")
150-
return "", fmt.Errorf("no node types specified")
151-
}
152144

153145
if config.BuildID != "" {
154146
image = &config.BuildID
@@ -164,7 +156,7 @@ func sessionCreate(ctx context.Context, execConfig types.ExecConfig, gitConfig t
164156

165157
params := types.ExecCreateParams{
166158
Provider: types.Provider(provider),
167-
NodeTypeID: "",
159+
HardwareSpec: spec,
168160
Region: region,
169161
SSHKeyName: sshKeyName,
170162
SSHPublicKey: sshPublicKey,
@@ -175,7 +167,7 @@ func sessionCreate(ctx context.Context, execConfig types.ExecConfig, gitConfig t
175167
Source: execConfig.Src,
176168
}
177169

178-
sessionID, err := session.Create(ctx, params, nodeTypeIDs)
170+
sessionID, err := session.Create(ctx, params)
179171
if err != nil {
180172
var e *types.Error
181173
if errors.As(err, &e) {
@@ -385,3 +377,31 @@ func formatExecCobraOpts(execs []types.Exec, prepend ...string) ([]string, map[i
385377

386378
return options, optionMap
387379
}
380+
381+
func parseHardwareSpec() (types.HardwareSpec, error) {
382+
return types.HardwareSpec{
383+
GPU: types.GPU{
384+
Count: types.HardwareRequestRange{
385+
Min: config.GPUs,
386+
Max: config.GPUs,
387+
},
388+
Type: config.GPUType,
389+
RAM: types.HardwareRequestRange{
390+
Min: config.GPUMemory,
391+
Max: config.GPUMemory,
392+
},
393+
},
394+
CPU: types.HardwareRequestRange{
395+
Min: config.CPUs,
396+
Max: config.CPUs,
397+
},
398+
RAM: types.HardwareRequestRange{
399+
Min: config.Memory,
400+
Max: config.Memory,
401+
},
402+
Storage: types.HardwareRequestRange{
403+
Min: config.HDD,
404+
Max: config.HDD,
405+
},
406+
}, nil
407+
}

config/flags.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,23 @@ var BuildID = ""
1313
// CreateExec is used to denote whether to create a new exec when running commands that require a exec.
1414
var CreateExec = true
1515

16-
// NodeTypeID is the ID of the provider specific node type to use when creating a new session
17-
var NodeTypeID = ""
16+
// GPUs is the number of GPUs to allocate for a gpuType.
17+
var GPUs int
18+
19+
// GPUMemory is the memory of GPU if applicable for a gpuType.
20+
var GPUMemory int
21+
22+
// GPUType is the type of GPU to use.
23+
var GPUType string
24+
25+
// CPUs is the number of VCPUs to allocate.
26+
var CPUs int
27+
28+
// Memory is the amount of RAM to allocate in GB.
29+
var Memory int
30+
31+
// HDD is the amount of storage to allocate in GB.
32+
var HDD int
1833

1934
// NodeRegion is the region to use when creating a new session
2035
var NodeRegion = ""

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ require (
1111
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06
1212
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
1313
github.com/spf13/cobra v1.6.1
14-
github.com/unweave/unweave v0.0.0-20230507172101-139ff3bb1192
14+
github.com/unweave/unweave v0.0.0-20230525135826-dacfce72a65a
1515
)
1616

1717
require (

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
7373
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
7474
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
7575
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
76-
github.com/unweave/unweave v0.0.0-20230507172101-139ff3bb1192 h1:uzTIOVmrTNWEn800oE/47GOKmgiKoQJW/mf7oLtxHEk=
77-
github.com/unweave/unweave v0.0.0-20230507172101-139ff3bb1192/go.mod h1:JUa40hxqyuBllT/k+SW8W8PGQmDEoH9nON9Mp/7fjJA=
76+
github.com/unweave/unweave v0.0.0-20230525135826-dacfce72a65a h1:V0RfEQxzMkHnfPVbbePzz+VycsPfvagUCQHedcRdJR4=
77+
github.com/unweave/unweave v0.0.0-20230525135826-dacfce72a65a/go.mod h1:JUa40hxqyuBllT/k+SW8W8PGQmDEoH9nON9Mp/7fjJA=
7878
golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
7979
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
8080
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=

main.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ func init() {
120120
Hidden: true,
121121
}
122122
boxCmd.Flags().StringVar(&config.Provider, "provider", "", "Provider to use")
123-
boxCmd.Flags().StringVar(&config.NodeTypeID, "type", "", "Node type to use, eg. `gpu_1x_a100`")
124123
boxCmd.Flags().StringVar(&config.NodeRegion, "region", "", "Region to use, eg. `us_west_2`")
125124

126125
rootCmd.AddCommand(boxCmd)
@@ -135,9 +134,14 @@ func init() {
135134
codeCmd.Flags().BoolVar(&config.CreateExec, "new", false, "Create a new")
136135
codeCmd.Flags().StringVarP(&config.BuildID, "image", "i", "", "Build ID of the container image to use")
137136
codeCmd.Flags().StringVar(&config.Provider, "provider", "", "Provider to use")
138-
codeCmd.Flags().StringVar(&config.NodeTypeID, "type", "", "Node type to use, eg. `gpu_1x_a100`")
139137
codeCmd.Flags().StringVar(&config.NodeRegion, "region", "", "Region to use, eg. `us_west_2`")
140138
codeCmd.Flags().StringVar(&config.SSHPrivateKeyPath, "prv", "", "Absolute Path to the private key to use")
139+
codeCmd.Flags().IntVar(&config.GPUs, "gpus", 0, "Number of GPUs to allocate for a gpuType, e.g., 2")
140+
codeCmd.Flags().IntVar(&config.GPUMemory, "gpu-mem", 0, "Memory of GPU if applicable for a gpuType, e.g., 12")
141+
codeCmd.Flags().StringVar(&config.GPUType, "gpu-type", "", "Type of GPU to use, e.g., rtx_5000")
142+
codeCmd.Flags().IntVar(&config.CPUs, "cpus", 0, "Number of VCPUs to allocate, e.g., 4")
143+
codeCmd.Flags().IntVar(&config.Memory, "mem", 0, "Amount of RAM to allocate in GB, e.g., 16")
144+
codeCmd.Flags().IntVar(&config.HDD, "hdd", 0, "Amount of hard-disk space to allocate in GB")
141145
rootCmd.AddCommand(codeCmd)
142146

143147
rootCmd.AddCommand(&cobra.Command{
@@ -218,8 +222,13 @@ func init() {
218222
}
219223
newCmd.Flags().StringVarP(&config.BuildID, "image", "i", "", "Build ID of the container image to use")
220224
newCmd.Flags().StringVar(&config.Provider, "provider", "", "Provider to use")
221-
newCmd.Flags().StringVar(&config.NodeTypeID, "type", "", "Node type to use, eg. `gpu_1x_a100`")
222225
newCmd.Flags().StringVar(&config.NodeRegion, "region", "", "Region to use, eg. `us_west_2`")
226+
newCmd.Flags().IntVar(&config.GPUs, "gpus", 0, "Number of GPUs to allocate for a gpuType, e.g., 2")
227+
newCmd.Flags().IntVar(&config.GPUMemory, "gpu-mem", 0, "Memory of GPU if applicable for a gpuType, e.g., 12")
228+
newCmd.Flags().StringVar(&config.GPUType, "gpu-type", "", "Type of GPU to use, e.g., rtx_5000")
229+
newCmd.Flags().IntVar(&config.CPUs, "cpus", 0, "Number of VCPUs to allocate, e.g., 4")
230+
newCmd.Flags().IntVar(&config.Memory, "mem", 0, "Amount of RAM to allocate in GB, e.g., 16")
231+
newCmd.Flags().IntVar(&config.HDD, "hdd", 0, "Amount of hard-disk space to allocate in GB")
223232
rootCmd.AddCommand(newCmd)
224233

225234
lsCmd := &cobra.Command{
@@ -257,9 +266,14 @@ func init() {
257266
sshCmd.Flags().BoolVar(&config.NoCopySource, "no-copy", false, "Do not copy source code to the session")
258267
sshCmd.Flags().StringVarP(&config.BuildID, "image", "i", "", "Build ID of the container image to use")
259268
sshCmd.Flags().StringVar(&config.Provider, "provider", "", "Provider to use")
260-
sshCmd.Flags().StringVar(&config.NodeTypeID, "type", "", "Node type to use, eg. `gpu_1x_a100`")
261269
sshCmd.Flags().StringVar(&config.NodeRegion, "region", "", "Region to use, eg. `us_west_2`")
262270
sshCmd.Flags().StringVar(&config.SSHPrivateKeyPath, "prv", "", "Absolute Path to the private key to use")
271+
sshCmd.Flags().IntVar(&config.GPUs, "gpus", 0, "Number of GPUs to allocate for a gpuType, e.g., 2")
272+
sshCmd.Flags().IntVar(&config.GPUMemory, "gpu-mem", 0, "Memory of GPU if applicable for a gpuType, e.g., 12")
273+
sshCmd.Flags().StringVar(&config.GPUType, "gpu-type", "", "Type of GPU to use, e.g., rtx_5000")
274+
sshCmd.Flags().IntVar(&config.CPUs, "cpus", 0, "Number of VCPUs to allocate, e.g., 4")
275+
sshCmd.Flags().IntVar(&config.Memory, "mem", 0, "Amount of RAM to allocate in GB, e.g., 16")
276+
sshCmd.Flags().IntVar(&config.HDD, "hdd", 0, "Amount of hard-disk space to allocate in GB")
263277
rootCmd.AddCommand(sshCmd)
264278

265279
// SSH Key commands

session/session.go

Lines changed: 89 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,102 @@ import (
1010
"github.com/unweave/unweave/api/types"
1111
)
1212

13-
// Create attempts to create a session using the node types provided
14-
// until the first successful creation. If none of the node types are successful, it
15-
// returns 503 out of capacity error.
16-
func Create(ctx context.Context, params types.ExecCreateParams, nodeTypeIDs []string) (string, error) {
17-
uwc := config.InitUnweaveClient()
13+
// Create attempts to create a session using the Exec spec provided, uses GPUs in the config if not, returns a 503 out-of-capacity error.
14+
// Renders newly created sessions to the UI implicitly.
15+
func Create(ctx context.Context, params types.ExecCreateParams) (string, error) {
16+
if params.HardwareSpec.GPU.Type == "" {
17+
exec, err := createSessionFromConfigGPUTypes(ctx, params)
18+
renderSessionCreated(exec)
1819

19-
var err error
20-
var exec *types.Exec
20+
return exec.ID, err
21+
}
2122

22-
for _, nodeTypeID := range nodeTypeIDs {
23-
params.NodeTypeID = nodeTypeID
24-
25-
owner, projectName := config.GetProjectOwnerAndName()
26-
exec, err = uwc.Exec.Create(ctx, owner, projectName, params)
27-
if err == nil {
28-
results := []ui.ResultEntry{
29-
{Key: "Name", Value: exec.Name},
30-
{Key: "ID", Value: exec.ID},
31-
{Key: "Provider", Value: exec.Provider.DisplayName()},
32-
{Key: "Type", Value: exec.NodeTypeID},
33-
{Key: "Region", Value: exec.Region},
34-
{Key: "Status", Value: fmt.Sprintf("%s", exec.Status)},
35-
{Key: "SSHKey", Value: fmt.Sprintf("%s", exec.SSHKey.Name)},
23+
exec, err := createSession(ctx, params, params.HardwareSpec.GPU.Type)
24+
if err != nil {
25+
var e *types.Error
26+
if errors.As(err, &e) {
27+
if err != nil {
28+
return "", err
3629
}
37-
38-
ui.ResultTitle("Session Created:")
39-
ui.Result(results, ui.IndentWidth)
40-
return exec.ID, nil
30+
} else {
31+
return "", err
4132
}
33+
}
34+
renderSessionCreated(exec)
35+
36+
return exec.ID, err
37+
}
4238

39+
func createSession(ctx context.Context, params types.ExecCreateParams, gpuType string) (*types.Exec, error) {
40+
uwc := config.InitUnweaveClient()
41+
owner, projectName := config.GetProjectOwnerAndName()
42+
43+
useParams := params
44+
useParams.HardwareSpec.GPU.Type = gpuType
45+
46+
exec, err := uwc.Exec.Create(ctx, owner, projectName, useParams)
47+
if err != nil {
48+
return nil, err
49+
}
50+
51+
return exec, nil
52+
}
53+
54+
func createSessionFromConfigGPUTypes(ctx context.Context, params types.ExecCreateParams) (*types.Exec, error) {
55+
gpuTypesFromConfig := gpuTypesFromConfig()
56+
var err error
57+
var exec *types.Exec
58+
for _, gpuType := range gpuTypesFromConfig {
59+
exec, err = createSession(ctx, params, gpuType)
4360
if err != nil {
44-
var e *types.Error
45-
if errors.As(err, &e) {
46-
// If error 503, it's mostly likely an out of capacity error. Continue to
47-
// next node type.
48-
if e.Code == 503 {
49-
continue
50-
}
51-
return "", err
61+
if isOutOfCapacityError(err) {
62+
continue
5263
}
64+
return nil, err
5365
}
66+
67+
return exec, nil
5468
}
55-
// Return the last error - which will be a 503 if it's an out of capacity error.
56-
return "", err
69+
70+
return nil, err
71+
}
72+
73+
func isOutOfCapacityError(err error) bool {
74+
var e *types.Error
75+
if errors.As(err, &e) && e.Code == 503 {
76+
return true
77+
}
78+
return false
79+
}
80+
81+
func gpuTypesFromConfig() []string {
82+
var gpuTypeIDs []string
83+
provider := config.Config.Project.DefaultProvider
84+
if config.Provider != "" {
85+
provider = config.Provider
86+
}
87+
if p, ok := config.Config.Project.Providers[provider]; ok {
88+
gpuTypeIDs = p.NodeTypes
89+
}
90+
return gpuTypeIDs
91+
}
92+
93+
func renderSessionCreated(exec *types.Exec) {
94+
if exec == nil {
95+
return
96+
}
97+
98+
results := []ui.ResultEntry{
99+
{Key: "Name", Value: exec.Name},
100+
{Key: "ID", Value: exec.ID},
101+
{Key: "Provider", Value: exec.Provider.DisplayName()},
102+
{Key: "Type", Value: exec.NodeTypeID},
103+
{Key: "Region", Value: exec.Region},
104+
{Key: "Status", Value: fmt.Sprintf("%s", exec.Status)},
105+
{Key: "SSHKey", Value: fmt.Sprintf("%s", exec.SSHKey.Name)},
106+
}
107+
108+
ui.ResultTitle("Session Created:")
109+
ui.Result(results, ui.IndentWidth)
110+
return
57111
}

0 commit comments

Comments
 (0)