Skip to content

Commit 84f9c7e

Browse files
authored
feat!: support multiple generative models (#29)
1 parent de16134 commit 84f9c7e

8 files changed

Lines changed: 148 additions & 24 deletions

File tree

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ If you don't already have one, create a key in [Google AI Studio](https://makers
3737
The system chat message must begin with an exclamation mark and is used for internal operations.
3838
A short list of supported system commands:
3939

40-
| Command | Description
41-
| --- | ---
42-
| !q | Quit the application
43-
| !p | Delete the history used as chat context by the model
44-
| !m | Toggle input mode (single-line <-> multi-line)
40+
| Command | Description |
41+
|---------|------------------------------------------------------|
42+
| !q | Quit the application |
43+
| !p | Delete the history used as chat context by the model |
44+
| !i | Toggle input mode (single-line <-> multi-line) |
45+
| !m | Select generative model |
4546

4647
### CLI help
4748
```console
@@ -54,7 +55,8 @@ Usage:
5455
Flags:
5556
-f, --format render markdown-formatted response (default true)
5657
-h, --help help for this command
57-
-m, --multiline read input as a multi-line string
58+
-m, --model string generative model name (default "gemini-pro")
59+
--multiline read input as a multi-line string
5860
-s, --style string markdown format style (ascii, dark, light, pink, notty, dracula) (default "auto")
5961
-t, --term string multi-line input terminator (default "$")
6062
-v, --version version for this command

cmd/gemini/main.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ func run() int {
2222
}
2323

2424
var opts cli.ChatOpts
25+
rootCmd.Flags().StringVarP(&opts.Model, "model", "m", gemini.DefaultModel, "generative model name")
2526
rootCmd.Flags().BoolVarP(&opts.Format, "format", "f", true, "render markdown-formatted response")
2627
rootCmd.Flags().StringVarP(&opts.Style, "style", "s", "auto",
2728
"markdown format style (ascii, dark, light, pink, notty, dracula)")
28-
rootCmd.Flags().BoolVarP(&opts.Multiline, "multiline", "m", false, "read input as a multi-line string")
29+
rootCmd.Flags().BoolVar(&opts.Multiline, "multiline", false, "read input as a multi-line string")
2930
rootCmd.Flags().StringVarP(&opts.Terminator, "term", "t", "$", "multi-line input terminator")
3031

3132
rootCmd.RunE = func(_ *cobra.Command, _ []string) error {
3233
apiKey := os.Getenv(apiKeyEnv)
33-
chatSession, err := gemini.NewChatSession(context.Background(), apiKey)
34+
chatSession, err := gemini.NewChatSession(context.Background(), opts.Model, apiKey)
3435
if err != nil {
3536
return err
3637
}

gemini/chat_session.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,36 @@ package gemini
22

33
import (
44
"context"
5+
"sync"
56

67
"github.com/google/generative-ai-go/genai"
78
"google.golang.org/api/option"
89
)
910

10-
// ChatSession represents a gemini-pro powered chat session.
11+
const DefaultModel = "gemini-pro"
12+
13+
// ChatSession represents a gemini powered chat session.
1114
type ChatSession struct {
12-
ctx context.Context
15+
ctx context.Context
16+
1317
client *genai.Client
1418
session *genai.ChatSession
19+
20+
loadModels sync.Once
21+
models []string
1522
}
1623

17-
// NewChatSession returns a new ChatSession.
18-
func NewChatSession(ctx context.Context, apiKey string) (*ChatSession, error) {
24+
// NewChatSession returns a new [ChatSession].
25+
func NewChatSession(ctx context.Context, model, apiKey string) (*ChatSession, error) {
1926
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
2027
if err != nil {
2128
return nil, err
2229
}
30+
2331
return &ChatSession{
2432
ctx: ctx,
2533
client: client,
26-
session: client.GenerativeModel("gemini-pro").StartChat(),
34+
session: client.GenerativeModel(model).StartChat(),
2735
}, nil
2836
}
2937

@@ -37,12 +45,36 @@ func (c *ChatSession) SendMessageStream(input string) *genai.GenerateContentResp
3745
return c.session.SendMessageStream(c.ctx, genai.Text(input))
3846
}
3947

48+
// SetGenerativeModel sets the name of the generative model for the chat.
49+
// It preserves the history from the previous chat session.
50+
func (c *ChatSession) SetGenerativeModel(model string) {
51+
history := c.session.History
52+
c.session = c.client.GenerativeModel(model).StartChat()
53+
c.session.History = history
54+
}
55+
56+
// ListModels returns a list of the supported generative model names.
57+
func (c *ChatSession) ListModels() []string {
58+
c.loadModels.Do(func() {
59+
c.models = []string{DefaultModel}
60+
iter := c.client.ListModels(c.ctx)
61+
for {
62+
modelInfo, err := iter.Next()
63+
if err != nil {
64+
break
65+
}
66+
c.models = append(c.models, modelInfo.Name)
67+
}
68+
})
69+
return c.models
70+
}
71+
4072
// ClearHistory clears chat history.
4173
func (c *ChatSession) ClearHistory() {
4274
c.session.History = make([]*genai.Content, 0)
4375
}
4476

45-
// Close closes the genai.Client.
77+
// Close closes the chat session.
4678
func (c *ChatSession) Close() error {
4779
return c.client.Close()
4880
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ require (
66
github.com/charmbracelet/glamour v0.7.0
77
github.com/chzyer/readline v1.5.1
88
github.com/google/generative-ai-go v0.18.0
9+
github.com/manifoldco/promptui v0.9.0
910
github.com/muesli/termenv v0.15.2
1011
github.com/spf13/cobra v1.8.1
1112
google.golang.org/api v0.196.0

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,13 @@ github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd3
2525
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
2626
github.com/charmbracelet/glamour v0.7.0 h1:2BtKGZ4iVJCDfMF229EzbeR1QRKLWztO9dMtjmqZSng=
2727
github.com/charmbracelet/glamour v0.7.0/go.mod h1:jUMh5MeihljJPQbJ/wf4ldw2+yBP59+ctV36jASy7ps=
28+
github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI=
2829
github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
2930
github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
31+
github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI=
3032
github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
3133
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
34+
github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU=
3235
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
3336
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
3437
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
@@ -93,6 +96,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
9396
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
9497
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
9598
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
99+
github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA=
100+
github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg=
96101
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
97102
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
98103
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -169,6 +174,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
169174
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
170175
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
171176
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
177+
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
172178
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
173179
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
174180
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

internal/cli/chat.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
// ChatOpts represents Chat configuration options.
1313
type ChatOpts struct {
14+
Model string
1415
Format bool
1516
Style string
1617
Multiline bool

internal/cli/command.go

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ const (
1717
systemCmdPrefix = "!"
1818
systemCmdQuit = "!q"
1919
systemCmdPurgeHistory = "!p"
20-
systemCmdToggleInputMode = "!m"
20+
systemCmdSelectInputMode = "!i"
21+
systemCmdSelectModel = "!m"
2122
)
2223

2324
type command interface {
@@ -44,18 +45,39 @@ func (c *systemCommand) run(message string) bool {
4445
case systemCmdPurgeHistory:
4546
c.chat.model.ClearHistory()
4647
c.print("Cleared the chat history.")
47-
case systemCmdToggleInputMode:
48+
case systemCmdSelectInputMode:
49+
multiline, err := selectInputMode(c.chat.opts.Multiline)
50+
if err != nil {
51+
c.error(err)
52+
break
53+
}
54+
if multiline == c.chat.opts.Multiline {
55+
c.printSelectedCurrent()
56+
break
57+
}
58+
c.chat.opts.Multiline = multiline
4859
if c.chat.opts.Multiline {
49-
c.print("Switched to single-line input mode.")
50-
c.chat.reader.HistoryEnable()
51-
c.chat.opts.Multiline = false
52-
} else {
5360
c.print("Switched to multi-line input mode.")
5461
// disable history for multi-line messages since it is
5562
// unusable for future requests
5663
c.chat.reader.HistoryDisable()
57-
c.chat.opts.Multiline = true
64+
} else {
65+
c.print("Switched to single-line input mode.")
66+
c.chat.reader.HistoryEnable()
67+
}
68+
case systemCmdSelectModel:
69+
model, err := selectModel(c.chat.opts.Model, c.chat.model.ListModels())
70+
if err != nil {
71+
c.error(err)
72+
break
73+
}
74+
if model == c.chat.opts.Model {
75+
c.printSelectedCurrent()
76+
break
5877
}
78+
c.chat.opts.Model = model
79+
c.chat.model.SetGenerativeModel(model)
80+
c.print(fmt.Sprintf("Selected '%s' generative model.", model))
5981
default:
6082
c.print("Unknown system command.")
6183
}
@@ -66,6 +88,14 @@ func (c *systemCommand) print(message string) {
6688
fmt.Printf("%s%s\n", c.chat.prompt.cli, message)
6789
}
6890

91+
func (c *systemCommand) printSelectedCurrent() {
92+
fmt.Printf("%sThe selection is unchanged.\n", c.chat.prompt.cli)
93+
}
94+
95+
func (c *systemCommand) error(err error) {
96+
fmt.Printf(color.Red("%s%s\n"), c.chat.prompt.cli, err)
97+
}
98+
6999
type geminiCommand struct {
70100
chat *Chat
71101
spinner *spinner
@@ -104,7 +134,7 @@ func (c *geminiCommand) runBlocking(message string) {
104134
var buf strings.Builder
105135
for _, candidate := range response.Candidates {
106136
for _, part := range candidate.Content.Parts {
107-
buf.WriteString(fmt.Sprintf("%s", part))
137+
fmt.Fprintf(&buf, "%s", part)
108138
}
109139
}
110140
output, err := glamour.Render(buf.String(), c.chat.opts.Style)
@@ -138,6 +168,6 @@ func (c *geminiCommand) runStreaming(message string) {
138168
}
139169

140170
func (c *geminiCommand) printFlush(message string) {
141-
fmt.Fprintf(c.writer, "%s", message)
142-
c.writer.Flush()
171+
_, _ = c.writer.WriteString(message)
172+
_ = c.writer.Flush()
143173
}

internal/cli/select.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package cli
2+
3+
import (
4+
"slices"
5+
6+
"github.com/manifoldco/promptui"
7+
)
8+
9+
var (
10+
inputMode = []string{"single-line", "multi-line"}
11+
)
12+
13+
// selectModel returns the selected generative model name.
14+
func selectModel(current string, models []string) (string, error) {
15+
prompt := promptui.Select{
16+
Label: "Select generative model",
17+
HideSelected: true,
18+
Items: models,
19+
CursorPos: slices.Index(models, current),
20+
}
21+
22+
_, result, err := prompt.Run()
23+
if err != nil {
24+
return "", err
25+
}
26+
27+
return result, nil
28+
}
29+
30+
// selectInputMode returns true if multiline input is selected;
31+
// otherwise, it returns false.
32+
func selectInputMode(multiline bool) (bool, error) {
33+
var cursorPos int
34+
if multiline {
35+
cursorPos = 1
36+
}
37+
38+
prompt := promptui.Select{
39+
Label: "Select input mode",
40+
HideSelected: true,
41+
Items: inputMode,
42+
CursorPos: cursorPos,
43+
}
44+
45+
_, result, err := prompt.Run()
46+
if err != nil {
47+
return false, err
48+
}
49+
50+
return result == inputMode[1], nil
51+
}

0 commit comments

Comments
 (0)