chore(sec): unify usage of crypto/rand.Read
(#7453)
- Unify the usage of [`crypto/rand.Read`](https://pkg.go.dev/crypto/rand#Read) to `util.CryptoRandomBytes`. - Refactor `util.CryptoRandomBytes` to never return an error. It is documented by Go, https://go.dev/issue/66821, to always succeed. So if we still receive a error or if the returned bytes read is not equal to the expected bytes to be read we panic (just to be on the safe side). - This simplifies a lot of code to no longer care about error handling. Reviewed-on: https://codeberg.org/forgejo/forgejo/pulls/7453 Reviewed-by: Earl Warren <earl-warren@noreply.codeberg.org> Co-authored-by: Gusted <postmaster@gusted.xyz> Co-committed-by: Gusted <postmaster@gusted.xyz>
This commit is contained in:
parent
99fc04b763
commit
53df0bf9a4
25 changed files with 61 additions and 163 deletions
|
@ -5,10 +5,8 @@
|
|||
package generate
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"forgejo.org/modules/util"
|
||||
|
@ -18,18 +16,11 @@ import (
|
|||
|
||||
// NewInternalToken generate a new value intended to be used by INTERNAL_TOKEN.
|
||||
func NewInternalToken() (string, error) {
|
||||
secretBytes := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, secretBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
secretKey := base64.RawURLEncoding.EncodeToString(secretBytes)
|
||||
secretKey := base64.RawURLEncoding.EncodeToString(util.CryptoRandomBytes(32))
|
||||
|
||||
now := time.Now()
|
||||
|
||||
var internalToken string
|
||||
internalToken, err = jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
internalToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"nbf": now.Unix(),
|
||||
}).SignedString([]byte(secretKey))
|
||||
if err != nil {
|
||||
|
@ -54,14 +45,9 @@ func DecodeJwtSecret(src string) ([]byte, error) {
|
|||
}
|
||||
|
||||
// NewJwtSecret generates a new base64 encoded value intended to be used for JWT secrets.
|
||||
func NewJwtSecret() ([]byte, string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return bytes, base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
func NewJwtSecret() ([]byte, string) {
|
||||
bytes := util.CryptoRandomBytes(32)
|
||||
return bytes, base64.RawURLEncoding.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// NewSecretKey generate a new value intended to be used by SECRET_KEY.
|
||||
|
|
|
@ -26,8 +26,7 @@ func TestDecodeJwtSecret(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewJwtSecret(t *testing.T) {
|
||||
secret, encoded, err := NewJwtSecret()
|
||||
require.NoError(t, err)
|
||||
secret, encoded := NewJwtSecret()
|
||||
assert.Len(t, secret, 32)
|
||||
decoded, err := DecodeJwtSecret(encoded)
|
||||
require.NoError(t, err)
|
||||
|
|
|
@ -17,10 +17,11 @@ package keying
|
|||
|
||||
import (
|
||||
"crypto/hkdf"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
|
||||
"forgejo.org/modules/util"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
|
@ -95,10 +96,7 @@ func (k *Key) Encrypt(plaintext, additionalData []byte) []byte {
|
|||
}
|
||||
|
||||
// Generate a random nonce.
|
||||
nonce := make([]byte, aeadNonceSize)
|
||||
if n, err := rand.Read(nonce); err != nil || n != aeadNonceSize {
|
||||
panic(err)
|
||||
}
|
||||
nonce := util.CryptoRandomBytes(aeadNonceSize)
|
||||
|
||||
// Returns the ciphertext of this plaintext.
|
||||
return e.Seal(nonce, nonce, plaintext, additionalData)
|
||||
|
|
|
@ -80,10 +80,7 @@ func loadLFSFrom(rootCfg ConfigProvider) error {
|
|||
jwtSecretBase64 := loadSecret(rootCfg.Section("server"), "LFS_JWT_SECRET_URI", "LFS_JWT_SECRET")
|
||||
LFS.JWTSecretBytes, err = generate.DecodeJwtSecret(jwtSecretBase64)
|
||||
if err != nil {
|
||||
LFS.JWTSecretBytes, jwtSecretBase64, err = generate.NewJwtSecret()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating JWT Secret for custom config: %v", err)
|
||||
}
|
||||
LFS.JWTSecretBytes, jwtSecretBase64 = generate.NewJwtSecret()
|
||||
|
||||
// Save secret
|
||||
saveCfg, err := rootCfg.PrepareSaving()
|
||||
|
|
|
@ -138,10 +138,7 @@ func loadOAuth2From(rootCfg ConfigProvider) {
|
|||
if InstallLock {
|
||||
jwtSecretBytes, err := generate.DecodeJwtSecret(jwtSecretBase64)
|
||||
if err != nil {
|
||||
jwtSecretBytes, jwtSecretBase64, err = generate.NewJwtSecret()
|
||||
if err != nil {
|
||||
log.Fatal("error generating JWT secret: %v", err)
|
||||
}
|
||||
jwtSecretBytes, jwtSecretBase64 = generate.NewJwtSecret()
|
||||
saveCfg, err := rootCfg.PrepareSaving()
|
||||
if err != nil {
|
||||
log.Fatal("save oauth2.JWT_SECRET failed: %v", err)
|
||||
|
@ -161,10 +158,7 @@ var generalSigningSecret atomic.Pointer[[]byte]
|
|||
func GetGeneralTokenSigningSecret() []byte {
|
||||
old := generalSigningSecret.Load()
|
||||
if old == nil || len(*old) == 0 {
|
||||
jwtSecret, _, err := generate.NewJwtSecret()
|
||||
if err != nil {
|
||||
log.Fatal("Unable to generate general JWT secret: %v", err)
|
||||
}
|
||||
jwtSecret, _ := generate.NewJwtSecret()
|
||||
if generalSigningSecret.CompareAndSwap(old, &jwtSecret) {
|
||||
return jwtSecret
|
||||
}
|
||||
|
|
|
@ -88,10 +88,16 @@ func CryptoRandomString(length int64) (string, error) {
|
|||
// CryptoRandomBytes generates `length` crypto bytes
|
||||
// This differs from CryptoRandomString, as each byte in CryptoRandomString is generated by [0,61] range
|
||||
// This function generates totally random bytes, each byte is generated by [0,255] range
|
||||
func CryptoRandomBytes(length int64) ([]byte, error) {
|
||||
func CryptoRandomBytes(length int64) []byte {
|
||||
// crypto/rand.Read is documented to never return a error.
|
||||
// https://go.dev/issue/66821
|
||||
buf := make([]byte, length)
|
||||
_, err := rand.Read(buf)
|
||||
return buf, err
|
||||
n, err := rand.Read(buf)
|
||||
if err != nil || n != int(length) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
// ToUpperASCII returns s with all ASCII letters mapped to their upper case.
|
||||
|
|
|
@ -163,20 +163,18 @@ func Test_RandomString(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_RandomBytes(t *testing.T) {
|
||||
bytes1, err := util.CryptoRandomBytes(32)
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes2, err := util.CryptoRandomBytes(32)
|
||||
require.NoError(t, err)
|
||||
bytes1 := util.CryptoRandomBytes(32)
|
||||
bytes2 := util.CryptoRandomBytes(32)
|
||||
|
||||
assert.Len(t, bytes1, 32)
|
||||
assert.Len(t, bytes2, 32)
|
||||
assert.NotEqual(t, bytes1, bytes2)
|
||||
|
||||
bytes3, err := util.CryptoRandomBytes(256)
|
||||
require.NoError(t, err)
|
||||
|
||||
bytes4, err := util.CryptoRandomBytes(256)
|
||||
require.NoError(t, err)
|
||||
bytes3 := util.CryptoRandomBytes(256)
|
||||
bytes4 := util.CryptoRandomBytes(256)
|
||||
|
||||
assert.Len(t, bytes3, 256)
|
||||
assert.Len(t, bytes4, 256)
|
||||
assert.NotEqual(t, bytes3, bytes4)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue