1package nervana
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "time"
11)
12
13type Client struct {
14 cli *http.Client
15 endpoint string
16 apiKey string
17}
18
19func NewClient(endpoint string, apiKey string) *Client {
20 return &Client{
21 cli: &http.Client{
22 Timeout: 5 * time.Second,
23 },
24 endpoint: endpoint,
25 apiKey: apiKey,
26 }
27}
28
29type NervanaItem struct {
30 Text string `json:"text"`
31 Label string `json:"label"`
32 EntityId string `json:"entityId"`
33 Description string `json:"description"`
34}
35
36func (c *Client) newRequest(ctx context.Context, text string) (*http.Request, error) {
37 payload := map[string]string{
38 "text": text,
39 "language": "en",
40 }
41
42 b, err := json.Marshal(payload)
43 if err != nil {
44 return nil, err
45 }
46
47 req, err := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(b))
48
49 req.Header.Set("Authorization", "Bearer "+c.apiKey)
50
51 return req, err
52}
53
54func (c *Client) MakeRequest(ctx context.Context, text string) ([]NervanaItem, error) {
55 req, err := c.newRequest(ctx, text)
56 if err != nil {
57 return nil, err
58 }
59
60 resp, err := c.cli.Do(req)
61 if err != nil {
62 return nil, err
63 }
64 defer resp.Body.Close()
65
66 if resp.StatusCode != 200 {
67 io.Copy(io.Discard, resp.Body)
68 return nil, fmt.Errorf("received non-200 response code: %d", resp.StatusCode)
69 }
70
71 var nervanaResp []NervanaItem
72 if err := json.NewDecoder(resp.Body).Decode(&nervanaResp); err != nil {
73 return nil, err
74 }
75
76 return nervanaResp, nil
77}