193 lines
4.7 KiB
Go
193 lines
4.7 KiB
Go
package auth
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type User struct {
|
|
Sub string `json:"sub"`
|
|
}
|
|
|
|
type AuthentikUser struct {
|
|
User
|
|
Email string `json:"email"`
|
|
EmailVerified bool `json:"email_verified"`
|
|
Name string `json:"name"`
|
|
GivenName string `json:"given_name"`
|
|
PreferredUsername string `json:"preferred_username"`
|
|
Nickname string `json:"nickname"`
|
|
Groups []string `json:"groups"`
|
|
}
|
|
|
|
type Auth struct {
|
|
authConfig AuthConfig
|
|
clientConfig ClientConfig
|
|
}
|
|
|
|
type Token struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiredIn int `json:"expires_in"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
func NewAuthWithConfig(config ClientConfig, authConfig AuthConfig) (Auth, error) {
|
|
a := Auth{}
|
|
a.authConfig = authConfig
|
|
a.clientConfig = config
|
|
return a, nil
|
|
}
|
|
|
|
func NewAuthWithConfigurationURL(config ClientConfig, url string) (Auth, error) {
|
|
a := Auth{}
|
|
a.clientConfig = config
|
|
authConfig := AuthConfig{}
|
|
|
|
res, err := http.Get(url)
|
|
if err != nil {
|
|
return Auth{}, fmt.Errorf("%w: %q", ErrCantGetConfiguratorData, err)
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
bodyContent, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return Auth{}, fmt.Errorf("%w: %q", ErrCantGetConfiguratorData, err)
|
|
}
|
|
|
|
err = json.Unmarshal(bodyContent, &authConfig)
|
|
if err != nil {
|
|
return Auth{}, fmt.Errorf("%w: %q", ErrCantGetConfiguratorData, err)
|
|
}
|
|
|
|
a.authConfig = authConfig
|
|
return a, nil
|
|
}
|
|
|
|
func (a Auth) GetAuthorizationURL(state string) (string, error) {
|
|
if a.authConfig.AuthorizationEndpoint == "" {
|
|
return "", fmt.Errorf("%w: %s", ErrCantGetAuthorizationURL, "AuthorizationEndpoint in config is empty")
|
|
}
|
|
|
|
if a.clientConfig.ClientID == "" {
|
|
return "", fmt.Errorf("%w: %s", ErrCantGetAuthorizationURL, "clientid in config is empty")
|
|
}
|
|
|
|
url, err := url.Parse(a.authConfig.AuthorizationEndpoint)
|
|
if err != nil {
|
|
return "", fmt.Errorf("%w: %q", ErrCantGetAuthorizationURL, err)
|
|
}
|
|
|
|
values := url.Query()
|
|
|
|
values.Set("client_id", a.clientConfig.ClientID)
|
|
if a.clientConfig.RedirectURL != "" {
|
|
values.Set("redirect_uri", a.clientConfig.RedirectURL)
|
|
}
|
|
|
|
if len(a.clientConfig.Scope) > 0 {
|
|
values.Set("scope", strings.Join(a.clientConfig.Scope, "+"))
|
|
}
|
|
|
|
if state != "" {
|
|
values.Set("state", state)
|
|
}
|
|
|
|
values.Set("response_type", "code")
|
|
|
|
url.RawQuery = values.Encode()
|
|
|
|
return url.String(), nil
|
|
}
|
|
|
|
func (a Auth) GetUserInfo(accessToken string, user any) error {
|
|
req, err := http.NewRequest("GET", a.authConfig.UserinfoEndpoint, nil)
|
|
|
|
req.Header.Add("Authorization", "Bearer "+accessToken)
|
|
|
|
hc := http.Client{}
|
|
resp, err := hc.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %q", ErrCreateRequestForUserInfo, err)
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("%w: %s (%v)", ErrCantGetUserInfo, "server response with nuon 200 status code", resp.StatusCode)
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %q", ErrCreateRequestForUserInfo, err)
|
|
}
|
|
|
|
err = json.Unmarshal(body, &user)
|
|
if err != nil {
|
|
return fmt.Errorf("%w: %s: %q", ErrCantGetUserInfo, "json unmarshal", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a Auth) GetTokenFromCode(code string) (Token, error) {
|
|
form := url.Values{}
|
|
form.Add("grant_type", "authorization_code")
|
|
form.Add("code", code)
|
|
|
|
req, err := http.NewRequest("POST", a.authConfig.TokenEndpoint, strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return Token{}, fmt.Errorf("%w: %q", ErrCantCreateTokenRequests, err)
|
|
}
|
|
|
|
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
|
req.SetBasicAuth(a.clientConfig.ClientID, a.clientConfig.ClientSecret)
|
|
|
|
hc := http.Client{}
|
|
|
|
resp, err := hc.Do(req)
|
|
if err != nil {
|
|
return Token{}, fmt.Errorf("%w: %q", ErrCantSendRequestsForToken, err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return Token{}, fmt.Errorf("%w: %q", ErrCantSendRequestsForToken, err)
|
|
}
|
|
fmt.Println(string(body))
|
|
fmt.Println(resp.StatusCode)
|
|
|
|
if resp.StatusCode != 200 {
|
|
var er struct {
|
|
Error string `json:"error"`
|
|
ErrorDescription string `json:"error_description"`
|
|
}
|
|
|
|
err = json.Unmarshal(body, &er)
|
|
if err != nil {
|
|
return Token{}, fmt.Errorf("%w: %s", ErrWrongResponseFromServer, string(body))
|
|
}
|
|
if er.ErrorDescription != "" {
|
|
return Token{}, fmt.Errorf("%w: %s", ErrWrongResponseFromServer, er.ErrorDescription)
|
|
}
|
|
|
|
return Token{}, fmt.Errorf("%w: %s", ErrWrongResponseFromServer, string(body))
|
|
}
|
|
|
|
t := Token{}
|
|
t.CreatedAt = time.Now()
|
|
|
|
err = json.Unmarshal(body, &t)
|
|
if err != nil {
|
|
return Token{}, fmt.Errorf("%w: %q", ErrCantGetTokenForCode, err)
|
|
}
|
|
|
|
return t, nil
|
|
}
|