diff --git a/worker/go.mod b/worker/go.mod index 19146cc..dd46262 100644 --- a/worker/go.mod +++ b/worker/go.mod @@ -18,6 +18,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/mocktools/go-smtp-mock/v2 v2.5.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/redis/go-redis/v9 v9.18.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect diff --git a/worker/go.sum b/worker/go.sum index d656f64..a862024 100644 --- a/worker/go.sum +++ b/worker/go.sum @@ -24,6 +24,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mocktools/go-smtp-mock/v2 v2.5.4 h1:U89Y4SuOhDFUfboMYUtXzWDp7hNLrofRa5yNqGSESSM= +github.com/mocktools/go-smtp-mock/v2 v2.5.4/go.mod h1:qBGjYXy5jKKVFhDnB39DYQfn4hWfcqWAlJTcvrku3rg= github.com/openai/openai-go/v3 v3.30.0 h1:T8VkhqAm6BuvxwpVG+Aw+H4TcYIsbj9nqytjpWcE/aU= github.com/openai/openai-go/v3 v3.30.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/worker/internal/services/notify.go b/worker/internal/services/notify.go index 376f4d0..99e108f 100644 --- a/worker/internal/services/notify.go +++ b/worker/internal/services/notify.go @@ -5,11 +5,11 @@ import ( "pkg/i18n" "pkg/models" u "pkg/utils" - "strconv" "strings" "github.com/go-resty/resty/v2" "github.com/samber/lo" + "github.com/spf13/cast" mail "github.com/wneessen/go-mail" "go.uber.org/zap" ) @@ -41,30 +41,40 @@ func SendWebhook(webhook models.NotifyWebhook) error { return nil } -func SendEmail(to string, shareInfo *models.RedisShareInfo, ip string) error { - host := u.GetEnv("smtp.host") - if host == "" { +type EmailTemplateData struct { + Locale string + IP string + ShareType models.ShareType + FileName string +} + +func SendEmail(to string, emailTemplateData EmailTemplateData, options ...mail.Option) error { + smtp := u.GetEnvMap("smtp") + if smtp["host"] == "" { zap.L().Warn("smtp host is empty, skip share notify email", zap.String("to", to)) return nil } - username := u.GetEnv("smtp.username") - password := u.GetEnv("smtp.password") - from := u.GetEnvWithDefault("smtp.from", username) - if from == "" { - return fmt.Errorf("smtp.from or smtp.username is required") + host := cast.ToString(smtp["host"]) + if host == "" { + return fmt.Errorf("smtp.host is required") } + username := cast.ToString(smtp["username"]) + if username == "" { + return fmt.Errorf("smtp.username is required") + } + port := lo.Ternary(cast.ToInt(smtp["port"]) != 0, cast.ToInt(smtp["port"]), mail.DefaultPortSSL) templateData := map[string]any{ - "IP": ip, + "IP": emailTemplateData.IP, "SiteURL": u.GetEnv("site.url"), - "ShareType": i18n.T(shareInfo.Locale, lo.Ternary(shareInfo.Type == models.ShareTypeText, "share_type_text", "share_type_file")), - "FileName": shareInfo.FileName, + "ShareType": i18n.T(emailTemplateData.Locale, lo.Ternary(emailTemplateData.ShareType == models.ShareTypeText, "share_type_text", "share_type_file")), + "FileName": emailTemplateData.FileName, } - subject := i18n.TWithData(shareInfo.Locale, "notify_email_subject", templateData) - body := i18n.TWithData(shareInfo.Locale, "notify_email_body", templateData) + subject := i18n.TWithData(emailTemplateData.Locale, "notify_email_subject", templateData) + body := i18n.TWithData(emailTemplateData.Locale, "notify_email_body", templateData) message := mail.NewMsg() - if err := message.From(from); err != nil { + if err := message.From(username); err != nil { return err } if err := message.To(to); err != nil { @@ -73,21 +83,22 @@ func SendEmail(to string, shareInfo *models.RedisShareInfo, ip string) error { message.Subject(subject) message.SetBodyString(mail.TypeTextPlain, body) - port, err := strconv.Atoi(u.GetEnvWithDefault("smtp.port", "587")) - if err != nil { - return err + options = append([]mail.Option{ + mail.WithPort(port), + mail.WithUsername(username), + mail.WithSMTPAuth(mail.SMTPAuthAutoDiscover), + }, options...) + + password := cast.ToString(smtp["password"]) + if password != "" { + options = append(options, mail.WithPassword(password)) } - options := []mail.Option{ - mail.WithPort(port), - } - if port == mail.DefaultPortSSL { + if cast.ToString(smtp["protocol"]) == "ssl" { options = append(options, mail.WithSSL()) - } else { - options = append(options, mail.WithTLSPortPolicy(mail.TLSMandatory)) } - if username != "" { - options = append(options, mail.WithUsername(username), mail.WithPassword(password), mail.WithSMTPAuth(mail.SMTPAuthAutoDiscover)) + if cast.ToString(smtp["protocol"]) == "tls" { + options = append(options, mail.WithTLSPortPolicy(mail.TLSMandatory)) } client, err := mail.NewClient(host, options...) diff --git a/worker/internal/tasks/share.go b/worker/internal/tasks/share.go index 20c4a5d..9bf0e0f 100644 --- a/worker/internal/tasks/share.go +++ b/worker/internal/tasks/share.go @@ -78,7 +78,12 @@ func ShareNotify(ctx context.Context, task *asynq.Task) error { } for _, email := range shareInfo.NotifyEmails { - if err := services.SendEmail(email, shareInfo, payload.IP); err != nil { + if err := services.SendEmail(email, services.EmailTemplateData{ + Locale: shareInfo.Locale, + ShareType: shareInfo.Type, + FileName: shareInfo.FileName, + IP: payload.IP, + }); err != nil { errs = append(errs, err) continue } diff --git a/worker/test/services/notify_test.go b/worker/test/services/notify_test.go new file mode 100644 index 0000000..68d790a --- /dev/null +++ b/worker/test/services/notify_test.go @@ -0,0 +1,162 @@ +package services + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + + "pkg/i18n" + "pkg/models" + "pkg/utils" + "worker/internal/services" + + smtpmock "github.com/mocktools/go-smtp-mock/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wneessen/go-mail" +) + +func TestMain(m *testing.M) { + utils.InitTestViper(utils.EnvOption{ + ConfigType: []string{"yaml"}, + ConfigData: strings.NewReader(""), + }) + if err := i18n.Init(); err != nil { + panic(err) + } + os.Exit(m.Run()) +} + +// ============= SendWebhook ============= + +func TestSendWebhook_DefaultMethodIsPost(t *testing.T) { + var gotMethod string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{URL: ts.URL}) + require.NoError(t, err) + assert.Equal(t, "POST", gotMethod) +} + +func TestSendWebhook_CustomMethod(t *testing.T) { + var gotMethod string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{URL: ts.URL, Method: " put "}) + require.NoError(t, err) + assert.Equal(t, "PUT", gotMethod) +} + +func TestSendWebhook_CustomHeaders(t *testing.T) { + var gotHeader string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("X-Custom-Token") + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{ + URL: ts.URL, + Headers: map[string]string{"X-Custom-Token": "secret123"}, + }) + require.NoError(t, err) + assert.Equal(t, "secret123", gotHeader) +} + +func TestSendWebhook_FormDataSetsContentType(t *testing.T) { + var gotContentType string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotContentType = r.Header.Get("Content-Type") + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{ + URL: ts.URL, + BodyType: "form-data", + Body: "key=value", + }) + require.NoError(t, err) + assert.Equal(t, "application/x-www-form-urlencoded", gotContentType) +} + +func TestSendWebhook_BodyNoneSkipsBody(t *testing.T) { + var gotBodyLen int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + gotBodyLen = len(b) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{ + URL: ts.URL, + BodyType: "none", + Body: "should-not-be-sent", + }) + require.NoError(t, err) + assert.Zero(t, gotBodyLen) +} + +func TestSendWebhook_4xxReturnsError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{URL: ts.URL}) + assert.Error(t, err) +} + +func TestSendWebhook_5xxReturnsError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + + err := services.SendWebhook(models.NotifyWebhook{URL: ts.URL}) + assert.Error(t, err) +} + +func TestSendWebhook_InvalidURLReturnsError(t *testing.T) { + err := services.SendWebhook(models.NotifyWebhook{URL: "http://127.0.0.1:1"}) + assert.Error(t, err) +} + +// ============= SendEmail ============= + +func TestSendEmail_HappyPath(t *testing.T) { + utils.SetEnv("smtp.host", "localhost") + utils.SetEnv("smtp.username", "sender@example.com") + t.Cleanup(func() { utils.SetEnv("smtp.host", "") }) + server := smtpmock.New(smtpmock.ConfigurationAttr{ + LogToStdout: true, + LogServerActivity: true, + }) + require.NoError(t, server.Start()) + t.Cleanup(func() { _ = server.Stop() }) + + host, port := "127.0.0.1", server.PortNumber() + utils.SetEnv("smtp.host", host) + utils.SetEnv("smtp.port", fmt.Sprintf("%d", port)) + + err := services.SendEmail("recipient@example.com", services.EmailTemplateData{ + Locale: "en", + ShareType: models.ShareTypeText, + FileName: "report.pdf", + IP: "1.2.3.4", + }, mail.WithHELO("localhost"), mail.WithTLSPolicy(mail.NoTLS), mail.WithSMTPAuth(mail.SMTPAuthNoAuth)) + require.NoError(t, err) +}