diff --git a/crockford32/crockford32.go b/crockford32/crockford32.go index d19053a..6d66f19 100644 --- a/crockford32/crockford32.go +++ b/crockford32/crockford32.go @@ -1,15 +1,18 @@ package crockford32 -import "slices" +import ( + "fmt" + "slices" +) -// CrockfordEncode takes a byte array and encodes it into a character +// Encode takes a byte array and encodes it into a character // string according to Crockford's base 32 encoding. Every 5-bits // corresponds to a character. This specific implementation uses Big // Endian byte order to fit the ULID spec ("network byte ordering") // // https://www.crockford.com/base32.html // https://git.wisellama.rocks/Mirrors/ulid-spec -func CrockfordEncode(bytes []byte) string { +func Encode(bytes []byte) string { // Crockford is a base 32 encoding. // 2^5 = 32, so every 5 bits will give us a character. // @@ -97,9 +100,80 @@ func CrockfordEncode(bytes []byte) string { output := make([]rune, 0, len(intList)) for _, i := range intList { lookup := i & 0b11111 - character := crockfordEncodeMap[lookup] + character := encodeMap[lookup] output = append(output, character) } return string(output) } + +// Decode takes in a Crockford base-32 encoded string and parses the +// bytes out of it. Each character corresponds to 5 bits. +// +// The string may contain optional dashes to make it more readable, +// similar to a UUID. +// +// This implementation does not handle the additional check symbols +// mentioned in the spec. It intentionally ignores them. +func Decode(input string) ([]byte, error) { + // Decode each character into an 8-bit uint + intList := make([]uint8, 0, len(input)) + for _, r := range input { + _, ignored := ignoredSymbols[r] + if ignored { + continue + } + + value, ok := decodeMap[r] + if !ok { + return nil, fmt.Errorf("invalid character: %c", r) + } + + intList = append(intList, value) + } + + slices.Reverse(intList) + + // Go right to left across the bits placing each 5-bit chunk into + // the byte array + output := make([]byte, 0, len(input)) + bitsRemaining := uint8(8) + currentByte := byte(0) + for _, value := range intList { + mask := uint8(0b11111) + + // We have all the space we need + if bitsRemaining > 5 { + // Just use what we need + shift := 8 - bitsRemaining + shifted := value << shift + currentByte = currentByte | shifted + bitsRemaining -= 5 + } else { + // Take our remaining bits and them fill in the rest from the next value + oldV := value + + // Shift this value up to fill in the remainder of this byte + oldV = oldV & (mask >> (5 - bitsRemaining)) + shift := 8 - bitsRemaining + oldV = oldV << shift + currentByte = currentByte | oldV + + // Start the next byte with the remainder of the this value + output = append(output, currentByte) + currentByte = byte(0) + + value = value >> bitsRemaining + currentByte = currentByte | value + bitsUsed := 5 - bitsRemaining + bitsRemaining = 8 - bitsUsed + } + } + + if currentByte != 0 { + output = append(output, currentByte) + } + + slices.Reverse(output) + return output, nil +} diff --git a/crockford32/crockford32_maps.go b/crockford32/crockford32_maps.go index 1692df3..645c64f 100644 --- a/crockford32/crockford32_maps.go +++ b/crockford32/crockford32_maps.go @@ -1,9 +1,9 @@ package crockford32 var ( - // crockfordEncodeMap takes binary values and converts them to + // encodeMap takes binary values and converts them to // characters in Crockford base32. - crockfordEncodeMap = map[uint8]rune{ + encodeMap = map[uint8]rune{ 0: '0', 1: '1', 2: '2', @@ -38,9 +38,9 @@ var ( 31: 'Z', } - // crockfordDecodeMap takes characters and converts them to binary + // decodeMap takes characters and converts them to binary // values based on Crockford base32. - crockfordDecodeMap = map[rune]uint8{ + decodeMap = map[rune]uint8{ '0': 0, 'O': 0, 'o': 0, @@ -102,4 +102,13 @@ var ( 'Z': 31, 'z': 31, } + + ignoredSymbols = map[rune]struct{}{ + '-': {}, + '*': {}, + '$': {}, + '=': {}, + 'u': {}, + 'U': {}, + } ) diff --git a/crockford32/crockford32_test.go b/crockford32/crockford32_test.go index f4e3775..5eec015 100644 --- a/crockford32/crockford32_test.go +++ b/crockford32/crockford32_test.go @@ -1,10 +1,11 @@ package crockford32 import ( + "fmt" "testing" ) -func TestCrockfordEncode(t *testing.T) { +func TestEncode(t *testing.T) { runAll := true type testData struct { @@ -58,10 +59,97 @@ func TestCrockfordEncode(t *testing.T) { if !test.RunIt { t.SkipNow() } - output := CrockfordEncode(test.Input) + output := Encode(test.Input) if test.Expected != output { t.Errorf("expected %v, got %v", test.Expected, output) } }) } } + +func TestDecode(t *testing.T) { + runAll := true + + type testData struct { + TestName string + RunIt bool + Input string + Expected []byte + Error bool + } + + tests := []testData{ + { + TestName: "nil", + RunIt: false || runAll, + Input: "", + Expected: []byte{}, + Error: false, + }, + { + TestName: "invalid character", + RunIt: false || runAll, + Input: "ABCDF!@#$%", + Expected: nil, + Error: true, + }, + { + TestName: "1 byte", + RunIt: false || runAll, + Input: "0Z", + Expected: []byte{0b11111}, + Error: false, + }, + { + TestName: "2 bytes", + RunIt: false || runAll, + Input: "00ZZ", + Expected: []byte{0b00000011, 0b11111111}, + Error: false, + }, + { + TestName: "valid ulid", + RunIt: false || runAll, + Input: "01HPS3K5JR06AFVGQT5ZYC0GEK", + Expected: []byte{1, 141, 178, 57, 150, 88, 1, 148, 253, 194, 250, 47, 252, 192, 65, 211}, + Error: false, + }, + { + TestName: "quick brown fox", + RunIt: false || runAll, + Input: "1A6GS90E5TPJRVB41H74VVQDRG6CVVR41N7AVBGECG6YXK5E8G78T3541P62YKS41J6YSSE", + Expected: []byte("The quick brown fox jumps over the lazy dog."), + Error: false, + }, + { + TestName: "max ULID", + RunIt: false || runAll, + Input: "7ZZZZZZZZZZZZZZZZZZZZZZZZZ", + Expected: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + Error: false, + }, + } + + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + if !test.RunIt { + t.SkipNow() + } + output, err := Decode(test.Input) + if test.Error && err == nil { + t.Errorf("expected an error") + } else { + msg := fmt.Sprintf("expected %b, got %b", test.Expected, output) + if len(test.Expected) != len(output) { + t.Fatalf(msg) + } + + for i := range test.Expected { + if test.Expected[i] != output[i] { + t.Errorf(msg) + } + } + } + }) + } +} diff --git a/ulid.go b/ulid.go index 59f6522..6db1396 100644 --- a/ulid.go +++ b/ulid.go @@ -21,7 +21,7 @@ func NewULIDString(t time.Time, entropy io.Reader) (string, error) { return "", err } - s := crockford32.CrockfordEncode(bytes) + s := crockford32.Encode(bytes) return s, nil } @@ -45,7 +45,7 @@ func NewULID(t time.Time, entropy io.Reader) ([]byte, error) { return nil, fmt.Errorf("failed to read bytes from entropy source: %w", err) } - msBytes, err := GetMSBytes(t) + msBytes, err := TimeMSBytes(t) if err != nil { return nil, err } @@ -67,12 +67,29 @@ func NewULID(t time.Time, entropy io.Reader) ([]byte, error) { return ulidBytes, nil } -// GetMSBytes returns the given Unix time in milliseconds as a 6-byte +// ParseULID expects a Crockford base-32 encoded string, and it will +// parse out the ULID bytes from the string. +func ParseULID(s string) ([]byte, error) { + b, err := crockford32.Decode(s) + if err != nil { + return nil, fmt.Errorf("error decoding string: %w", err) + } + + // Validate time + _, err = GetTime(b) + if err != nil { + return nil, err + } + + return b, nil +} + +// TimeMSBytes returns the given Unix time in milliseconds as a 6-byte // array. It truncates the 64-bit Unix epoch time down to 48-bits (6 // bytes) and returns that 6 byte array. According to the ULID spec, // 48-bits is enough room that we won't run out of space until 10889 // AD. -func GetMSBytes(t time.Time) ([]byte, error) { +func TimeMSBytes(t time.Time) ([]byte, error) { ms := uint64(t.UnixMilli()) // Put the 64-bit int into a byte array @@ -87,3 +104,27 @@ func GetMSBytes(t time.Time) ([]byte, error) { // output. return bytes[2:], nil } + +// GetTime parses the first 6 bytes (48-bits) of the given ULID bytes +// into a time.Time value. It fails if the time value was too large to +// be properly encoded. +func GetTime(u []byte) (time.Time, error) { + zeroTime := time.Time{} + if len(u) != 16 { + return zeroTime, errors.New("invalid ULID bytes") + } + + // Zero pad to get 8 bytes + timeBytes := []byte{0, 0} + timeBytes = append(timeBytes, u[:6]...) + + epoch := binary.BigEndian.Uint64(timeBytes) + + maxTime := uint64(2<<48) - 1 + + if epoch > maxTime { + return zeroTime, errors.New("time value was too large") + } + + return time.UnixMilli(int64(epoch)), nil +} diff --git a/ulid_test.go b/ulid_test.go index 35eca5a..296c69a 100644 --- a/ulid_test.go +++ b/ulid_test.go @@ -14,7 +14,7 @@ func getTestRandomSource() *rand.Rand { return rand.New(rand.NewSource(0)) } -func TestGetMSBytes(t *testing.T) { +func TestTimeMSBytes(t *testing.T) { runAll := true type testData struct { @@ -22,7 +22,7 @@ func TestGetMSBytes(t *testing.T) { RunIt bool Time time.Time Expected []byte - Err bool + Error bool } tests := []testData{ @@ -31,28 +31,28 @@ func TestGetMSBytes(t *testing.T) { RunIt: false || runAll, Time: time.Date(2024, 02, 16, 14, 02, 15, 17, time.UTC), Expected: []byte{0x01, 0x8D, 0xB2, 0x39, 0x96, 0x58}, - Err: false, + Error: false, }, { TestName: "zero time", RunIt: false || runAll, Time: time.Time{}, // zero time overflows when using Unix epoch Expected: []byte{}, - Err: true, + Error: true, }, { TestName: "max time", RunIt: false || runAll, Time: time.UnixMilli(1 << 48), Expected: []byte{}, - Err: true, + Error: true, }, { TestName: "max time minus 1", RunIt: false || runAll, Time: time.UnixMilli(int64(1<<48) - 1), Expected: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, - Err: false, + Error: false, }, } @@ -62,12 +62,16 @@ func TestGetMSBytes(t *testing.T) { t.SkipNow() } - output, err := GetMSBytes(test.Time) - if test.Err { + output, err := TimeMSBytes(test.Time) + if test.Error { if err == nil { t.Errorf("expected an error") } } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + errMsg := fmt.Sprintf("expected: %X, received %X", test.Expected, output) if len(test.Expected) != len(output) { t.Fatal(errMsg) @@ -91,7 +95,7 @@ func TestNewULID(t *testing.T) { Time time.Time Entropy io.Reader Expected []byte - Err bool + Error bool } tests := []testData{ @@ -101,7 +105,7 @@ func TestNewULID(t *testing.T) { Time: time.Time{}, Entropy: nil, Expected: nil, - Err: true, + Error: true, }, { TestName: "Unix zero time", @@ -109,7 +113,7 @@ func TestNewULID(t *testing.T) { Time: time.Unix(0, 0), Entropy: getTestRandomSource(), Expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x94, 0xFD, 0xC2, 0xFA, 0x2F, 0xFC, 0xC0, 0x41, 0xD3}, - Err: false, + Error: false, }, { TestName: "time overflow", @@ -117,7 +121,7 @@ func TestNewULID(t *testing.T) { Time: time.Time{}, // zero time overflows when using Unix epoch time Entropy: getTestRandomSource(), Expected: nil, - Err: true, + Error: true, }, { TestName: "seed 0, real time", @@ -125,7 +129,7 @@ func TestNewULID(t *testing.T) { Time: time.Date(2024, 02, 16, 14, 02, 15, 17, time.UTC), Entropy: getTestRandomSource(), Expected: []byte{0x01, 0x8D, 0xB2, 0x39, 0x96, 0x58, 0x01, 0x94, 0xFD, 0xC2, 0xFA, 0x2F, 0xFC, 0xC0, 0x41, 0xD3}, - Err: false, + Error: false, }, } @@ -136,11 +140,15 @@ func TestNewULID(t *testing.T) { } output, err := NewULID(test.Time, test.Entropy) - if test.Err { + if test.Error { if err == nil { t.Errorf("expected an error") } } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + errMsg := fmt.Sprintf("expected: %X, received %X", test.Expected, output) if len(test.Expected) != len(output) { t.Fatal(errMsg) @@ -164,7 +172,7 @@ func TestULIDString(t *testing.T) { Time time.Time Entropy io.Reader Expected string - Err bool + Error bool } tests := []testData{ @@ -174,7 +182,7 @@ func TestULIDString(t *testing.T) { Time: time.Time{}, Entropy: nil, Expected: "", - Err: true, + Error: true, }, { TestName: "Unix zero time", @@ -182,7 +190,7 @@ func TestULIDString(t *testing.T) { Time: time.Unix(0, 0), Entropy: getTestRandomSource(), Expected: "000000000006AFVGQT5ZYC0GEK", - Err: false, + Error: false, }, { TestName: "time overflow", @@ -190,7 +198,7 @@ func TestULIDString(t *testing.T) { Time: time.Time{}, // zero time overflows when using Unix epoch time Entropy: getTestRandomSource(), Expected: "", - Err: true, + Error: true, }, { TestName: "seed 0, real time", @@ -198,7 +206,7 @@ func TestULIDString(t *testing.T) { Time: time.Date(2024, 02, 16, 14, 02, 15, 17, time.UTC), Entropy: getTestRandomSource(), Expected: "01HPS3K5JR06AFVGQT5ZYC0GEK", - Err: false, + Error: false, }, } @@ -209,11 +217,15 @@ func TestULIDString(t *testing.T) { } output, err := NewULIDString(test.Time, test.Entropy) - if test.Err { + if test.Error { if err == nil { t.Errorf("expected an error") } } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if test.Expected != output { log.Fatalf("expected: %s, received %s", test.Expected, output) } @@ -221,3 +233,76 @@ func TestULIDString(t *testing.T) { }) } } + +func TestULIDFromString(t *testing.T) { + runAll := true + + type testData struct { + TestName string + RunIt bool + Input string + Expected []byte + Error bool + } + + tests := []testData{ + { + TestName: "empty string", + RunIt: false || runAll, + Input: "", + Expected: nil, + Error: true, + }, + { + TestName: "valid ulid", + RunIt: false || runAll, + Input: "01HPS3K5JR06AFVGQT5ZYC0GEK", + Expected: []byte{1, 141, 178, 57, 150, 88, 1, 148, 253, 194, 250, 47, 252, 192, 65, 211}, + Error: false, + }, + { + TestName: "time overflow", + RunIt: false || runAll, + Input: "ZZZZZZZZZZZZZZZZZZZZZZZZZZ", + Expected: nil, + Error: true, + }, + { + TestName: "max ULID", + RunIt: false || runAll, + Input: "7ZZZZZZZZZZZZZZZZZZZZZZZZZ", + Expected: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + Error: false, + }, + } + + for _, test := range tests { + t.Run(test.TestName, func(t *testing.T) { + if !test.RunIt { + t.SkipNow() + } + + output, err := ParseULID(test.Input) + if test.Error { + if err == nil { + t.Errorf("expected an error") + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + msg := fmt.Sprintf("expected %b, got %b", test.Expected, output) + if len(test.Expected) != len(output) { + t.Fatalf(msg) + } + + for i := range test.Expected { + if test.Expected[i] != output[i] { + t.Errorf(msg) + } + } + } + }) + } +}