Spaces:
Configuration error
Configuration error
package e2e_test | |
import ( | |
"bytes" | |
"context" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net/http" | |
"os" | |
"github.com/mudler/LocalAI/core/schema" | |
. "github.com/onsi/ginkgo/v2" | |
. "github.com/onsi/gomega" | |
"github.com/sashabaranov/go-openai" | |
"github.com/sashabaranov/go-openai/jsonschema" | |
) | |
var _ = Describe("E2E test", func() { | |
Context("Generating", func() { | |
BeforeEach(func() { | |
// | |
}) | |
// Check that the GPU was used | |
AfterEach(func() { | |
// | |
}) | |
Context("text", func() { | |
It("correctly", func() { | |
model := "gpt-4" | |
resp, err := client.CreateChatCompletion(context.TODO(), | |
openai.ChatCompletionRequest{ | |
Model: model, Messages: []openai.ChatCompletionMessage{ | |
{ | |
Role: "user", | |
Content: "How much is 2+2?", | |
}, | |
}}) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four")), fmt.Sprint(resp.Choices[0].Message.Content)) | |
}) | |
}) | |
Context("function calls", func() { | |
It("correctly invoke", func() { | |
params := jsonschema.Definition{ | |
Type: jsonschema.Object, | |
Properties: map[string]jsonschema.Definition{ | |
"location": { | |
Type: jsonschema.String, | |
Description: "The city and state, e.g. San Francisco, CA", | |
}, | |
"unit": { | |
Type: jsonschema.String, | |
Enum: []string{"celsius", "fahrenheit"}, | |
}, | |
}, | |
Required: []string{"location"}, | |
} | |
f := openai.FunctionDefinition{ | |
Name: "get_current_weather", | |
Description: "Get the current weather in a given location", | |
Parameters: params, | |
} | |
t := openai.Tool{ | |
Type: openai.ToolTypeFunction, | |
Function: &f, | |
} | |
dialogue := []openai.ChatCompletionMessage{ | |
{Role: openai.ChatMessageRoleUser, Content: "What is the weather in Boston today?"}, | |
} | |
resp, err := client.CreateChatCompletion(context.TODO(), | |
openai.ChatCompletionRequest{ | |
Model: openai.GPT4, | |
Messages: dialogue, | |
Tools: []openai.Tool{t}, | |
}, | |
) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) | |
msg := resp.Choices[0].Message | |
Expect(len(msg.ToolCalls)).To(Equal(1), fmt.Sprint(msg.ToolCalls)) | |
Expect(msg.ToolCalls[0].Function.Name).To(Equal("get_current_weather"), fmt.Sprint(msg.ToolCalls[0].Function.Name)) | |
Expect(msg.ToolCalls[0].Function.Arguments).To(ContainSubstring("Boston"), fmt.Sprint(msg.ToolCalls[0].Function.Arguments)) | |
}) | |
}) | |
Context("json", func() { | |
It("correctly", func() { | |
model := "gpt-4" | |
req := openai.ChatCompletionRequest{ | |
ResponseFormat: &openai.ChatCompletionResponseFormat{Type: openai.ChatCompletionResponseFormatTypeJSONObject}, | |
Model: model, | |
Messages: []openai.ChatCompletionMessage{ | |
{ | |
Role: "user", | |
Content: "An animal with 'name', 'gender' and 'legs' fields", | |
}, | |
}, | |
} | |
resp, err := client.CreateChatCompletion(context.TODO(), req) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) | |
var i map[string]interface{} | |
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &i) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(i).To(HaveKey("name")) | |
Expect(i).To(HaveKey("gender")) | |
Expect(i).To(HaveKey("legs")) | |
}) | |
}) | |
Context("images", func() { | |
It("correctly", func() { | |
resp, err := client.CreateImage(context.TODO(), | |
openai.ImageRequest{ | |
Prompt: "test", | |
Size: openai.CreateImageSize512x512, | |
}, | |
) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) | |
}) | |
It("correctly changes the response format to url", func() { | |
resp, err := client.CreateImage(context.TODO(), | |
openai.ImageRequest{ | |
Prompt: "test", | |
Size: openai.CreateImageSize512x512, | |
ResponseFormat: openai.CreateImageResponseFormatURL, | |
}, | |
) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL)) | |
}) | |
It("correctly changes the response format to base64", func() { | |
resp, err := client.CreateImage(context.TODO(), | |
openai.ImageRequest{ | |
Prompt: "test", | |
Size: openai.CreateImageSize512x512, | |
ResponseFormat: openai.CreateImageResponseFormatB64JSON, | |
}, | |
) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON)) | |
}) | |
}) | |
Context("embeddings", func() { | |
It("correctly", func() { | |
resp, err := client.CreateEmbeddings(context.TODO(), | |
openai.EmbeddingRequestStrings{ | |
Input: []string{"doc"}, | |
Model: openai.AdaEmbeddingV2, | |
}, | |
) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Data[0].Embedding).ToNot(BeEmpty()) | |
}) | |
}) | |
Context("vision", func() { | |
It("correctly", func() { | |
model := "gpt-4o" | |
resp, err := client.CreateChatCompletion(context.TODO(), | |
openai.ChatCompletionRequest{ | |
Model: model, Messages: []openai.ChatCompletionMessage{ | |
{ | |
Role: "user", | |
MultiContent: []openai.ChatMessagePart{ | |
{ | |
Type: openai.ChatMessagePartTypeText, | |
Text: "What is in the image?", | |
}, | |
{ | |
Type: openai.ChatMessagePartTypeImageURL, | |
ImageURL: &openai.ChatMessageImageURL{ | |
URL: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", | |
Detail: openai.ImageURLDetailLow, | |
}, | |
}, | |
}, | |
}, | |
}}) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(len(resp.Choices)).To(Equal(1), fmt.Sprint(resp)) | |
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("wooden"), ContainSubstring("grass")), fmt.Sprint(resp.Choices[0].Message.Content)) | |
}) | |
}) | |
Context("text to audio", func() { | |
It("correctly", func() { | |
res, err := client.CreateSpeech(context.Background(), openai.CreateSpeechRequest{ | |
Model: openai.TTSModel1, | |
Input: "Hello!", | |
Voice: openai.VoiceAlloy, | |
}) | |
Expect(err).ToNot(HaveOccurred()) | |
defer res.Close() | |
_, err = io.ReadAll(res) | |
Expect(err).ToNot(HaveOccurred()) | |
}) | |
}) | |
Context("audio to text", func() { | |
It("correctly", func() { | |
downloadURL := "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" | |
file, err := downloadHttpFile(downloadURL) | |
Expect(err).ToNot(HaveOccurred()) | |
req := openai.AudioRequest{ | |
Model: openai.Whisper1, | |
FilePath: file, | |
} | |
resp, err := client.CreateTranscription(context.Background(), req) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(resp.Text).To(ContainSubstring("This is the"), fmt.Sprint(resp.Text)) | |
}) | |
}) | |
Context("reranker", func() { | |
It("correctly", func() { | |
modelName := "jina-reranker-v1-base-en" | |
req := schema.JINARerankRequest{ | |
Model: modelName, | |
Query: "Organic skincare products for sensitive skin", | |
Documents: []string{ | |
"Eco-friendly kitchenware for modern homes", | |
"Biodegradable cleaning supplies for eco-conscious consumers", | |
"Organic cotton baby clothes for sensitive skin", | |
"Natural organic skincare range for sensitive skin", | |
"Tech gadgets for smart homes: 2024 edition", | |
"Sustainable gardening tools and compost solutions", | |
"Sensitive skin-friendly facial cleansers and toners", | |
"Organic food wraps and storage solutions", | |
"All-natural pet food for dogs with allergies", | |
"Yoga mats made from recycled materials", | |
}, | |
TopN: 3, | |
} | |
serialized, err := json.Marshal(req) | |
Expect(err).To(BeNil()) | |
Expect(serialized).ToNot(BeNil()) | |
rerankerEndpoint := apiEndpoint + "/rerank" | |
resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) | |
Expect(err).To(BeNil()) | |
Expect(resp).ToNot(BeNil()) | |
body, err := io.ReadAll(resp.Body) | |
Expect(err).ToNot(HaveOccurred()) | |
Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp)) | |
deserializedResponse := schema.JINARerankResponse{} | |
err = json.Unmarshal(body, &deserializedResponse) | |
Expect(err).To(BeNil()) | |
Expect(deserializedResponse).ToNot(BeZero()) | |
Expect(deserializedResponse.Model).To(Equal(modelName)) | |
Expect(len(deserializedResponse.Results)).To(BeNumerically(">", 0)) | |
}) | |
}) | |
}) | |
}) | |
func downloadHttpFile(url string) (string, error) { | |
resp, err := http.Get(url) | |
if err != nil { | |
return "", err | |
} | |
defer resp.Body.Close() | |
tmpfile, err := os.CreateTemp("", "example") | |
if err != nil { | |
return "", err | |
} | |
defer tmpfile.Close() | |
_, err = io.Copy(tmpfile, resp.Body) | |
if err != nil { | |
return "", err | |
} | |
return tmpfile.Name(), nil | |
} | |