add suffix support to generate endpoint

this change is triggered by the presence of "suffix", particularly
useful for code completion tasks
This commit is contained in:
Michael Yang
2024-06-20 19:13:36 -07:00
parent 987dbab0b0
commit d290e87513
6 changed files with 155 additions and 27 deletions

View File

@@ -151,6 +151,8 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
Tools []api.Tool
Prompt string
Suffix string
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
@@ -204,7 +206,13 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,

View File

@@ -359,3 +359,38 @@ Answer: `,
})
}
}
func TestExecuteWithSuffix(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`)
if err != nil {
t.Fatal(err)
}
cases := []struct {
name string
values Values
expect string
}{
{
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
},
{
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var b bytes.Buffer
if err := tmpl.Execute(&b, tt.values); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}